Instructions to use Taykhoom/mRNA-FM with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Taykhoom/mRNA-FM with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("fill-mask", model="Taykhoom/mRNA-FM", trust_remote_code=True)# Load model directly from transformers import AutoModelForMaskedLM model = AutoModelForMaskedLM.from_pretrained("Taykhoom/mRNA-FM", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import PreTrainedModel | |
| from transformers.modeling_outputs import BaseModelOutput, MaskedLMOutput | |
| try: | |
| from .configuration_rnafm import RnaFmConfig | |
| except ImportError: | |
| from configuration_rnafm import RnaFmConfig | |
| def gelu(x): | |
| return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | |
| class RnaFmLearnedPositionalEmbedding(nn.Embedding): | |
| def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): | |
| num_embeddings_ = num_embeddings + padding_idx + 1 | |
| super().__init__(num_embeddings_, embedding_dim, padding_idx) | |
| self.max_positions = num_embeddings | |
| def forward(self, input: torch.Tensor): | |
| mask = input.ne(self.padding_idx).int() | |
| positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + self.padding_idx | |
| return F.embedding( | |
| positions, | |
| self.weight, | |
| self.padding_idx, | |
| self.max_norm, | |
| self.norm_type, | |
| self.scale_grad_by_freq, | |
| self.sparse, | |
| ) | |
| class RnaFmAttention(nn.Module): | |
| def __init__(self, config: RnaFmConfig): | |
| super().__init__() | |
| self.embed_dim = config.embed_dim | |
| self.num_heads = config.attention_heads | |
| self.head_dim = config.embed_dim // config.attention_heads | |
| self.scaling = self.head_dim ** -0.5 | |
| self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) | |
| self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) | |
| self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) | |
| self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) | |
| def _project(self, x): | |
| tgt_len, bsz, _ = x.size() | |
| q = self.q_proj(x) * self.scaling | |
| k = self.k_proj(x) | |
| v = self.v_proj(x) | |
| q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) | |
| k = k.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) | |
| v = v.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) | |
| return q, k, v, tgt_len, bsz | |
| def forward(self, x, key_padding_mask=None, output_attentions=False): | |
| q, k, v, tgt_len, bsz = self._project(x) | |
| attn_weights = torch.bmm(q, k.transpose(1, 2)) | |
| if key_padding_mask is not None: | |
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, tgt_len) | |
| attn_weights = attn_weights.masked_fill( | |
| key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf") | |
| ) | |
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, tgt_len) | |
| attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) | |
| attn_probs = attn_weights_float.type_as(attn_weights) | |
| attn = torch.bmm(attn_probs, v) | |
| attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) | |
| attn = self.out_proj(attn) | |
| if output_attentions: | |
| weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, tgt_len) | |
| return attn, weights | |
| return attn, None | |
| class RnaFmSdpaAttention(RnaFmAttention): | |
| def forward(self, x, key_padding_mask=None, output_attentions=False): | |
| if output_attentions: | |
| return super().forward(x, key_padding_mask, output_attentions=True) | |
| tgt_len, bsz, _ = x.size() | |
| q = self.q_proj(x) | |
| k = self.k_proj(x) | |
| v = self.v_proj(x) | |
| q = q.view(tgt_len, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3) | |
| k = k.view(tgt_len, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3) | |
| v = v.view(tgt_len, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3) | |
| attn_mask = None | |
| if key_padding_mask is not None: | |
| attn_mask = torch.zeros(bsz, 1, 1, tgt_len, dtype=q.dtype, device=q.device) | |
| attn_mask = attn_mask.masked_fill( | |
| key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf") | |
| ) | |
| out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) | |
| out = out.permute(2, 0, 1, 3).contiguous().view(tgt_len, bsz, self.embed_dim) | |
| return self.out_proj(out), None | |
| class RnaFmFlashAttention2(RnaFmAttention): | |
| def forward(self, x, key_padding_mask=None, output_attentions=False): | |
| if output_attentions: | |
| return super().forward(x, key_padding_mask, output_attentions=True) | |
| try: | |
| from flash_attn import flash_attn_func | |
| from flash_attn.bert_padding import pad_input, unpad_input | |
| except ImportError as e: | |
| raise ImportError( | |
| "flash_attn is required for attn_implementation='flash_attention_2'. " | |
| "Install with: pip install flash-attn --no-build-isolation" | |
| ) from e | |
| tgt_len, bsz, _ = x.size() | |
| q = self.q_proj(x) | |
| k = self.k_proj(x) | |
| v = self.v_proj(x) | |
| q = q.view(tgt_len, bsz, self.num_heads, self.head_dim).permute(1, 0, 2, 3) | |
| k = k.view(tgt_len, bsz, self.num_heads, self.head_dim).permute(1, 0, 2, 3) | |
| v = v.view(tgt_len, bsz, self.num_heads, self.head_dim).permute(1, 0, 2, 3) | |
| orig_dtype = q.dtype | |
| if q.dtype not in (torch.float16, torch.bfloat16): | |
| q, k, v = q.to(torch.bfloat16), k.to(torch.bfloat16), v.to(torch.bfloat16) | |
| softmax_scale = self.head_dim ** -0.5 | |
| if key_padding_mask is not None and key_padding_mask.any(): | |
| attention_mask_bool = ~key_padding_mask | |
| q_unpad, indices, cu_seqlens, max_seqlen, _ = unpad_input(q, attention_mask_bool) | |
| k_unpad, *_ = unpad_input(k, attention_mask_bool) | |
| v_unpad, *_ = unpad_input(v, attention_mask_bool) | |
| from flash_attn import flash_attn_varlen_func | |
| out_unpad = flash_attn_varlen_func( | |
| q_unpad, k_unpad, v_unpad, | |
| cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, | |
| max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, | |
| softmax_scale=softmax_scale, | |
| causal=False, | |
| ) | |
| out = pad_input(out_unpad, indices, bsz, tgt_len) | |
| else: | |
| out = flash_attn_func(q, k, v, softmax_scale=softmax_scale, causal=False) | |
| out = out.to(orig_dtype).permute(1, 0, 2, 3).contiguous().view(tgt_len, bsz, self.embed_dim) | |
| return self.out_proj(out), None | |
| RNAFM_ATTENTION_CLASSES = { | |
| "eager": RnaFmAttention, | |
| "sdpa": RnaFmSdpaAttention, | |
| "flash_attention_2": RnaFmFlashAttention2, | |
| } | |
| class RnaFmLayer(nn.Module): | |
| def __init__(self, config: RnaFmConfig): | |
| super().__init__() | |
| attn_cls = RNAFM_ATTENTION_CLASSES[getattr(config, "_attn_implementation", "eager")] | |
| self.self_attn = attn_cls(config) | |
| self.self_attn_layer_norm = nn.LayerNorm(config.embed_dim) | |
| self.fc1 = nn.Linear(config.embed_dim, config.ffn_embed_dim) | |
| self.fc2 = nn.Linear(config.ffn_embed_dim, config.embed_dim) | |
| self.final_layer_norm = nn.LayerNorm(config.embed_dim) | |
| def forward(self, x, key_padding_mask=None, output_attentions=False): | |
| residual = x | |
| x = self.self_attn_layer_norm(x) | |
| x, attn = self.self_attn(x, key_padding_mask=key_padding_mask, output_attentions=output_attentions) | |
| x = residual + x | |
| residual = x | |
| x = self.final_layer_norm(x) | |
| x = gelu(self.fc1(x)) | |
| x = self.fc2(x) | |
| x = residual + x | |
| return x, attn | |
| class RnaFmPreTrainedModel(PreTrainedModel): | |
| config_class = RnaFmConfig | |
| base_model_prefix = "rnafm" | |
| _supports_sdpa = True | |
| _supports_flash_attn_2 = True | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| nn.init.xavier_uniform_(module.weight) | |
| if module.bias is not None: | |
| nn.init.constant_(module.bias, 0.0) | |
| elif isinstance(module, nn.Embedding): | |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| class RnaFmModel(RnaFmPreTrainedModel): | |
| def __init__(self, config: RnaFmConfig): | |
| super().__init__(config) | |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim, padding_idx=config.padding_idx) | |
| self.embed_positions = RnaFmLearnedPositionalEmbedding(config.model_max_length, config.embed_dim, config.padding_idx) | |
| self.emb_layer_norm_before = nn.LayerNorm(config.embed_dim) if config.emb_layer_norm_before else None | |
| self.layers = nn.ModuleList([RnaFmLayer(config) for _ in range(config.num_layers)]) | |
| self.emb_layer_norm_after = nn.LayerNorm(config.embed_dim) | |
| self.post_init() | |
| def forward( | |
| self, | |
| input_ids, | |
| attention_mask=None, | |
| output_hidden_states=None, | |
| output_attentions=None, | |
| return_dict=None, | |
| ): | |
| output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if attention_mask is not None: | |
| padding_mask = attention_mask.eq(0) | |
| else: | |
| padding_mask = input_ids.eq(self.config.padding_idx) | |
| x = self.embed_tokens(input_ids) | |
| if self.config.token_dropout: | |
| x.masked_fill_((input_ids == self.config.mask_idx).unsqueeze(-1), 0.0) | |
| mask_ratio_train = 0.15 * 0.8 | |
| src_lengths = (~padding_mask).sum(-1) | |
| mask_ratio_observed = (input_ids == self.config.mask_idx).sum(-1).to(x.dtype) / src_lengths.to(x.dtype) | |
| x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None] | |
| x = x + self.embed_positions(input_ids) | |
| if self.emb_layer_norm_before is not None: | |
| x = self.emb_layer_norm_before(x) | |
| if padding_mask.any(): | |
| x = x * (1 - padding_mask.unsqueeze(-1).to(x.dtype)) | |
| else: | |
| padding_mask = None | |
| all_hidden_states = [] | |
| all_attentions = [] | |
| if output_hidden_states: | |
| all_hidden_states.append(x) | |
| x = x.transpose(0, 1) | |
| for layer in self.layers: | |
| x, attn = layer(x, key_padding_mask=padding_mask, output_attentions=output_attentions) | |
| if output_hidden_states: | |
| all_hidden_states.append(x.transpose(0, 1)) | |
| if output_attentions and attn is not None: | |
| all_attentions.append(attn) | |
| x = self.emb_layer_norm_after(x) | |
| x = x.transpose(0, 1) | |
| if output_hidden_states: | |
| all_hidden_states[-1] = x | |
| return BaseModelOutput( | |
| last_hidden_state=x, | |
| hidden_states=tuple(all_hidden_states) if output_hidden_states else None, | |
| attentions=tuple(all_attentions) if output_attentions else None, | |
| ) | |
| class RnaFmLMHead(nn.Module): | |
| def __init__(self, config: RnaFmConfig): | |
| super().__init__() | |
| self.dense = nn.Linear(config.embed_dim, config.embed_dim) | |
| self.layer_norm = nn.LayerNorm(config.embed_dim) | |
| self.decoder = nn.Linear(config.embed_dim, config.vocab_size, bias=True) | |
| def forward(self, features): | |
| x = self.dense(features) | |
| x = gelu(x) | |
| x = self.layer_norm(x) | |
| x = self.decoder(x) | |
| return x | |
| class RnaFmForMaskedLM(RnaFmPreTrainedModel): | |
| _tied_weights_keys = ["lm_head.decoder.weight"] | |
| def __init__(self, config: RnaFmConfig): | |
| super().__init__(config) | |
| self.rnafm = RnaFmModel(config) | |
| self.lm_head = RnaFmLMHead(config) | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| return self.rnafm.embed_tokens | |
| def set_input_embeddings(self, value): | |
| self.rnafm.embed_tokens = value | |
| def get_output_embeddings(self): | |
| return self.lm_head.decoder | |
| def set_output_embeddings(self, new_embeddings): | |
| self.lm_head.decoder = new_embeddings | |
| def forward( | |
| self, | |
| input_ids, | |
| attention_mask=None, | |
| labels=None, | |
| output_hidden_states=None, | |
| output_attentions=None, | |
| return_dict=None, | |
| ): | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| out = self.rnafm( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| output_hidden_states=output_hidden_states, | |
| output_attentions=output_attentions, | |
| return_dict=return_dict, | |
| ) | |
| logits = self.lm_head(out.last_hidden_state) | |
| loss = None | |
| if labels is not None: | |
| loss = F.cross_entropy( | |
| logits.view(-1, self.config.vocab_size), | |
| labels.view(-1), | |
| ignore_index=-100, | |
| ) | |
| return MaskedLMOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=out.hidden_states, | |
| attentions=out.attentions, | |
| ) | |