import math from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import BaseModelOutput, MaskedLMOutput try: from .configuration_rnamsm import RNAMSMConfig except ImportError: from configuration_rnamsm import RNAMSMConfig def gelu(x): return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) class RNAMSMLMHead(nn.Module): def __init__(self, config: RNAMSMConfig, embed_tokens_weight: nn.Parameter): super().__init__() self.dense = nn.Linear(config.embed_dim, config.embed_dim) self.layer_norm = nn.LayerNorm(config.embed_dim) self.weight = embed_tokens_weight self.bias = nn.Parameter(torch.zeros(config.vocab_size)) def forward(self, x): x = self.dense(x) x = gelu(x) x = self.layer_norm(x) return F.linear(x, self.weight) + self.bias class LearnedPositionalEmbedding(nn.Embedding): def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): num_embeddings_ = num_embeddings + padding_idx + 1 super().__init__(num_embeddings_, embedding_dim, padding_idx) self.max_positions = num_embeddings def forward(self, tokens: torch.Tensor) -> torch.Tensor: mask = tokens.ne(self.padding_idx).int() positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + self.padding_idx return F.embedding(positions, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) class NormalizedResidualBlock(nn.Module): def __init__(self, layer: nn.Module, embedding_dim: int, dropout: float): super().__init__() self.layer = layer self.layer_norm = nn.LayerNorm(embedding_dim) self.dropout_module = nn.Dropout(dropout) def forward(self, x, *args, **kwargs): residual = x x = self.layer_norm(x) outputs = self.layer(x, *args, **kwargs) if isinstance(outputs, tuple): x, *out = outputs else: x, out = outputs, None x = self.dropout_module(x) x = residual + x if out is not None: return (x,) + tuple(out) return x class FeedForwardNetwork(nn.Module): def __init__(self, embedding_dim: int, ffn_embedding_dim: int, activation_dropout: float, max_tokens_per_msa: int): super().__init__() self.fc1 = nn.Linear(embedding_dim, ffn_embedding_dim) self.fc2 = nn.Linear(ffn_embedding_dim, embedding_dim) self.activation_fn = nn.GELU() self.activation_dropout_module = nn.Dropout(activation_dropout) self.max_tokens_per_msa = max_tokens_per_msa def forward(self, x): x = self.activation_fn(self.fc1(x)) x = self.activation_dropout_module(x) return self.fc2(x) class RowSelfAttention(nn.Module): """Self-attention across columns (sequence positions), summed over MSA rows.""" def __init__(self, embed_dim: int, num_heads: int, dropout: float, max_tokens_per_msa: int): super().__init__() self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads self.scaling = self.head_dim ** -0.5 self.max_tokens_per_msa = max_tokens_per_msa self.attn_shape = "hnij" self.q_proj = nn.Linear(embed_dim, embed_dim) self.k_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) self.dropout_module = nn.Dropout(dropout) def align_scaling(self, q): return self.scaling / math.sqrt(q.size(0)) def compute_attention_weights(self, x, scaling, padding_mask=None): num_rows, num_cols, batch_size, embed_dim = x.size() q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) q = q * scaling if padding_mask is not None: q = q * (1 - padding_mask.permute(1, 2, 0).unsqueeze(3).unsqueeze(4).to(q)) attn_weights = torch.einsum(f"rinhd,rjnhd->{self.attn_shape}", q, k) if padding_mask is not None: attn_weights = attn_weights.masked_fill( padding_mask[:, 0].unsqueeze(0).unsqueeze(2), -10000.0) return attn_weights def compute_attention_update(self, x, attn_probs): num_rows, num_cols, batch_size, embed_dim = x.size() v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) context = torch.einsum(f"{self.attn_shape},rjnhd->rinhd", attn_probs, v) context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim) return self.out_proj(context) def _batched_forward(self, x, padding_mask=None): num_rows, num_cols, batch_size, _ = x.size() max_rows = max(1, self.max_tokens_per_msa // num_cols) scaling = self.align_scaling(x) attns = 0 for start in range(0, num_rows, max_rows): pm = padding_mask[:, start:start + max_rows] if padding_mask is not None else None attns = attns + self.compute_attention_weights(x[start:start + max_rows], scaling, pm) attn_probs = attns.softmax(-1) attn_probs = self.dropout_module(attn_probs) outputs = [self.compute_attention_update(x[start:start + max_rows], attn_probs) for start in range(0, num_rows, max_rows)] return torch.cat(outputs, 0), attn_probs def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None): num_rows, num_cols, batch_size, _ = x.size() if num_rows * num_cols > self.max_tokens_per_msa and not torch.is_grad_enabled(): return self._batched_forward(x, self_attn_padding_mask) scaling = self.align_scaling(x) attn_weights = self.compute_attention_weights(x, scaling, self_attn_padding_mask) attn_probs = attn_weights.softmax(-1) attn_probs = self.dropout_module(attn_probs) output = self.compute_attention_update(x, attn_probs) return output, attn_probs class ColumnSelfAttention(nn.Module): """Self-attention across MSA rows (alignment depth) per sequence position.""" def __init__(self, embed_dim: int, num_heads: int, dropout: float, max_tokens_per_msa: int): super().__init__() self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads self.scaling = self.head_dim ** -0.5 self.max_tokens_per_msa = max_tokens_per_msa self.q_proj = nn.Linear(embed_dim, embed_dim) self.k_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) self.dropout_module = nn.Dropout(dropout) def compute_attention_update(self, x, self_attn_padding_mask=None): num_rows, num_cols, batch_size, embed_dim = x.size() if num_rows == 1: attn_probs = torch.ones(self.num_heads, num_cols, batch_size, 1, 1, device=x.device, dtype=x.dtype) output = self.out_proj(self.v_proj(x)) else: q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) q = q * self.scaling attn_weights = torch.einsum("icnhd,jcnhd->hcnij", q, k) if self_attn_padding_mask is not None: attn_weights = attn_weights.masked_fill( self_attn_padding_mask.permute(2, 0, 1).unsqueeze(0).unsqueeze(3), -10000.0) attn_probs = attn_weights.softmax(-1) attn_probs = self.dropout_module(attn_probs) context = torch.einsum("hcnij,jcnhd->icnhd", attn_probs, v) context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim) output = self.out_proj(context) return output, attn_probs def _batched_forward(self, x, self_attn_padding_mask=None): num_rows, num_cols, batch_size, _ = x.size() max_cols = max(1, self.max_tokens_per_msa // num_rows) outputs, attns = [], [] for start in range(0, num_cols, max_cols): pm = (self_attn_padding_mask[:, :, start:start + max_cols] if self_attn_padding_mask is not None else None) out, attn = self.compute_attention_update(x[:, start:start + max_cols], pm) outputs.append(out) attns.append(attn) return torch.cat(outputs, 1), torch.cat(attns, 1) def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None): num_rows, num_cols, batch_size, _ = x.size() if num_rows * num_cols > self.max_tokens_per_msa and not torch.is_grad_enabled(): return self._batched_forward(x, self_attn_padding_mask) return self.compute_attention_update(x, self_attn_padding_mask) class AxialTransformerLayer(nn.Module): def __init__(self, config: RNAMSMConfig): super().__init__() self.row_self_attention = NormalizedResidualBlock( RowSelfAttention(config.embed_dim, config.num_attention_heads, config.attention_dropout, config.max_tokens_per_msa), config.embed_dim, config.dropout, ) self.column_self_attention = NormalizedResidualBlock( ColumnSelfAttention(config.embed_dim, config.num_attention_heads, config.attention_dropout, config.max_tokens_per_msa), config.embed_dim, config.dropout, ) self.feed_forward_layer = NormalizedResidualBlock( FeedForwardNetwork(config.embed_dim, config.ffn_embed_dim, config.activation_dropout, config.max_tokens_per_msa), config.embed_dim, config.dropout, ) def forward(self, x, padding_mask=None, output_attentions=False): x, row_attn = self.row_self_attention(x, self_attn_padding_mask=padding_mask) x, col_attn = self.column_self_attention(x, self_attn_padding_mask=padding_mask) x = self.feed_forward_layer(x) return x, row_attn, col_attn class RNAMSMPreTrainedModel(PreTrainedModel): config_class = RNAMSMConfig base_model_prefix = "rnamsm" def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, std=0.02) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) class RNAMSMModel(RNAMSMPreTrainedModel): """ RNA-MSM backbone: MSA Transformer that processes multiple-sequence-aligned RNA sequences and produces per-position embeddings for each alignment row. Input: input_ids of shape (batch, num_alignments, seqlen) Output: last_hidden_state of shape (batch, num_alignments, seqlen, embed_dim) """ def __init__(self, config: RNAMSMConfig): super().__init__(config) self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim, padding_idx=config.padding_idx) self.embed_positions = LearnedPositionalEmbedding( config.max_positions, config.embed_dim, config.padding_idx) if config.embed_positions_msa: self.msa_position_embedding = nn.Parameter( 0.01 * torch.randn(1, config.max_alignments, 1, 1)) else: self.register_parameter("msa_position_embedding", None) self.dropout_module = nn.Dropout(config.dropout) self.emb_layer_norm_before = nn.LayerNorm(config.embed_dim) self.emb_layer_norm_after = nn.LayerNorm(config.embed_dim) self.layers = nn.ModuleList([AxialTransformerLayer(config) for _ in range(config.num_layers)]) self.post_init() def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ): output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions) return_dict = return_dict if return_dict is not None else self.config.use_return_dict assert input_ids.ndim == 3, ( "RNA-MSM expects 3D input_ids of shape (batch, num_alignments, seqlen). " "For single sequences, use tokenizer which produces (batch, 1, seqlen).") batch_size, num_alignments, seqlen = input_ids.size() # HF convention: attention_mask 1=attend, 0=pad -> padding_mask True=padding if attention_mask is not None: padding_mask = attention_mask.eq(0) else: padding_mask = input_ids.eq(self.config.padding_idx) if not padding_mask.any(): padding_mask = None # (B, R, C) -> embed: (B, R, C, D) x = self.embed_tokens(input_ids) x = x + self.embed_positions( input_ids.view(batch_size * num_alignments, seqlen) ).view(batch_size, num_alignments, seqlen, self.config.embed_dim) if self.msa_position_embedding is not None: if num_alignments > self.config.max_alignments: raise RuntimeError( f"MSA depth {num_alignments} exceeds max_alignments " f"{self.config.max_alignments}.") x = x + self.msa_position_embedding[:, :num_alignments] x = self.emb_layer_norm_before(x) x = self.dropout_module(x) if padding_mask is not None: x = x * (1 - padding_mask.unsqueeze(-1).to(x)) all_hidden_states = [] all_row_attentions = [] all_col_attentions = [] if output_hidden_states: all_hidden_states.append(x) # (B, R, C, D) -> (R, C, B, D) for axial attention x = x.permute(1, 2, 0, 3) for layer in self.layers: x, row_attn, col_attn = layer(x, padding_mask=padding_mask, output_attentions=output_attentions) if output_hidden_states: all_hidden_states.append(x.permute(2, 0, 1, 3)) if output_attentions: all_row_attentions.append(row_attn) all_col_attentions.append(col_attn) x = self.emb_layer_norm_after(x) x = x.permute(2, 0, 1, 3) # (R, C, B, D) -> (B, R, C, D) if output_hidden_states: all_hidden_states[-1] = x if not return_dict: return tuple(v for v in [ x, tuple(all_hidden_states) if output_hidden_states else None, tuple(all_row_attentions) if output_attentions else None, ] if v is not None) return BaseModelOutput( last_hidden_state=x, hidden_states=tuple(all_hidden_states) if output_hidden_states else None, attentions=tuple(all_row_attentions) if output_attentions else None, ) class RNAMSMForMaskedLM(RNAMSMPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: RNAMSMConfig): super().__init__(config) self.rnamsm = RNAMSMModel(config) self.lm_head = RNAMSMLMHead(config, self.rnamsm.embed_tokens.weight) self.post_init() def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict out = self.rnamsm( input_ids, attention_mask=attention_mask, output_hidden_states=output_hidden_states, output_attentions=output_attentions, return_dict=return_dict, ) logits = self.lm_head(out[0] if not return_dict else out.last_hidden_state) loss = None if labels is not None: loss = F.cross_entropy( logits.view(-1, self.config.vocab_size), labels.view(-1), ignore_index=-100, ) if not return_dict: output = (logits,) + out[1:] return ((loss,) + output) if loss is not None else output return MaskedLMOutput( loss=loss, logits=logits, hidden_states=out.hidden_states, attentions=out.attentions, )