Spaces:
Runtime error
Runtime error
| import math | |
| from typing import Literal, Union, Iterable, Tuple | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import TransformerEncoder, TransformerEncoderLayer | |
| LAMBDA_MIN = math.pow(10, -3.0) | |
| LAMBDA_MAX = math.pow(10, 3.0) | |
| class MultiFeedForwardModule(nn.Module): | |
| def __init__( | |
| self, | |
| input_size: int, | |
| hidden_size: Union[int, Iterable[int]], | |
| output_size: int, | |
| *, | |
| activation: Literal['relu', 'selu', 'gelu'] = 'relu', | |
| dropout: float = 0.1, | |
| dropout_last_layer: bool = True | |
| ): | |
| super(MultiFeedForwardModule, self).__init__() | |
| if activation == 'relu': | |
| self._activation = nn.ReLU() | |
| elif activation == 'selu': | |
| self._activation = nn.SELU() | |
| elif activation == 'gelu': | |
| self._activation = nn.GELU() | |
| else: | |
| raise ValueError('activation must be relu or selu') | |
| if not hasattr(hidden_size, '__iter__'): | |
| if hidden_size is None: | |
| hidden_size = [output_size] | |
| else: | |
| hidden_size = [hidden_size] | |
| self._layers = [] | |
| layer_dims = [input_size] + hidden_size + [output_size] | |
| for i in range(1, len(layer_dims) - 1): | |
| self._layers.append(nn.Linear(layer_dims[i - 1], layer_dims[i])) | |
| self._layers.append(self._activation) | |
| self._layers.append(nn.Dropout(dropout)) | |
| self._layers.append(nn.Linear(layer_dims[-2], layer_dims[-1])) | |
| if dropout_last_layer: | |
| self._layers.append(nn.Dropout(dropout)) | |
| self._layers = nn.Sequential(*self._layers) | |
| def forward(self, x): | |
| return self._layers(x) | |
| class SinusodialMz(nn.Module): | |
| def __init__(self, embedding_dim: int, *, lambda_params: Tuple[float, float] = (LAMBDA_MIN, LAMBDA_MAX)) -> None: | |
| super(SinusodialMz, self).__init__() | |
| self.lambda_min, self.lambda_max = lambda_params | |
| self.lambda_div_value = self.lambda_max / self.lambda_min | |
| self.x = torch.arange(0, embedding_dim, 2) | |
| self.x = ( | |
| 2 * math.pi * | |
| ( | |
| self.lambda_min * | |
| self.lambda_div_value ** (self.x / (embedding_dim - 2)) | |
| ) ** -1 | |
| ) | |
| def forward(self, mz: torch.Tensor): | |
| self.x = self.x.to(mz.device) | |
| x = torch.einsum('bl,d->bld', mz, self.x) | |
| sin_embedding = torch.sin(x) | |
| cos_embedding = torch.cos(x) | |
| b, l, d = sin_embedding.shape | |
| x = torch.zeros(b, l, 2 * d, dtype=mz.dtype, device=mz.device) | |
| x[:, :, ::2] = sin_embedding | |
| x[:, :, 1::2] = cos_embedding | |
| return x | |
| class SinusodialMzEmbedding(nn.Module): | |
| def __init__( | |
| self, | |
| embedding_dim: int, | |
| *, | |
| lambda_params: Tuple[float, float] = (LAMBDA_MIN, LAMBDA_MAX), | |
| feedward_activation: Literal['relu', 'selu', 'gelu'] = 'relu', | |
| dropout: float = 0.1, | |
| dropout_last_layer: bool = True | |
| ): | |
| super(SinusodialMzEmbedding, self).__init__() | |
| if embedding_dim % 2 != 0: | |
| raise ValueError('embedding_dim must be even') | |
| self.embedding = SinusodialMz( | |
| embedding_dim, lambda_params=lambda_params) | |
| self.feedward_layers = MultiFeedForwardModule( | |
| embedding_dim, embedding_dim, embedding_dim, | |
| activation=feedward_activation, dropout=dropout, dropout_last_layer=dropout_last_layer | |
| ) | |
| def forward(self, mz: torch.Tensor): | |
| x = self.embedding(mz) | |
| x = self.feedward_layers(x) | |
| return x | |
| class PeaksEmbedding(nn.Module): | |
| def __init__( | |
| self, | |
| embedding_dim: int, | |
| *, | |
| lambda_params: Tuple[float, float] = (LAMBDA_MIN, LAMBDA_MAX), | |
| feedward_activation: Literal['relu', 'selu', 'gelu'] = 'relu', | |
| dropout: float = 0.1, | |
| dropout_last_layer: bool = False | |
| ) -> None: | |
| super(PeaksEmbedding, self).__init__() | |
| self.mz_embedding = SinusodialMzEmbedding( | |
| embedding_dim, | |
| lambda_params=lambda_params, | |
| feedward_activation=feedward_activation, | |
| dropout=dropout, | |
| dropout_last_layer=dropout_last_layer | |
| ) | |
| self.intensity_embedding = MultiFeedForwardModule( | |
| embedding_dim + 1, embedding_dim, embedding_dim, | |
| activation=feedward_activation, | |
| dropout=dropout, | |
| dropout_last_layer=dropout_last_layer | |
| ) | |
| def forward(self, mz: torch.Tensor, intensity: torch.Tensor): | |
| mz_tensor = self.mz_embedding(mz) | |
| intensity_tensor = torch.unsqueeze(intensity, dim=-1) | |
| x = self.intensity_embedding( | |
| torch.cat([mz_tensor, intensity_tensor], dim=-1)) | |
| return x | |
| class SiameseModel(nn.Module): | |
| def __init__( | |
| self, | |
| embedding_dim: int, | |
| n_head: int, | |
| n_layer: int, | |
| dim_feedward: int, | |
| dim_target: int, | |
| *, | |
| lambda_params: Tuple[float, float] = (LAMBDA_MIN, LAMBDA_MAX), | |
| feedward_activation: Literal['relu', 'selu', 'gelu'] = 'relu', | |
| dropout: float = 0.1, | |
| dropout_last_layer: bool = False, | |
| norm_first: bool = True | |
| ) -> None: | |
| super(SiameseModel, self).__init__() | |
| if embedding_dim % n_head != 0: | |
| raise ValueError('embedding must be divisible by n_head') | |
| self.embedding = PeaksEmbedding( | |
| embedding_dim, | |
| lambda_params=lambda_params, | |
| feedward_activation=feedward_activation, | |
| dropout=dropout, | |
| dropout_last_layer=dropout_last_layer | |
| ) | |
| if feedward_activation == 'selu': | |
| # transformer encoder activation | |
| # only gelu or relu | |
| self.activation = 'gelu' | |
| else: | |
| self.activation = feedward_activation | |
| if feedward_activation == 'relu': | |
| self._activation = nn.ReLU() | |
| elif feedward_activation == 'selu': | |
| self._activation = nn.SELU() | |
| elif feedward_activation == 'gelu': | |
| self._activation = nn.GELU() | |
| else: | |
| raise ValueError('activation must be relu or selu or gelu') | |
| encoder_layer = TransformerEncoderLayer( | |
| embedding_dim, | |
| n_head, | |
| dim_feedforward=dim_feedward, | |
| dropout=dropout, | |
| activation=self.activation, | |
| batch_first=True, | |
| norm_first=norm_first | |
| ) | |
| self._encoder = TransformerEncoder( | |
| encoder_layer, | |
| n_layer, | |
| enable_nested_tensor=False | |
| ) | |
| self._decoder = MultiFeedForwardModule( | |
| embedding_dim, | |
| dim_feedward, | |
| dim_target, | |
| activation=feedward_activation, | |
| dropout=dropout, | |
| dropout_last_layer=dropout_last_layer | |
| ) | |
| def forward(self, mz: torch.Tensor, intensity: torch.Tensor, mask: torch.Tensor): | |
| x = self.embedding(mz, intensity) | |
| x = self._encoder(x, src_key_padding_mask=mask) | |
| # mean pooling or cls position vector | |
| x = torch.mean(x, dim=1) | |
| x = self._activation(self._decoder(x)) | |
| return x | |
| # class MambaSiameseModel(nn.Module): | |
| # def __init__( | |
| # self, | |
| # embedding_dim: int, | |
| # n_layer: int, | |
| # dim_feedward: int, | |
| # dim_target: int, | |
| # *, | |
| # lambda_params: Tuple[float, float] = (LAMBDA_MIN, LAMBDA_MAX), | |
| # feedward_activation: Literal['relu', 'selu', 'gelu'] = 'relu', | |
| # dropout: float = 0.1, | |
| # dropout_last_layer: bool = False, | |
| # ): | |
| # super(MambaSiameseModel, self).__init__() | |
| # self.embedding = PeaksEmbedding( | |
| # embedding_dim, | |
| # lambda_params=lambda_params, | |
| # feedward_activation=feedward_activation, | |
| # dropout=dropout, | |
| # dropout_last_layer=dropout_last_layer | |
| # ) | |
| # if feedward_activation == 'relu': | |
| # self._activation = nn.ReLU() | |
| # elif feedward_activation == 'selu': | |
| # self._activation = nn.SELU() | |
| # elif feedward_activation == 'gelu': | |
| # self._activation = nn.GELU() | |
| # else: | |
| # raise ValueError('activation must be relu or selu or gelu') | |
| # self._encoder = nn.Sequential(*[ | |
| # Mamba2( | |
| # d_model=embedding_dim, | |
| # d_state=64, | |
| # d_conv=4, | |
| # expand=2 | |
| # ) | |
| # for _ in range(n_layer) | |
| # ]) | |
| # self._decoder = MultiFeedForwardModule( | |
| # embedding_dim, | |
| # dim_feedward, | |
| # dim_target, | |
| # activation=feedward_activation, | |
| # dropout=dropout, | |
| # dropout_last_layer=dropout_last_layer | |
| # ) | |
| # def forward(self, mz: torch.Tensor, intensity: torch.Tensor, mask: torch.Tensor): | |
| # x = self.embedding(mz, intensity) | |
| # x = self._encoder(x) | |
| # # mean pooling or cls position vector | |
| # x = torch.mean(x, dim=1) | |
| # x = self._activation(self._decoder(x)) | |
| # return x | |