RNA-FM / modeling_rnafm.py
Taykhoom's picture
Upload folder using huggingface_hub
6af52a6 verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput, MaskedLMOutput
try:
from .configuration_rnafm import RnaFmConfig
except ImportError:
from configuration_rnafm import RnaFmConfig
def gelu(x):
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
class RnaFmLearnedPositionalEmbedding(nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
num_embeddings_ = num_embeddings + padding_idx + 1
super().__init__(num_embeddings_, embedding_dim, padding_idx)
self.max_positions = num_embeddings
def forward(self, input: torch.Tensor):
mask = input.ne(self.padding_idx).int()
positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + self.padding_idx
return F.embedding(
positions,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
class RnaFmAttention(nn.Module):
def __init__(self, config: RnaFmConfig):
super().__init__()
self.embed_dim = config.embed_dim
self.num_heads = config.attention_heads
self.head_dim = config.embed_dim // config.attention_heads
self.scaling = self.head_dim ** -0.5
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def _project(self, x):
tgt_len, bsz, _ = x.size()
q = self.q_proj(x) * self.scaling
k = self.k_proj(x)
v = self.v_proj(x)
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
k = k.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
v = v.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
return q, k, v, tgt_len, bsz
def forward(self, x, key_padding_mask=None, output_attentions=False):
q, k, v, tgt_len, bsz = self._project(x)
attn_weights = torch.bmm(q, k.transpose(1, 2))
if key_padding_mask is not None:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, tgt_len)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf")
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, tgt_len)
attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32)
attn_probs = attn_weights_float.type_as(attn_weights)
attn = torch.bmm(attn_probs, v)
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
attn = self.out_proj(attn)
if output_attentions:
weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, tgt_len)
return attn, weights
return attn, None
class RnaFmSdpaAttention(RnaFmAttention):
def forward(self, x, key_padding_mask=None, output_attentions=False):
if output_attentions:
return super().forward(x, key_padding_mask, output_attentions=True)
tgt_len, bsz, _ = x.size()
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
q = q.view(tgt_len, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3)
k = k.view(tgt_len, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3)
v = v.view(tgt_len, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3)
attn_mask = None
if key_padding_mask is not None:
attn_mask = torch.zeros(bsz, 1, 1, tgt_len, dtype=q.dtype, device=q.device)
attn_mask = attn_mask.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf")
)
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
out = out.permute(2, 0, 1, 3).contiguous().view(tgt_len, bsz, self.embed_dim)
return self.out_proj(out), None
class RnaFmFlashAttention2(RnaFmAttention):
def forward(self, x, key_padding_mask=None, output_attentions=False):
if output_attentions:
return super().forward(x, key_padding_mask, output_attentions=True)
try:
from flash_attn import flash_attn_func
from flash_attn.bert_padding import pad_input, unpad_input
except ImportError as e:
raise ImportError(
"flash_attn is required for attn_implementation='flash_attention_2'. "
"Install with: pip install flash-attn --no-build-isolation"
) from e
tgt_len, bsz, _ = x.size()
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
q = q.view(tgt_len, bsz, self.num_heads, self.head_dim).permute(1, 0, 2, 3)
k = k.view(tgt_len, bsz, self.num_heads, self.head_dim).permute(1, 0, 2, 3)
v = v.view(tgt_len, bsz, self.num_heads, self.head_dim).permute(1, 0, 2, 3)
orig_dtype = q.dtype
if q.dtype not in (torch.float16, torch.bfloat16):
q, k, v = q.to(torch.bfloat16), k.to(torch.bfloat16), v.to(torch.bfloat16)
softmax_scale = self.head_dim ** -0.5
if key_padding_mask is not None and key_padding_mask.any():
attention_mask_bool = ~key_padding_mask
q_unpad, indices, cu_seqlens, max_seqlen, _ = unpad_input(q, attention_mask_bool)
k_unpad, *_ = unpad_input(k, attention_mask_bool)
v_unpad, *_ = unpad_input(v, attention_mask_bool)
from flash_attn import flash_attn_varlen_func
out_unpad = flash_attn_varlen_func(
q_unpad, k_unpad, v_unpad,
cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen,
softmax_scale=softmax_scale,
causal=False,
)
out = pad_input(out_unpad, indices, bsz, tgt_len)
else:
out = flash_attn_func(q, k, v, softmax_scale=softmax_scale, causal=False)
out = out.to(orig_dtype).permute(1, 0, 2, 3).contiguous().view(tgt_len, bsz, self.embed_dim)
return self.out_proj(out), None
RNAFM_ATTENTION_CLASSES = {
"eager": RnaFmAttention,
"sdpa": RnaFmSdpaAttention,
"flash_attention_2": RnaFmFlashAttention2,
}
class RnaFmLayer(nn.Module):
def __init__(self, config: RnaFmConfig):
super().__init__()
attn_cls = RNAFM_ATTENTION_CLASSES[getattr(config, "_attn_implementation", "eager")]
self.self_attn = attn_cls(config)
self.self_attn_layer_norm = nn.LayerNorm(config.embed_dim)
self.fc1 = nn.Linear(config.embed_dim, config.ffn_embed_dim)
self.fc2 = nn.Linear(config.ffn_embed_dim, config.embed_dim)
self.final_layer_norm = nn.LayerNorm(config.embed_dim)
def forward(self, x, key_padding_mask=None, output_attentions=False):
residual = x
x = self.self_attn_layer_norm(x)
x, attn = self.self_attn(x, key_padding_mask=key_padding_mask, output_attentions=output_attentions)
x = residual + x
residual = x
x = self.final_layer_norm(x)
x = gelu(self.fc1(x))
x = self.fc2(x)
x = residual + x
return x, attn
class RnaFmPreTrainedModel(PreTrainedModel):
config_class = RnaFmConfig
base_model_prefix = "rnafm"
_supports_sdpa = True
_supports_flash_attn_2 = True
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0.0)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class RnaFmModel(RnaFmPreTrainedModel):
def __init__(self, config: RnaFmConfig):
super().__init__(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim, padding_idx=config.padding_idx)
self.embed_positions = RnaFmLearnedPositionalEmbedding(config.model_max_length, config.embed_dim, config.padding_idx)
self.emb_layer_norm_before = nn.LayerNorm(config.embed_dim) if config.emb_layer_norm_before else None
self.layers = nn.ModuleList([RnaFmLayer(config) for _ in range(config.num_layers)])
self.emb_layer_norm_after = nn.LayerNorm(config.embed_dim)
self.post_init()
def forward(
self,
input_ids,
attention_mask=None,
output_hidden_states=None,
output_attentions=None,
return_dict=None,
):
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if attention_mask is not None:
padding_mask = attention_mask.eq(0)
else:
padding_mask = input_ids.eq(self.config.padding_idx)
x = self.embed_tokens(input_ids)
if self.config.token_dropout:
x.masked_fill_((input_ids == self.config.mask_idx).unsqueeze(-1), 0.0)
mask_ratio_train = 0.15 * 0.8
src_lengths = (~padding_mask).sum(-1)
mask_ratio_observed = (input_ids == self.config.mask_idx).sum(-1).to(x.dtype) / src_lengths.to(x.dtype)
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
x = x + self.embed_positions(input_ids)
if self.emb_layer_norm_before is not None:
x = self.emb_layer_norm_before(x)
if padding_mask.any():
x = x * (1 - padding_mask.unsqueeze(-1).to(x.dtype))
else:
padding_mask = None
all_hidden_states = []
all_attentions = []
if output_hidden_states:
all_hidden_states.append(x)
x = x.transpose(0, 1)
for layer in self.layers:
x, attn = layer(x, key_padding_mask=padding_mask, output_attentions=output_attentions)
if output_hidden_states:
all_hidden_states.append(x.transpose(0, 1))
if output_attentions and attn is not None:
all_attentions.append(attn)
x = self.emb_layer_norm_after(x)
x = x.transpose(0, 1)
if output_hidden_states:
all_hidden_states[-1] = x
return BaseModelOutput(
last_hidden_state=x,
hidden_states=tuple(all_hidden_states) if output_hidden_states else None,
attentions=tuple(all_attentions) if output_attentions else None,
)
class RnaFmLMHead(nn.Module):
def __init__(self, config: RnaFmConfig):
super().__init__()
self.dense = nn.Linear(config.embed_dim, config.embed_dim)
self.layer_norm = nn.LayerNorm(config.embed_dim)
self.decoder = nn.Linear(config.embed_dim, config.vocab_size, bias=True)
def forward(self, features):
x = self.dense(features)
x = gelu(x)
x = self.layer_norm(x)
x = self.decoder(x)
return x
class RnaFmForMaskedLM(RnaFmPreTrainedModel):
_tied_weights_keys = ["lm_head.decoder.weight"]
def __init__(self, config: RnaFmConfig):
super().__init__(config)
self.rnafm = RnaFmModel(config)
self.lm_head = RnaFmLMHead(config)
self.post_init()
def get_input_embeddings(self):
return self.rnafm.embed_tokens
def set_input_embeddings(self, value):
self.rnafm.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
def forward(
self,
input_ids,
attention_mask=None,
labels=None,
output_hidden_states=None,
output_attentions=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
out = self.rnafm(
input_ids,
attention_mask=attention_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
)
logits = self.lm_head(out.last_hidden_state)
loss = None
if labels is not None:
loss = F.cross_entropy(
logits.view(-1, self.config.vocab_size),
labels.view(-1),
ignore_index=-100,
)
return MaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=out.hidden_states,
attentions=out.attentions,
)