Instructions to use Taykhoom/RNA-MSM with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Taykhoom/RNA-MSM with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("fill-mask", model="Taykhoom/RNA-MSM", trust_remote_code=True)# Load model directly from transformers import AutoModelForMaskedLM model = AutoModelForMaskedLM.from_pretrained("Taykhoom/RNA-MSM", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import math | |
| from typing import Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import PreTrainedModel | |
| from transformers.modeling_outputs import BaseModelOutput, MaskedLMOutput | |
| try: | |
| from .configuration_rnamsm import RNAMSMConfig | |
| except ImportError: | |
| from configuration_rnamsm import RNAMSMConfig | |
| def gelu(x): | |
| return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | |
| class RNAMSMLMHead(nn.Module): | |
| def __init__(self, config: RNAMSMConfig, embed_tokens_weight: nn.Parameter): | |
| super().__init__() | |
| self.dense = nn.Linear(config.embed_dim, config.embed_dim) | |
| self.layer_norm = nn.LayerNorm(config.embed_dim) | |
| self.weight = embed_tokens_weight | |
| self.bias = nn.Parameter(torch.zeros(config.vocab_size)) | |
| def forward(self, x): | |
| x = self.dense(x) | |
| x = gelu(x) | |
| x = self.layer_norm(x) | |
| return F.linear(x, self.weight) + self.bias | |
| class LearnedPositionalEmbedding(nn.Embedding): | |
| def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): | |
| num_embeddings_ = num_embeddings + padding_idx + 1 | |
| super().__init__(num_embeddings_, embedding_dim, padding_idx) | |
| self.max_positions = num_embeddings | |
| def forward(self, tokens: torch.Tensor) -> torch.Tensor: | |
| mask = tokens.ne(self.padding_idx).int() | |
| positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + self.padding_idx | |
| return F.embedding(positions, self.weight, self.padding_idx, | |
| self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) | |
| class NormalizedResidualBlock(nn.Module): | |
| def __init__(self, layer: nn.Module, embedding_dim: int, dropout: float): | |
| super().__init__() | |
| self.layer = layer | |
| self.layer_norm = nn.LayerNorm(embedding_dim) | |
| self.dropout_module = nn.Dropout(dropout) | |
| def forward(self, x, *args, **kwargs): | |
| residual = x | |
| x = self.layer_norm(x) | |
| outputs = self.layer(x, *args, **kwargs) | |
| if isinstance(outputs, tuple): | |
| x, *out = outputs | |
| else: | |
| x, out = outputs, None | |
| x = self.dropout_module(x) | |
| x = residual + x | |
| if out is not None: | |
| return (x,) + tuple(out) | |
| return x | |
| class FeedForwardNetwork(nn.Module): | |
| def __init__(self, embedding_dim: int, ffn_embedding_dim: int, | |
| activation_dropout: float, max_tokens_per_msa: int): | |
| super().__init__() | |
| self.fc1 = nn.Linear(embedding_dim, ffn_embedding_dim) | |
| self.fc2 = nn.Linear(ffn_embedding_dim, embedding_dim) | |
| self.activation_fn = nn.GELU() | |
| self.activation_dropout_module = nn.Dropout(activation_dropout) | |
| self.max_tokens_per_msa = max_tokens_per_msa | |
| def forward(self, x): | |
| x = self.activation_fn(self.fc1(x)) | |
| x = self.activation_dropout_module(x) | |
| return self.fc2(x) | |
| class RowSelfAttention(nn.Module): | |
| """Self-attention across columns (sequence positions), summed over MSA rows.""" | |
| def __init__(self, embed_dim: int, num_heads: int, dropout: float, max_tokens_per_msa: int): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.dropout = dropout | |
| self.head_dim = embed_dim // num_heads | |
| self.scaling = self.head_dim ** -0.5 | |
| self.max_tokens_per_msa = max_tokens_per_msa | |
| self.attn_shape = "hnij" | |
| self.q_proj = nn.Linear(embed_dim, embed_dim) | |
| self.k_proj = nn.Linear(embed_dim, embed_dim) | |
| self.v_proj = nn.Linear(embed_dim, embed_dim) | |
| self.out_proj = nn.Linear(embed_dim, embed_dim) | |
| self.dropout_module = nn.Dropout(dropout) | |
| def align_scaling(self, q): | |
| return self.scaling / math.sqrt(q.size(0)) | |
| def compute_attention_weights(self, x, scaling, padding_mask=None): | |
| num_rows, num_cols, batch_size, embed_dim = x.size() | |
| q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) | |
| k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) | |
| q = q * scaling | |
| if padding_mask is not None: | |
| q = q * (1 - padding_mask.permute(1, 2, 0).unsqueeze(3).unsqueeze(4).to(q)) | |
| attn_weights = torch.einsum(f"rinhd,rjnhd->{self.attn_shape}", q, k) | |
| if padding_mask is not None: | |
| attn_weights = attn_weights.masked_fill( | |
| padding_mask[:, 0].unsqueeze(0).unsqueeze(2), -10000.0) | |
| return attn_weights | |
| def compute_attention_update(self, x, attn_probs): | |
| num_rows, num_cols, batch_size, embed_dim = x.size() | |
| v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) | |
| context = torch.einsum(f"{self.attn_shape},rjnhd->rinhd", attn_probs, v) | |
| context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim) | |
| return self.out_proj(context) | |
| def _batched_forward(self, x, padding_mask=None): | |
| num_rows, num_cols, batch_size, _ = x.size() | |
| max_rows = max(1, self.max_tokens_per_msa // num_cols) | |
| scaling = self.align_scaling(x) | |
| attns = 0 | |
| for start in range(0, num_rows, max_rows): | |
| pm = padding_mask[:, start:start + max_rows] if padding_mask is not None else None | |
| attns = attns + self.compute_attention_weights(x[start:start + max_rows], scaling, pm) | |
| attn_probs = attns.softmax(-1) | |
| attn_probs = self.dropout_module(attn_probs) | |
| outputs = [self.compute_attention_update(x[start:start + max_rows], attn_probs) | |
| for start in range(0, num_rows, max_rows)] | |
| return torch.cat(outputs, 0), attn_probs | |
| def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None): | |
| num_rows, num_cols, batch_size, _ = x.size() | |
| if num_rows * num_cols > self.max_tokens_per_msa and not torch.is_grad_enabled(): | |
| return self._batched_forward(x, self_attn_padding_mask) | |
| scaling = self.align_scaling(x) | |
| attn_weights = self.compute_attention_weights(x, scaling, self_attn_padding_mask) | |
| attn_probs = attn_weights.softmax(-1) | |
| attn_probs = self.dropout_module(attn_probs) | |
| output = self.compute_attention_update(x, attn_probs) | |
| return output, attn_probs | |
| class ColumnSelfAttention(nn.Module): | |
| """Self-attention across MSA rows (alignment depth) per sequence position.""" | |
| def __init__(self, embed_dim: int, num_heads: int, dropout: float, max_tokens_per_msa: int): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.dropout = dropout | |
| self.head_dim = embed_dim // num_heads | |
| self.scaling = self.head_dim ** -0.5 | |
| self.max_tokens_per_msa = max_tokens_per_msa | |
| self.q_proj = nn.Linear(embed_dim, embed_dim) | |
| self.k_proj = nn.Linear(embed_dim, embed_dim) | |
| self.v_proj = nn.Linear(embed_dim, embed_dim) | |
| self.out_proj = nn.Linear(embed_dim, embed_dim) | |
| self.dropout_module = nn.Dropout(dropout) | |
| def compute_attention_update(self, x, self_attn_padding_mask=None): | |
| num_rows, num_cols, batch_size, embed_dim = x.size() | |
| if num_rows == 1: | |
| attn_probs = torch.ones(self.num_heads, num_cols, batch_size, 1, 1, | |
| device=x.device, dtype=x.dtype) | |
| output = self.out_proj(self.v_proj(x)) | |
| else: | |
| q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) | |
| k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) | |
| v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) | |
| q = q * self.scaling | |
| attn_weights = torch.einsum("icnhd,jcnhd->hcnij", q, k) | |
| if self_attn_padding_mask is not None: | |
| attn_weights = attn_weights.masked_fill( | |
| self_attn_padding_mask.permute(2, 0, 1).unsqueeze(0).unsqueeze(3), -10000.0) | |
| attn_probs = attn_weights.softmax(-1) | |
| attn_probs = self.dropout_module(attn_probs) | |
| context = torch.einsum("hcnij,jcnhd->icnhd", attn_probs, v) | |
| context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim) | |
| output = self.out_proj(context) | |
| return output, attn_probs | |
| def _batched_forward(self, x, self_attn_padding_mask=None): | |
| num_rows, num_cols, batch_size, _ = x.size() | |
| max_cols = max(1, self.max_tokens_per_msa // num_rows) | |
| outputs, attns = [], [] | |
| for start in range(0, num_cols, max_cols): | |
| pm = (self_attn_padding_mask[:, :, start:start + max_cols] | |
| if self_attn_padding_mask is not None else None) | |
| out, attn = self.compute_attention_update(x[:, start:start + max_cols], pm) | |
| outputs.append(out) | |
| attns.append(attn) | |
| return torch.cat(outputs, 1), torch.cat(attns, 1) | |
| def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None): | |
| num_rows, num_cols, batch_size, _ = x.size() | |
| if num_rows * num_cols > self.max_tokens_per_msa and not torch.is_grad_enabled(): | |
| return self._batched_forward(x, self_attn_padding_mask) | |
| return self.compute_attention_update(x, self_attn_padding_mask) | |
| class AxialTransformerLayer(nn.Module): | |
| def __init__(self, config: RNAMSMConfig): | |
| super().__init__() | |
| self.row_self_attention = NormalizedResidualBlock( | |
| RowSelfAttention(config.embed_dim, config.num_attention_heads, | |
| config.attention_dropout, config.max_tokens_per_msa), | |
| config.embed_dim, config.dropout, | |
| ) | |
| self.column_self_attention = NormalizedResidualBlock( | |
| ColumnSelfAttention(config.embed_dim, config.num_attention_heads, | |
| config.attention_dropout, config.max_tokens_per_msa), | |
| config.embed_dim, config.dropout, | |
| ) | |
| self.feed_forward_layer = NormalizedResidualBlock( | |
| FeedForwardNetwork(config.embed_dim, config.ffn_embed_dim, | |
| config.activation_dropout, config.max_tokens_per_msa), | |
| config.embed_dim, config.dropout, | |
| ) | |
| def forward(self, x, padding_mask=None, output_attentions=False): | |
| x, row_attn = self.row_self_attention(x, self_attn_padding_mask=padding_mask) | |
| x, col_attn = self.column_self_attention(x, self_attn_padding_mask=padding_mask) | |
| x = self.feed_forward_layer(x) | |
| return x, row_attn, col_attn | |
| class RNAMSMPreTrainedModel(PreTrainedModel): | |
| config_class = RNAMSMConfig | |
| base_model_prefix = "rnamsm" | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| nn.init.normal_(module.weight, std=0.02) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| nn.init.normal_(module.weight, std=0.02) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| elif isinstance(module, nn.LayerNorm): | |
| nn.init.ones_(module.weight) | |
| nn.init.zeros_(module.bias) | |
| class RNAMSMModel(RNAMSMPreTrainedModel): | |
| """ | |
| RNA-MSM backbone: MSA Transformer that processes multiple-sequence-aligned RNA | |
| sequences and produces per-position embeddings for each alignment row. | |
| Input: input_ids of shape (batch, num_alignments, seqlen) | |
| Output: last_hidden_state of shape (batch, num_alignments, seqlen, embed_dim) | |
| """ | |
| def __init__(self, config: RNAMSMConfig): | |
| super().__init__(config) | |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim, | |
| padding_idx=config.padding_idx) | |
| self.embed_positions = LearnedPositionalEmbedding( | |
| config.max_positions, config.embed_dim, config.padding_idx) | |
| if config.embed_positions_msa: | |
| self.msa_position_embedding = nn.Parameter( | |
| 0.01 * torch.randn(1, config.max_alignments, 1, 1)) | |
| else: | |
| self.register_parameter("msa_position_embedding", None) | |
| self.dropout_module = nn.Dropout(config.dropout) | |
| self.emb_layer_norm_before = nn.LayerNorm(config.embed_dim) | |
| self.emb_layer_norm_after = nn.LayerNorm(config.embed_dim) | |
| self.layers = nn.ModuleList([AxialTransformerLayer(config) | |
| for _ in range(config.num_layers)]) | |
| self.post_init() | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ): | |
| output_hidden_states = (output_hidden_states if output_hidden_states is not None | |
| else self.config.output_hidden_states) | |
| output_attentions = (output_attentions if output_attentions is not None | |
| else self.config.output_attentions) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| assert input_ids.ndim == 3, ( | |
| "RNA-MSM expects 3D input_ids of shape (batch, num_alignments, seqlen). " | |
| "For single sequences, use tokenizer which produces (batch, 1, seqlen).") | |
| batch_size, num_alignments, seqlen = input_ids.size() | |
| # HF convention: attention_mask 1=attend, 0=pad -> padding_mask True=padding | |
| if attention_mask is not None: | |
| padding_mask = attention_mask.eq(0) | |
| else: | |
| padding_mask = input_ids.eq(self.config.padding_idx) | |
| if not padding_mask.any(): | |
| padding_mask = None | |
| # (B, R, C) -> embed: (B, R, C, D) | |
| x = self.embed_tokens(input_ids) | |
| x = x + self.embed_positions( | |
| input_ids.view(batch_size * num_alignments, seqlen) | |
| ).view(batch_size, num_alignments, seqlen, self.config.embed_dim) | |
| if self.msa_position_embedding is not None: | |
| if num_alignments > self.config.max_alignments: | |
| raise RuntimeError( | |
| f"MSA depth {num_alignments} exceeds max_alignments " | |
| f"{self.config.max_alignments}.") | |
| x = x + self.msa_position_embedding[:, :num_alignments] | |
| x = self.emb_layer_norm_before(x) | |
| x = self.dropout_module(x) | |
| if padding_mask is not None: | |
| x = x * (1 - padding_mask.unsqueeze(-1).to(x)) | |
| all_hidden_states = [] | |
| all_row_attentions = [] | |
| all_col_attentions = [] | |
| if output_hidden_states: | |
| all_hidden_states.append(x) | |
| # (B, R, C, D) -> (R, C, B, D) for axial attention | |
| x = x.permute(1, 2, 0, 3) | |
| for layer in self.layers: | |
| x, row_attn, col_attn = layer(x, padding_mask=padding_mask, | |
| output_attentions=output_attentions) | |
| if output_hidden_states: | |
| all_hidden_states.append(x.permute(2, 0, 1, 3)) | |
| if output_attentions: | |
| all_row_attentions.append(row_attn) | |
| all_col_attentions.append(col_attn) | |
| x = self.emb_layer_norm_after(x) | |
| x = x.permute(2, 0, 1, 3) # (R, C, B, D) -> (B, R, C, D) | |
| if output_hidden_states: | |
| all_hidden_states[-1] = x | |
| if not return_dict: | |
| return tuple(v for v in [ | |
| x, | |
| tuple(all_hidden_states) if output_hidden_states else None, | |
| tuple(all_row_attentions) if output_attentions else None, | |
| ] if v is not None) | |
| return BaseModelOutput( | |
| last_hidden_state=x, | |
| hidden_states=tuple(all_hidden_states) if output_hidden_states else None, | |
| attentions=tuple(all_row_attentions) if output_attentions else None, | |
| ) | |
| class RNAMSMForMaskedLM(RNAMSMPreTrainedModel): | |
| _tied_weights_keys = ["lm_head.weight"] | |
| def __init__(self, config: RNAMSMConfig): | |
| super().__init__(config) | |
| self.rnamsm = RNAMSMModel(config) | |
| self.lm_head = RNAMSMLMHead(config, self.rnamsm.embed_tokens.weight) | |
| self.post_init() | |
| def get_output_embeddings(self): | |
| return self.lm_head | |
| def set_output_embeddings(self, new_embeddings): | |
| self.lm_head = new_embeddings | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| labels: Optional[torch.Tensor] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ): | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| out = self.rnamsm( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| output_hidden_states=output_hidden_states, | |
| output_attentions=output_attentions, | |
| return_dict=return_dict, | |
| ) | |
| logits = self.lm_head(out[0] if not return_dict else out.last_hidden_state) | |
| loss = None | |
| if labels is not None: | |
| loss = F.cross_entropy( | |
| logits.view(-1, self.config.vocab_size), | |
| labels.view(-1), | |
| ignore_index=-100, | |
| ) | |
| if not return_dict: | |
| output = (logits,) + out[1:] | |
| return ((loss,) + output) if loss is not None else output | |
| return MaskedLMOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=out.hidden_states, | |
| attentions=out.attentions, | |
| ) | |