RNA-MSM / modeling_rnamsm.py
Taykhoom's picture
Upload folder using huggingface_hub
00e6e55 verified
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput, MaskedLMOutput
try:
from .configuration_rnamsm import RNAMSMConfig
except ImportError:
from configuration_rnamsm import RNAMSMConfig
def gelu(x):
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
class RNAMSMLMHead(nn.Module):
def __init__(self, config: RNAMSMConfig, embed_tokens_weight: nn.Parameter):
super().__init__()
self.dense = nn.Linear(config.embed_dim, config.embed_dim)
self.layer_norm = nn.LayerNorm(config.embed_dim)
self.weight = embed_tokens_weight
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
def forward(self, x):
x = self.dense(x)
x = gelu(x)
x = self.layer_norm(x)
return F.linear(x, self.weight) + self.bias
class LearnedPositionalEmbedding(nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
num_embeddings_ = num_embeddings + padding_idx + 1
super().__init__(num_embeddings_, embedding_dim, padding_idx)
self.max_positions = num_embeddings
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
mask = tokens.ne(self.padding_idx).int()
positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + self.padding_idx
return F.embedding(positions, self.weight, self.padding_idx,
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse)
class NormalizedResidualBlock(nn.Module):
def __init__(self, layer: nn.Module, embedding_dim: int, dropout: float):
super().__init__()
self.layer = layer
self.layer_norm = nn.LayerNorm(embedding_dim)
self.dropout_module = nn.Dropout(dropout)
def forward(self, x, *args, **kwargs):
residual = x
x = self.layer_norm(x)
outputs = self.layer(x, *args, **kwargs)
if isinstance(outputs, tuple):
x, *out = outputs
else:
x, out = outputs, None
x = self.dropout_module(x)
x = residual + x
if out is not None:
return (x,) + tuple(out)
return x
class FeedForwardNetwork(nn.Module):
def __init__(self, embedding_dim: int, ffn_embedding_dim: int,
activation_dropout: float, max_tokens_per_msa: int):
super().__init__()
self.fc1 = nn.Linear(embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, embedding_dim)
self.activation_fn = nn.GELU()
self.activation_dropout_module = nn.Dropout(activation_dropout)
self.max_tokens_per_msa = max_tokens_per_msa
def forward(self, x):
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
return self.fc2(x)
class RowSelfAttention(nn.Module):
"""Self-attention across columns (sequence positions), summed over MSA rows."""
def __init__(self, embed_dim: int, num_heads: int, dropout: float, max_tokens_per_msa: int):
super().__init__()
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.scaling = self.head_dim ** -0.5
self.max_tokens_per_msa = max_tokens_per_msa
self.attn_shape = "hnij"
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.dropout_module = nn.Dropout(dropout)
def align_scaling(self, q):
return self.scaling / math.sqrt(q.size(0))
def compute_attention_weights(self, x, scaling, padding_mask=None):
num_rows, num_cols, batch_size, embed_dim = x.size()
q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
q = q * scaling
if padding_mask is not None:
q = q * (1 - padding_mask.permute(1, 2, 0).unsqueeze(3).unsqueeze(4).to(q))
attn_weights = torch.einsum(f"rinhd,rjnhd->{self.attn_shape}", q, k)
if padding_mask is not None:
attn_weights = attn_weights.masked_fill(
padding_mask[:, 0].unsqueeze(0).unsqueeze(2), -10000.0)
return attn_weights
def compute_attention_update(self, x, attn_probs):
num_rows, num_cols, batch_size, embed_dim = x.size()
v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
context = torch.einsum(f"{self.attn_shape},rjnhd->rinhd", attn_probs, v)
context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
return self.out_proj(context)
def _batched_forward(self, x, padding_mask=None):
num_rows, num_cols, batch_size, _ = x.size()
max_rows = max(1, self.max_tokens_per_msa // num_cols)
scaling = self.align_scaling(x)
attns = 0
for start in range(0, num_rows, max_rows):
pm = padding_mask[:, start:start + max_rows] if padding_mask is not None else None
attns = attns + self.compute_attention_weights(x[start:start + max_rows], scaling, pm)
attn_probs = attns.softmax(-1)
attn_probs = self.dropout_module(attn_probs)
outputs = [self.compute_attention_update(x[start:start + max_rows], attn_probs)
for start in range(0, num_rows, max_rows)]
return torch.cat(outputs, 0), attn_probs
def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None):
num_rows, num_cols, batch_size, _ = x.size()
if num_rows * num_cols > self.max_tokens_per_msa and not torch.is_grad_enabled():
return self._batched_forward(x, self_attn_padding_mask)
scaling = self.align_scaling(x)
attn_weights = self.compute_attention_weights(x, scaling, self_attn_padding_mask)
attn_probs = attn_weights.softmax(-1)
attn_probs = self.dropout_module(attn_probs)
output = self.compute_attention_update(x, attn_probs)
return output, attn_probs
class ColumnSelfAttention(nn.Module):
"""Self-attention across MSA rows (alignment depth) per sequence position."""
def __init__(self, embed_dim: int, num_heads: int, dropout: float, max_tokens_per_msa: int):
super().__init__()
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.scaling = self.head_dim ** -0.5
self.max_tokens_per_msa = max_tokens_per_msa
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.dropout_module = nn.Dropout(dropout)
def compute_attention_update(self, x, self_attn_padding_mask=None):
num_rows, num_cols, batch_size, embed_dim = x.size()
if num_rows == 1:
attn_probs = torch.ones(self.num_heads, num_cols, batch_size, 1, 1,
device=x.device, dtype=x.dtype)
output = self.out_proj(self.v_proj(x))
else:
q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
q = q * self.scaling
attn_weights = torch.einsum("icnhd,jcnhd->hcnij", q, k)
if self_attn_padding_mask is not None:
attn_weights = attn_weights.masked_fill(
self_attn_padding_mask.permute(2, 0, 1).unsqueeze(0).unsqueeze(3), -10000.0)
attn_probs = attn_weights.softmax(-1)
attn_probs = self.dropout_module(attn_probs)
context = torch.einsum("hcnij,jcnhd->icnhd", attn_probs, v)
context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
output = self.out_proj(context)
return output, attn_probs
def _batched_forward(self, x, self_attn_padding_mask=None):
num_rows, num_cols, batch_size, _ = x.size()
max_cols = max(1, self.max_tokens_per_msa // num_rows)
outputs, attns = [], []
for start in range(0, num_cols, max_cols):
pm = (self_attn_padding_mask[:, :, start:start + max_cols]
if self_attn_padding_mask is not None else None)
out, attn = self.compute_attention_update(x[:, start:start + max_cols], pm)
outputs.append(out)
attns.append(attn)
return torch.cat(outputs, 1), torch.cat(attns, 1)
def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None):
num_rows, num_cols, batch_size, _ = x.size()
if num_rows * num_cols > self.max_tokens_per_msa and not torch.is_grad_enabled():
return self._batched_forward(x, self_attn_padding_mask)
return self.compute_attention_update(x, self_attn_padding_mask)
class AxialTransformerLayer(nn.Module):
def __init__(self, config: RNAMSMConfig):
super().__init__()
self.row_self_attention = NormalizedResidualBlock(
RowSelfAttention(config.embed_dim, config.num_attention_heads,
config.attention_dropout, config.max_tokens_per_msa),
config.embed_dim, config.dropout,
)
self.column_self_attention = NormalizedResidualBlock(
ColumnSelfAttention(config.embed_dim, config.num_attention_heads,
config.attention_dropout, config.max_tokens_per_msa),
config.embed_dim, config.dropout,
)
self.feed_forward_layer = NormalizedResidualBlock(
FeedForwardNetwork(config.embed_dim, config.ffn_embed_dim,
config.activation_dropout, config.max_tokens_per_msa),
config.embed_dim, config.dropout,
)
def forward(self, x, padding_mask=None, output_attentions=False):
x, row_attn = self.row_self_attention(x, self_attn_padding_mask=padding_mask)
x, col_attn = self.column_self_attention(x, self_attn_padding_mask=padding_mask)
x = self.feed_forward_layer(x)
return x, row_attn, col_attn
class RNAMSMPreTrainedModel(PreTrainedModel):
config_class = RNAMSMConfig
base_model_prefix = "rnamsm"
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
class RNAMSMModel(RNAMSMPreTrainedModel):
"""
RNA-MSM backbone: MSA Transformer that processes multiple-sequence-aligned RNA
sequences and produces per-position embeddings for each alignment row.
Input: input_ids of shape (batch, num_alignments, seqlen)
Output: last_hidden_state of shape (batch, num_alignments, seqlen, embed_dim)
"""
def __init__(self, config: RNAMSMConfig):
super().__init__(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim,
padding_idx=config.padding_idx)
self.embed_positions = LearnedPositionalEmbedding(
config.max_positions, config.embed_dim, config.padding_idx)
if config.embed_positions_msa:
self.msa_position_embedding = nn.Parameter(
0.01 * torch.randn(1, config.max_alignments, 1, 1))
else:
self.register_parameter("msa_position_embedding", None)
self.dropout_module = nn.Dropout(config.dropout)
self.emb_layer_norm_before = nn.LayerNorm(config.embed_dim)
self.emb_layer_norm_after = nn.LayerNorm(config.embed_dim)
self.layers = nn.ModuleList([AxialTransformerLayer(config)
for _ in range(config.num_layers)])
self.post_init()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_hidden_states = (output_hidden_states if output_hidden_states is not None
else self.config.output_hidden_states)
output_attentions = (output_attentions if output_attentions is not None
else self.config.output_attentions)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
assert input_ids.ndim == 3, (
"RNA-MSM expects 3D input_ids of shape (batch, num_alignments, seqlen). "
"For single sequences, use tokenizer which produces (batch, 1, seqlen).")
batch_size, num_alignments, seqlen = input_ids.size()
# HF convention: attention_mask 1=attend, 0=pad -> padding_mask True=padding
if attention_mask is not None:
padding_mask = attention_mask.eq(0)
else:
padding_mask = input_ids.eq(self.config.padding_idx)
if not padding_mask.any():
padding_mask = None
# (B, R, C) -> embed: (B, R, C, D)
x = self.embed_tokens(input_ids)
x = x + self.embed_positions(
input_ids.view(batch_size * num_alignments, seqlen)
).view(batch_size, num_alignments, seqlen, self.config.embed_dim)
if self.msa_position_embedding is not None:
if num_alignments > self.config.max_alignments:
raise RuntimeError(
f"MSA depth {num_alignments} exceeds max_alignments "
f"{self.config.max_alignments}.")
x = x + self.msa_position_embedding[:, :num_alignments]
x = self.emb_layer_norm_before(x)
x = self.dropout_module(x)
if padding_mask is not None:
x = x * (1 - padding_mask.unsqueeze(-1).to(x))
all_hidden_states = []
all_row_attentions = []
all_col_attentions = []
if output_hidden_states:
all_hidden_states.append(x)
# (B, R, C, D) -> (R, C, B, D) for axial attention
x = x.permute(1, 2, 0, 3)
for layer in self.layers:
x, row_attn, col_attn = layer(x, padding_mask=padding_mask,
output_attentions=output_attentions)
if output_hidden_states:
all_hidden_states.append(x.permute(2, 0, 1, 3))
if output_attentions:
all_row_attentions.append(row_attn)
all_col_attentions.append(col_attn)
x = self.emb_layer_norm_after(x)
x = x.permute(2, 0, 1, 3) # (R, C, B, D) -> (B, R, C, D)
if output_hidden_states:
all_hidden_states[-1] = x
if not return_dict:
return tuple(v for v in [
x,
tuple(all_hidden_states) if output_hidden_states else None,
tuple(all_row_attentions) if output_attentions else None,
] if v is not None)
return BaseModelOutput(
last_hidden_state=x,
hidden_states=tuple(all_hidden_states) if output_hidden_states else None,
attentions=tuple(all_row_attentions) if output_attentions else None,
)
class RNAMSMForMaskedLM(RNAMSMPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: RNAMSMConfig):
super().__init__(config)
self.rnamsm = RNAMSMModel(config)
self.lm_head = RNAMSMLMHead(config, self.rnamsm.embed_tokens.weight)
self.post_init()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
out = self.rnamsm(
input_ids,
attention_mask=attention_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
)
logits = self.lm_head(out[0] if not return_dict else out.last_hidden_state)
loss = None
if labels is not None:
loss = F.cross_entropy(
logits.view(-1, self.config.vocab_size),
labels.view(-1),
ignore_index=-100,
)
if not return_dict:
output = (logits,) + out[1:]
return ((loss,) + output) if loss is not None else output
return MaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=out.hidden_states,
attentions=out.attentions,
)