SpecEmbedding / src /model.py
xp
init commit
6039b52
import math
from typing import Literal, Union, Iterable, Tuple
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
LAMBDA_MIN = math.pow(10, -3.0)
LAMBDA_MAX = math.pow(10, 3.0)
class MultiFeedForwardModule(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: Union[int, Iterable[int]],
output_size: int,
*,
activation: Literal['relu', 'selu', 'gelu'] = 'relu',
dropout: float = 0.1,
dropout_last_layer: bool = True
):
super(MultiFeedForwardModule, self).__init__()
if activation == 'relu':
self._activation = nn.ReLU()
elif activation == 'selu':
self._activation = nn.SELU()
elif activation == 'gelu':
self._activation = nn.GELU()
else:
raise ValueError('activation must be relu or selu')
if not hasattr(hidden_size, '__iter__'):
if hidden_size is None:
hidden_size = [output_size]
else:
hidden_size = [hidden_size]
self._layers = []
layer_dims = [input_size] + hidden_size + [output_size]
for i in range(1, len(layer_dims) - 1):
self._layers.append(nn.Linear(layer_dims[i - 1], layer_dims[i]))
self._layers.append(self._activation)
self._layers.append(nn.Dropout(dropout))
self._layers.append(nn.Linear(layer_dims[-2], layer_dims[-1]))
if dropout_last_layer:
self._layers.append(nn.Dropout(dropout))
self._layers = nn.Sequential(*self._layers)
def forward(self, x):
return self._layers(x)
class SinusodialMz(nn.Module):
def __init__(self, embedding_dim: int, *, lambda_params: Tuple[float, float] = (LAMBDA_MIN, LAMBDA_MAX)) -> None:
super(SinusodialMz, self).__init__()
self.lambda_min, self.lambda_max = lambda_params
self.lambda_div_value = self.lambda_max / self.lambda_min
self.x = torch.arange(0, embedding_dim, 2)
self.x = (
2 * math.pi *
(
self.lambda_min *
self.lambda_div_value ** (self.x / (embedding_dim - 2))
) ** -1
)
def forward(self, mz: torch.Tensor):
self.x = self.x.to(mz.device)
x = torch.einsum('bl,d->bld', mz, self.x)
sin_embedding = torch.sin(x)
cos_embedding = torch.cos(x)
b, l, d = sin_embedding.shape
x = torch.zeros(b, l, 2 * d, dtype=mz.dtype, device=mz.device)
x[:, :, ::2] = sin_embedding
x[:, :, 1::2] = cos_embedding
return x
class SinusodialMzEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
*,
lambda_params: Tuple[float, float] = (LAMBDA_MIN, LAMBDA_MAX),
feedward_activation: Literal['relu', 'selu', 'gelu'] = 'relu',
dropout: float = 0.1,
dropout_last_layer: bool = True
):
super(SinusodialMzEmbedding, self).__init__()
if embedding_dim % 2 != 0:
raise ValueError('embedding_dim must be even')
self.embedding = SinusodialMz(
embedding_dim, lambda_params=lambda_params)
self.feedward_layers = MultiFeedForwardModule(
embedding_dim, embedding_dim, embedding_dim,
activation=feedward_activation, dropout=dropout, dropout_last_layer=dropout_last_layer
)
def forward(self, mz: torch.Tensor):
x = self.embedding(mz)
x = self.feedward_layers(x)
return x
class PeaksEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
*,
lambda_params: Tuple[float, float] = (LAMBDA_MIN, LAMBDA_MAX),
feedward_activation: Literal['relu', 'selu', 'gelu'] = 'relu',
dropout: float = 0.1,
dropout_last_layer: bool = False
) -> None:
super(PeaksEmbedding, self).__init__()
self.mz_embedding = SinusodialMzEmbedding(
embedding_dim,
lambda_params=lambda_params,
feedward_activation=feedward_activation,
dropout=dropout,
dropout_last_layer=dropout_last_layer
)
self.intensity_embedding = MultiFeedForwardModule(
embedding_dim + 1, embedding_dim, embedding_dim,
activation=feedward_activation,
dropout=dropout,
dropout_last_layer=dropout_last_layer
)
def forward(self, mz: torch.Tensor, intensity: torch.Tensor):
mz_tensor = self.mz_embedding(mz)
intensity_tensor = torch.unsqueeze(intensity, dim=-1)
x = self.intensity_embedding(
torch.cat([mz_tensor, intensity_tensor], dim=-1))
return x
class SiameseModel(nn.Module):
def __init__(
self,
embedding_dim: int,
n_head: int,
n_layer: int,
dim_feedward: int,
dim_target: int,
*,
lambda_params: Tuple[float, float] = (LAMBDA_MIN, LAMBDA_MAX),
feedward_activation: Literal['relu', 'selu', 'gelu'] = 'relu',
dropout: float = 0.1,
dropout_last_layer: bool = False,
norm_first: bool = True
) -> None:
super(SiameseModel, self).__init__()
if embedding_dim % n_head != 0:
raise ValueError('embedding must be divisible by n_head')
self.embedding = PeaksEmbedding(
embedding_dim,
lambda_params=lambda_params,
feedward_activation=feedward_activation,
dropout=dropout,
dropout_last_layer=dropout_last_layer
)
if feedward_activation == 'selu':
# transformer encoder activation
# only gelu or relu
self.activation = 'gelu'
else:
self.activation = feedward_activation
if feedward_activation == 'relu':
self._activation = nn.ReLU()
elif feedward_activation == 'selu':
self._activation = nn.SELU()
elif feedward_activation == 'gelu':
self._activation = nn.GELU()
else:
raise ValueError('activation must be relu or selu or gelu')
encoder_layer = TransformerEncoderLayer(
embedding_dim,
n_head,
dim_feedforward=dim_feedward,
dropout=dropout,
activation=self.activation,
batch_first=True,
norm_first=norm_first
)
self._encoder = TransformerEncoder(
encoder_layer,
n_layer,
enable_nested_tensor=False
)
self._decoder = MultiFeedForwardModule(
embedding_dim,
dim_feedward,
dim_target,
activation=feedward_activation,
dropout=dropout,
dropout_last_layer=dropout_last_layer
)
def forward(self, mz: torch.Tensor, intensity: torch.Tensor, mask: torch.Tensor):
x = self.embedding(mz, intensity)
x = self._encoder(x, src_key_padding_mask=mask)
# mean pooling or cls position vector
x = torch.mean(x, dim=1)
x = self._activation(self._decoder(x))
return x
# class MambaSiameseModel(nn.Module):
# def __init__(
# self,
# embedding_dim: int,
# n_layer: int,
# dim_feedward: int,
# dim_target: int,
# *,
# lambda_params: Tuple[float, float] = (LAMBDA_MIN, LAMBDA_MAX),
# feedward_activation: Literal['relu', 'selu', 'gelu'] = 'relu',
# dropout: float = 0.1,
# dropout_last_layer: bool = False,
# ):
# super(MambaSiameseModel, self).__init__()
# self.embedding = PeaksEmbedding(
# embedding_dim,
# lambda_params=lambda_params,
# feedward_activation=feedward_activation,
# dropout=dropout,
# dropout_last_layer=dropout_last_layer
# )
# if feedward_activation == 'relu':
# self._activation = nn.ReLU()
# elif feedward_activation == 'selu':
# self._activation = nn.SELU()
# elif feedward_activation == 'gelu':
# self._activation = nn.GELU()
# else:
# raise ValueError('activation must be relu or selu or gelu')
# self._encoder = nn.Sequential(*[
# Mamba2(
# d_model=embedding_dim,
# d_state=64,
# d_conv=4,
# expand=2
# )
# for _ in range(n_layer)
# ])
# self._decoder = MultiFeedForwardModule(
# embedding_dim,
# dim_feedward,
# dim_target,
# activation=feedward_activation,
# dropout=dropout,
# dropout_last_layer=dropout_last_layer
# )
# def forward(self, mz: torch.Tensor, intensity: torch.Tensor, mask: torch.Tensor):
# x = self.embedding(mz, intensity)
# x = self._encoder(x)
# # mean pooling or cls position vector
# x = torch.mean(x, dim=1)
# x = self._activation(self._decoder(x))
# return x