| | import logging |
| | import math |
| | from dataclasses import dataclass |
| | from typing import Any, Optional, Tuple |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import PretrainedConfig, PreTrainedModel |
| |
|
| |
|
| | @dataclass |
| | class RotaryEmbeddingConfig: |
| | """ |
| | Parameters to initialize the RotaryEmbedding layer. The rescaling factor allows |
| | to adapt the rotary embeddings to larger lengths than what was used for training. |
| | One of this strategy is presented in the Yarn paper: https://arxiv.org/pdf/2309.00071.pdf. # noqa |
| | Args:b |
| | """ |
| |
|
| | rescaling_factor: Optional[float] |
| |
|
| |
|
| | class RotaryEmbedding(torch.nn.Module): |
| | """ |
| | Rotary position embeddings based on those in |
| | [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). |
| | Query and keys are transformed by rotation |
| | matrices which depend on their relative positions. |
| | """ |
| |
|
| | def __init__(self, dim: int, rotary_embedding_config: RotaryEmbeddingConfig): |
| | super().__init__() |
| |
|
| | |
| | self.rescaling_factor = rotary_embedding_config.rescaling_factor |
| | self.upper_freq = 10000 |
| | self.dim = dim |
| |
|
| | self._seq_len_cached = None |
| | self._cos_cached = None |
| | self._sin_cached = None |
| |
|
| | def _apply_rotary_pos_emb( |
| | self, |
| | heads: torch.Tensor, |
| | cos: torch.Tensor, |
| | sin: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """ """ |
| | x_first, x_second = ( |
| | heads[..., : heads.shape[-1] // 2], |
| | heads[..., heads.shape[-1] // 2 :], |
| | ) |
| |
|
| | first_part = x_first * cos - x_second * sin |
| | second_part = x_second * cos + x_first * sin |
| |
|
| | return torch.cat((first_part, second_part), dim=-1) |
| |
|
| | def _compute_cos_sin_tables( |
| | self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2 |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | seq_len = x.shape[seq_dimension] |
| | |
| | |
| | self._seq_len_cached = seq_len |
| | t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq) |
| | freqs = torch.einsum("i, j -> ij", t, inv_freq) |
| |
|
| | self._cos_cached = torch.cos(freqs)[None, :, None, :] |
| | self._sin_cached = torch.sin(freqs)[None, :, None, :] |
| | return self._cos_cached, self._sin_cached |
| |
|
| | def forward( |
| | self, q: torch.Tensor, k: torch.Tensor |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | if self.rescaling_factor is None: |
| | inv_freq = 1.0 / ( |
| | self.upper_freq ** (torch.arange(0, self.dim, 2).float() / self.dim) |
| | ) |
| | else: |
| | updated_base = self.upper_freq * ( |
| | self.rescaling_factor ** (self.dim / (self.dim - 2)) |
| | ) |
| | inv_freq = 1.0 / ( |
| | updated_base ** (torch.arange(0, self.dim, 2).float() / self.dim) |
| | ) |
| |
|
| | self._cos_cached, self._sin_cached = self._compute_cos_sin_tables( |
| | q, |
| | inv_freq, |
| | seq_dimension=-3, |
| | ) |
| |
|
| | return ( |
| | self._apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), |
| | self._apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), |
| | ) |
| |
|
| |
|
| | class ResidualConvBlock(nn.Module): |
| | """ |
| | Conv Block with Residual connection. |
| | """ |
| |
|
| | def __init__( |
| | self, dim_in: int, dim_out: int, layer_norm_shape: int, kernel_size: int = 1 |
| | ): |
| | super().__init__() |
| | self.conv_block = ConvBlock( |
| | dim_in=dim_in, |
| | dim_out=dim_out, |
| | layer_norm_shape=layer_norm_shape, |
| | kernel_size=kernel_size, |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | y = self.conv_block(x) |
| | return x.reshape(y.shape) + y |
| |
|
| |
|
| | class ConvBlock(nn.Module): |
| | """ |
| | Conv Block. |
| | """ |
| |
|
| | def __init__( |
| | self, dim_in: int, dim_out: int, layer_norm_shape: int, kernel_size: int = 1 |
| | ): |
| | super().__init__() |
| | self.conv = nn.Conv1d( |
| | in_channels=dim_in, |
| | out_channels=dim_out, |
| | kernel_size=kernel_size, |
| | padding="same", |
| | ) |
| | self.layer_norm = nn.LayerNorm(layer_norm_shape, eps=1e-5) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x = x.permute(0, 2, 1) |
| | x = self.layer_norm(x) |
| | x = x.permute(0, 2, 1) |
| | x = self.conv(x) |
| | x = F.gelu(x, approximate="tanh") |
| | return x |
| |
|
| |
|
| | class ConvTowerBlock(nn.Module): |
| | def __init__( |
| | self, |
| | dim_in: int, |
| | dim_out: int, |
| | conv_layer_norm_shape: int, |
| | resconv_layer_norm_shape, |
| | kernel_size: int, |
| | ) -> None: |
| | super().__init__() |
| | self.conv_layer = ConvBlock( |
| | dim_in=dim_in, |
| | dim_out=dim_out, |
| | layer_norm_shape=conv_layer_norm_shape, |
| | kernel_size=kernel_size, |
| | ) |
| | self.res_conv = ResidualConvBlock( |
| | dim_in=dim_out, |
| | dim_out=dim_out, |
| | layer_norm_shape=resconv_layer_norm_shape, |
| | kernel_size=1, |
| | ) |
| | self.avg_pool = nn.AvgPool1d(kernel_size=2, stride=2) |
| |
|
| | def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| | residual = x |
| | x = self.conv_layer(x) |
| | x = self.res_conv(x) |
| | x = self.avg_pool(x) |
| | return x, residual |
| |
|
| |
|
| | class ResidualDeConvBlock(nn.Module): |
| | """ |
| | Conv Block with Residual connection. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | dim_in: int, |
| | dim_out: int, |
| | layer_norm_shape: int, |
| | kernel_size: int = 1, |
| | stride: int = 1, |
| | ): |
| | super().__init__() |
| | self.deconv_block = DeConvBlock( |
| | dim_in=dim_in, |
| | dim_out=dim_out, |
| | layer_norm_shape=layer_norm_shape, |
| | kernel_size=kernel_size, |
| | stride=stride, |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | y = self.deconv_block(x) |
| | return x.reshape(y.shape) + y |
| |
|
| |
|
| | class DeConvBlock(nn.Module): |
| | """ |
| | DeConv Block. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | dim_in: int, |
| | dim_out: int, |
| | layer_norm_shape: int, |
| | kernel_size: int = 1, |
| | stride: int = 1, |
| | ): |
| | super().__init__() |
| | self.deconv = nn.ConvTranspose1d( |
| | in_channels=dim_in, |
| | out_channels=dim_out, |
| | kernel_size=kernel_size, |
| | stride=stride, |
| | padding=0, |
| | ) |
| | self.layer_norm = nn.LayerNorm(layer_norm_shape) |
| | self.kernel_size = kernel_size |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x = x.permute(0, 2, 1) |
| | x = self.layer_norm(x) |
| | x = x.permute(0, 2, 1) |
| | x = self.deconv(x) |
| | if self.kernel_size == 5: |
| | |
| | |
| | x = x[:, :, 1:-2] |
| | x = F.gelu(x, approximate="tanh") |
| | return x |
| |
|
| |
|
| | class DeConvTowerBlock(nn.Module): |
| | def __init__( |
| | self, |
| | dim_in: int, |
| | dim_out: int, |
| | kernel_size: int, |
| | conv_layer_norm_shape: int, |
| | resconv_layer_norm_shape: int, |
| | stride: int = 2, |
| | ): |
| | super().__init__() |
| | self.deconv_block = DeConvBlock( |
| | dim_in=dim_in, |
| | dim_out=dim_out, |
| | layer_norm_shape=conv_layer_norm_shape, |
| | kernel_size=kernel_size, |
| | stride=stride, |
| | ) |
| | self.res_deconv_block = ResidualDeConvBlock( |
| | dim_in=dim_out, |
| | dim_out=dim_out, |
| | layer_norm_shape=resconv_layer_norm_shape, |
| | kernel_size=1, |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor, res: torch.Tensor) -> torch.Tensor: |
| | x = self.deconv_block(x) |
| | x = self.res_deconv_block(x) |
| | x = x + res |
| | return x |
| |
|
| |
|
| | class MultiHeadAttention(nn.Module): |
| | def __init__( |
| | self, |
| | num_heads: int, |
| | key_size: int, |
| | rotary_embedding_config: Optional[RotaryEmbeddingConfig] = None, |
| | add_bias_kv: bool = False, |
| | value_size: Optional[int] = None, |
| | model_size: Optional[int] = None, |
| | name: Optional[str] = None, |
| | ): |
| | super().__init__() |
| | if not model_size: |
| | model_size = key_size |
| | if not value_size: |
| | value_size = key_size |
| | self.model_size = model_size |
| | self.key_size = key_size |
| | self.value_size = value_size |
| | self.add_bias_kv = add_bias_kv |
| | self.name = name |
| | self.num_heads = num_heads |
| | self._rotary_embedding_config = rotary_embedding_config |
| |
|
| | self.w_k = nn.Linear(self.model_size, self.num_heads * self.key_size) |
| | self.w_q = nn.Linear(self.model_size, self.num_heads * self.key_size) |
| | self.w_v = nn.Linear(self.model_size, self.num_heads * self.value_size) |
| | self.output = nn.Linear(self.num_heads * self.value_size, self.model_size) |
| | if self._rotary_embedding_config: |
| | self._rotary_embedding = RotaryEmbedding( |
| | self.key_size, self._rotary_embedding_config |
| | ) |
| |
|
| | def apply_rotary_embeddings( |
| | self, |
| | query: torch.Tensor, |
| | key: torch.Tensor, |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | """ """ |
| | query, key = self._rotary_embedding(query, key) |
| | return query, key |
| |
|
| | def forward( |
| | self, |
| | query: torch.Tensor, |
| | key: torch.Tensor, |
| | value: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | attention_weight_bias: Optional[torch.Tensor] = None, |
| | ) -> dict[str, torch.Tensor]: |
| | """ |
| | Returns: |
| | dictionary containing attention weights |
| | and outputs. |
| | """ |
| | key_heads = self.w_k(key).reshape( |
| | (*key.shape[:-1], self.num_heads, self.key_size) |
| | ) |
| | query_heads = self.w_q(query).reshape( |
| | (*query.shape[:-1], self.num_heads, self.key_size) |
| | ) |
| | value_heads = self.w_v(value).reshape( |
| | (*value.shape[:-1], self.num_heads, self.value_size) |
| | ) |
| | if self._rotary_embedding_config: |
| | query_heads, key_heads = self.apply_rotary_embeddings( |
| | query_heads, key_heads |
| | ) |
| | attention_weights = torch.einsum( |
| | "...thd, ...Thd -> ...htT", query_heads, key_heads |
| | ) |
| | sqrt_key_size = np.sqrt(self.key_size) |
| | attention_weights = attention_weights / sqrt_key_size |
| | if attention_mask: |
| | attention_weights = torch.where(attention_mask, attention_weights, -1e30) |
| | if attention_weight_bias: |
| | attention_weights = F.softmax( |
| | attention_weights + attention_weight_bias, dim=-1 |
| | ) |
| | else: |
| | attention_weights = F.softmax(attention_weights, dim=-1) |
| | value_out = torch.einsum( |
| | "...htT, ...Thd->...thd", attention_weights, value_heads |
| | ) |
| | value_out = value_out.reshape((*value_out.shape[:-2], -1)) |
| | embeddings = self.output(value_out) |
| |
|
| | return {"attention_weights": attention_weights, "embeddings": embeddings} |
| |
|
| |
|
| | class SelfAttentionBlock(nn.Module): |
| | def __init__( |
| | self, |
| | num_heads: int, |
| | embed_dim: int, |
| | ffn_embed_dim: int, |
| | key_size: Optional[int] = None, |
| | add_bias_kv: bool = False, |
| | add_bias_fnn: bool = True, |
| | ffn_activation_name: str = "gelu-no-approx", |
| | use_glu_in_ffn: bool = False, |
| | layer_norm_eps: float = 1e-5, |
| | pre_layer_norm: bool = True, |
| | name: Optional[str] = None, |
| | rotary_embedding_config: Optional[RotaryEmbeddingConfig] = None, |
| | ): |
| | super().__init__() |
| | if key_size is None: |
| | if embed_dim % num_heads != 0: |
| | raise ValueError( |
| | f"The embedding dimension should be divisible by the number of " |
| | f"heads, however provided embedding dimension is {embed_dim} and " |
| | f"the number of heads is {num_heads}." |
| | ) |
| | else: |
| | key_size = embed_dim // num_heads |
| |
|
| | |
| | self._pre_layer_norm = pre_layer_norm |
| | self._use_glu_in_fnn = use_glu_in_ffn |
| | |
| | if use_glu_in_ffn: |
| | |
| | |
| | |
| | |
| | self.fc1 = nn.Linear(embed_dim, int(2 * ffn_embed_dim), bias=add_bias_fnn) |
| | else: |
| | self.fc1 = nn.Linear(embed_dim, ffn_embed_dim, bias=add_bias_fnn) |
| |
|
| | self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_fnn) |
| |
|
| | self.layer_norm_self_attention = nn.LayerNorm( |
| | embed_dim, |
| | ) |
| | self.layer_norm_mlp = nn.LayerNorm(embed_dim) |
| | if ffn_activation_name == "swish": |
| | self._ffn_activation_fn = nn.SiLU() |
| | elif ffn_activation_name == "gelu-no-approx": |
| | self._ffn_activation_fn = lambda x: F.gelu(x, approximate="none") |
| | else: |
| | self._ffn_activation_fn = getattr(torch.nn, ffn_activation_name) |
| |
|
| | self.mha = MultiHeadAttention( |
| | num_heads=num_heads, |
| | key_size=key_size, |
| | add_bias_kv=add_bias_kv, |
| | model_size=embed_dim, |
| | name="self_attention", |
| | rotary_embedding_config=rotary_embedding_config, |
| | ) |
| |
|
| | def mlp(self, embed: torch.Tensor) -> torch.Tensor: |
| |
|
| | if self._pre_layer_norm: |
| | x = self.layer_norm_mlp(embed) |
| | else: |
| | x = embed |
| |
|
| | if self._use_glu_in_fnn: |
| | x = self.fc1(x) |
| | x1, x2 = torch.split(x, split_size_or_sections=x.shape[-1] // 2, dim=-1) |
| | x = self._ffn_activation_fn(x1) * x2 |
| | else: |
| | x = self._ffn_activation_fn(self.fc1(x)) |
| | x = self.fc2(x) |
| |
|
| | if not self._pre_layer_norm: |
| | x = self.layer_norm_mlp(x + embed) |
| | return x |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | attention_weight_bias: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| |
|
| | res = x |
| | if self._pre_layer_norm: |
| | x = self.layer_norm_self_attention(x) |
| |
|
| | output = self.mha( |
| | x, |
| | x, |
| | x, |
| | attention_mask=attention_mask, |
| | attention_weight_bias=attention_weight_bias, |
| | ) |
| |
|
| | if not self._pre_layer_norm: |
| | output["embeddings"] = self.layer_norm_self_attention( |
| | output["embeddings"] + res |
| | ) |
| |
|
| | x = output["embeddings"] |
| | else: |
| | x = output["embeddings"] |
| | x = res + x |
| |
|
| | |
| | if not self._pre_layer_norm: |
| | x = self.mlp(x) |
| | else: |
| | x = x + self.mlp(x) |
| |
|
| | output["embeddings"] = x |
| | return output |
| |
|
| |
|
| | class LMHead(nn.Module): |
| | def __init__( |
| | self, dim_in: int, embed_dim: int, dim_out: int, num_hidden_layers: int |
| | ) -> None: |
| | """ """ |
| | super().__init__() |
| | self.num_hidden_layers = num_hidden_layers |
| | self.linear_layers = nn.ModuleList([nn.Linear(dim_in, embed_dim)]) |
| | self.linear_layers.extend( |
| | nn.ModuleList( |
| | [nn.Linear(embed_dim, embed_dim)] |
| | for _ in range(num_hidden_layers - 1) |
| | ) |
| | ) |
| | self.linear_out = nn.Linear(embed_dim, dim_out) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x = F.gelu(x, approximate="tanh") |
| | for layer in self.linear_layers: |
| | x = layer(x) |
| | x = F.gelu(x, approximate="tanh") |
| | out = self.linear_out(x) |
| | return out |
| |
|
| |
|
| | class MOJOConfig(PretrainedConfig): |
| | model_type = "MOJO" |
| |
|
| | def __init__(self, **kwargs: Any) -> None: |
| | super().__init__(**kwargs) |
| | self.alphabet_size = kwargs.get( |
| | "alphabet_size", {"rnaseq": 66, "methylation": 66} |
| | ) |
| | self.token_embed_dim = kwargs.get("token_embed_dim", 256) |
| | self.init_gene_embed_dim = kwargs.get("init_gene_embed_dim", 200) |
| | self.use_gene_embedding = kwargs.get("use_gene_embedding", True) |
| | self.project_gene_embedding = kwargs.get("project_gene_embedding", True) |
| | self.sequence_length = kwargs.get("sequence_length", 17_116) |
| | self.fixed_sequence_length = kwargs.get("fixed_sequence_length", None) |
| | self.num_downsamples = kwargs.get("num_downsamples", 8) |
| | self.conv_init_embed_dim = kwargs.get("conv_init_embed_dim", 512) |
| | self.stem_kernel_shape = kwargs.get("stem_kernel_shape", 15) |
| | self.embed_dim = kwargs.get("embed_dim", 512) |
| | self.filter_list = kwargs.get("filter_list", []) |
| | self.num_attention_heads = kwargs.get("num_attention_heads", 16) |
| | self.key_size = kwargs.get("key_size", None) |
| | self.ffn_embed_dim = kwargs.get("ffn_embed_dim", 1_024) |
| | self.num_layers = kwargs.get("num_layers", 8) |
| | self.num_hidden_layers_head = kwargs.get("num_hidden_layers_head", 1) |
| |
|
| | |
| | self.embeddings_layers_to_save: tuple[int, ...] = kwargs.get( |
| | "embeddings_layers_to_save", () |
| | ) |
| | self.attention_maps_to_save: list[tuple[int, int]] = kwargs.get( |
| | "attention_maps_to_save", [] |
| | ) |
| |
|
| | self.__post_init__() |
| |
|
| | def __post_init__(self): |
| | |
| | key_size = self.key_size |
| | if key_size is None: |
| | embed_dim = self.embed_dim |
| | num_attention_heads = self.num_attention_heads |
| | if not embed_dim % num_attention_heads == 0: |
| | raise ValueError( |
| | f"When no key size is provided, the embedding dimension should be " |
| | f"divisible by the number of heads, however provided embedding " |
| | f"dimension is {embed_dim} and the number of heads is " |
| | f"{num_attention_heads}." |
| | ) |
| | self.key_size = embed_dim // num_attention_heads |
| |
|
| | |
| | use_gene_embedding = self.use_gene_embedding |
| | if use_gene_embedding: |
| | init_gene_embed_dim = self.init_gene_embed_dim |
| | token_embed_dim = self.token_embed_dim |
| | if init_gene_embed_dim != token_embed_dim: |
| | project_gene_embedding = self.project_gene_embedding |
| | if not project_gene_embedding: |
| | logging.warning( |
| | f"Init gene embedding dimension ({init_gene_embed_dim})" |
| | f"different than token embedding dimension ({token_embed_dim})." |
| | f"Setting `project_gene_embedding` to True" |
| | ) |
| | self.project_gene_embedding = True |
| |
|
| | |
| | num_downsamples = self.num_downsamples |
| | sequence_length = self.sequence_length |
| | downsample_factor = 2**num_downsamples |
| | fixed_sequence_length = ( |
| | math.ceil(sequence_length / downsample_factor) * downsample_factor |
| | ) |
| | self.fixed_sequence_length = fixed_sequence_length |
| |
|
| | |
| | num_downsamples = self.num_downsamples |
| | filter_list = ( |
| | np.linspace( |
| | self.conv_init_embed_dim, |
| | self.embed_dim, |
| | num_downsamples + 1, |
| | ) |
| | .astype(int) |
| | .tolist() |
| | ) |
| | self.filter_list = filter_list |
| |
|
| |
|
| | class MOJO(PreTrainedModel): |
| | config_class = MOJOConfig |
| |
|
| | def __init__(self, config: MOJOConfig): |
| | super().__init__(config=config) |
| |
|
| | |
| | self.embedding_layers = nn.ModuleDict( |
| | { |
| | omic: nn.Embedding(config.alphabet_size[omic], config.token_embed_dim) |
| | for omic in config.alphabet_size |
| | } |
| | ) |
| |
|
| | self.gene_embedding_layer = nn.Embedding( |
| | config.fixed_sequence_length, |
| | config.init_gene_embed_dim, |
| | ) |
| | self.fc_gene_embedding = nn.Linear( |
| | config.init_gene_embed_dim, config.token_embed_dim |
| | ) |
| |
|
| | |
| | self.stem_conv = nn.Sequential( |
| | nn.Conv1d( |
| | in_channels=config.token_embed_dim, |
| | out_channels=config.conv_init_embed_dim, |
| | kernel_size=config.stem_kernel_shape, |
| | padding="same", |
| | ), |
| | nn.GELU(approximate="tanh"), |
| | ) |
| |
|
| | self.conv_tower = nn.ModuleList( |
| | [ |
| | ConvTowerBlock( |
| | dim_in=config.filter_list[i], |
| | dim_out=config.filter_list[i + 1], |
| | kernel_size=5, |
| | conv_layer_norm_shape=config.filter_list[i], |
| | resconv_layer_norm_shape=config.filter_list[i + 1], |
| | ) |
| | for i in range(len(config.filter_list) - 1) |
| | ] |
| | ) |
| |
|
| | |
| | attention_maps_to_save = config.attention_maps_to_save |
| | self._attention_layers_to_save = list({t[0] for t in attention_maps_to_save}) |
| |
|
| | self._attention_maps_per_layer_to_save = { |
| | layer: [t[1] for t in attention_maps_to_save if t[0] == layer] |
| | for layer in self._attention_layers_to_save |
| | } |
| |
|
| | max_layer = max(self._attention_layers_to_save + [0]) |
| | if max_layer > config.num_layers: |
| | raise ValueError( |
| | f"You are requiring attention maps for layer {max_layer}, " |
| | f"while the model has {config.num_layers} layers only." |
| | ) |
| | self._rotary_embedding_config = RotaryEmbeddingConfig(rescaling_factor=None) |
| | self.transformer_layers = nn.ModuleList( |
| | [ |
| | SelfAttentionBlock( |
| | num_heads=config.num_attention_heads, |
| | embed_dim=config.embed_dim, |
| | ffn_embed_dim=config.ffn_embed_dim, |
| | key_size=config.key_size, |
| | add_bias_kv=False, |
| | add_bias_fnn=False, |
| | ffn_activation_name="swish", |
| | use_glu_in_ffn=True, |
| | layer_norm_eps=1e-5, |
| | pre_layer_norm=True, |
| | name=f"attention_layer_{layer_idx}", |
| | rotary_embedding_config=self._rotary_embedding_config, |
| | ) |
| | for layer_idx in range(config.num_layers) |
| | ] |
| | ) |
| |
|
| | |
| | self.deconv_tower = nn.ModuleList( |
| | [ |
| | DeConvTowerBlock( |
| | dim_in=config.filter_list[-1 - i], |
| | dim_out=config.filter_list[-1 - i - 1], |
| | kernel_size=5, |
| | stride=2, |
| | conv_layer_norm_shape=config.filter_list[-1 - i], |
| | resconv_layer_norm_shape=config.filter_list[-1 - i - 1], |
| | ) |
| | for i in range(len(config.filter_list) - 1) |
| | ] |
| | ) |
| |
|
| | |
| | self.omic_lm_heads = nn.ModuleDict( |
| | { |
| | omic: LMHead( |
| | dim_in=config.conv_init_embed_dim, |
| | embed_dim=config.embed_dim, |
| | dim_out=config.alphabet_size[omic], |
| | num_hidden_layers=config.num_hidden_layers_head, |
| | ) |
| | for omic in self.config.alphabet_size |
| | } |
| | ) |
| |
|
| | def get_embeddings( |
| | self, |
| | input_ids: dict[str, torch.Tensor], |
| | ) -> dict[str, torch.Tensor]: |
| | omic_embeddings = {} |
| | for omic, omic_tokens in input_ids.items(): |
| | omic_embeddings[omic] = self.embedding_layers[omic](omic_tokens) |
| | return omic_embeddings |
| |
|
| | def forward(self, input_ids: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: |
| | outs = {} |
| | embeddings = self.get_embeddings(input_ids) |
| | outs["omic_embeddings"] = embeddings |
| | x = torch.stack(list(embeddings.values()), dim=0).sum(dim=0) |
| | outs["embeddings"] = x |
| |
|
| | if self.config.use_gene_embedding: |
| | gene_indices = torch.arange( |
| | self.config.fixed_sequence_length, device=x.device |
| | ) |
| | gene_embedding = self.gene_embedding_layer(gene_indices) |
| | if self.config.project_gene_embedding: |
| | gene_embedding = self.fc_gene_embedding(gene_embedding) |
| | x = x + gene_embedding |
| | outs["embeddings_with_gene_embedding"] = x |
| |
|
| | x = x.permute(0, 2, 1) |
| | x = self.stem_conv(x) |
| | outs["stem"] = x |
| |
|
| | residuals = [] |
| | for conv_block in self.conv_tower: |
| | x, res = conv_block(x) |
| | residuals.append(res) |
| | x = x.permute(0, 2, 1) |
| | outs["conv_tower"] = x |
| | outs["conv_tower_residuals"] = residuals |
| | residuals = residuals[::-1] |
| |
|
| | for layer_idx, transformer in enumerate(self.transformer_layers): |
| | output = transformer(x) |
| | x = output["embeddings"] |
| | if (layer_idx + 1) in self.config.embeddings_layers_to_save: |
| | outs[f"embeddings_{(layer_idx + 1)}"] = output["embeddings"] |
| | if (layer_idx + 1) in self._attention_layers_to_save: |
| | for map_number in self._attention_maps_per_layer_to_save[layer_idx + 1]: |
| | dkey = f"attention_map_layer_{layer_idx + 1}_number_{map_number}" |
| | outs[dkey] = output["attention_weights"][:, map_number + 1] |
| | outs["after_transformer_embedding"] = x |
| |
|
| | x = x.permute(0, 2, 1) |
| | for deconv_block, res in zip(self.deconv_tower, residuals): |
| | x = deconv_block(x, res) |
| | x = x.permute(0, 2, 1) |
| | outs["deconv_tower"] = x |
| |
|
| | outs["logits"] = { |
| | omic: self.omic_lm_heads[omic](x) for omic in self.config.alphabet_size |
| | } |
| |
|
| | return outs |
| |
|