ERNIE-RNA-SS / modeling_ernierna.py
Taykhoom's picture
Upload folder using huggingface_hub
9910b2d verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, MaskedLMOutput
try:
from .configuration_ernierna import ErnieRNAConfig
except ImportError:
from configuration_ernierna import ErnieRNAConfig
class ErnieRNASinusoidalPositionalEmbedding(nn.Module):
def __init__(self, num_positions, embed_dim, padding_idx):
super().__init__()
self.embedding_dim = embed_dim
self.padding_idx = padding_idx
# Table size: need indices up to padding_idx + 1 + num_positions
table_size = padding_idx + 1 + num_positions
self.register_buffer("weights", self._get_embedding(table_size, embed_dim, padding_idx))
@staticmethod
def _get_embedding(num_embeddings, embedding_dim, padding_idx):
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1:
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
def forward(self, input_ids):
mask = input_ids.ne(self.padding_idx).int()
positions = (torch.cumsum(mask, dim=1) * mask).long() + self.padding_idx
return self.weights.index_select(0, positions.view(-1)).view(
input_ids.shape[0], input_ids.shape[1], -1
).detach()
class ErnieRNATwodProj(nn.Module):
def __init__(self, config):
super().__init__()
self.linear1 = nn.Linear(1, 6)
self.linear2 = nn.Linear(6, config.attention_heads)
self.activation_fn = ACT2FN[config.activation_fn]
def forward(self, x):
x = self.linear1(x)
x = self.activation_fn(x)
x = self.linear2(x)
return x
def _compute_pairing_bias(input_ids):
B, T = input_ids.shape
xi = input_ids.unsqueeze(2).expand(B, T, T)
xj = input_ids.unsqueeze(1).expand(B, T, T)
score = torch.zeros(B, T, T, dtype=torch.float32, device=input_ids.device)
score[(xi == 5) & (xj == 6)] = 2.0
score[(xi == 6) & (xj == 5)] = 2.0
score[(xi == 4) & (xj == 7)] = 3.0
score[(xi == 7) & (xj == 4)] = 3.0
score[(xi == 4) & (xj == 6)] = 0.8
score[(xi == 6) & (xj == 4)] = 0.8
return score.unsqueeze(-1) # [B, T, T, 1]
class ErnieRNAAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.embed_dim = config.embed_dim
self.num_heads = config.attention_heads
self.head_dim = self.embed_dim // self.num_heads
assert self.head_dim * self.num_heads == self.embed_dim
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.dropout = nn.Dropout(config.attention_dropout)
def _to_bh_t_hd(self, tensor, tgt_len, bsz):
return tensor.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
def forward(self, x, key_padding_mask=None, twod_bias=None, output_attentions=False):
tgt_len, bsz, _ = x.size()
q = self._to_bh_t_hd(self.q_proj(x), tgt_len, bsz)
k = self._to_bh_t_hd(self.k_proj(x), tgt_len, bsz)
v = self._to_bh_t_hd(self.v_proj(x), tgt_len, bsz)
scale = self.head_dim ** -0.5
q = q * scale
attn_weights = torch.bmm(q, k.transpose(-2, -1)) # [B*H, T, T]
if key_padding_mask is not None:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, tgt_len)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf")
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, tgt_len)
if twod_bias is not None:
attn_weights = attn_weights + twod_bias.reshape(bsz * self.num_heads, tgt_len, tgt_len)
# Pre-softmax attention becomes the 2D bias for the next layer
twod_bias_new = attn_weights.view(bsz, self.num_heads, tgt_len, tgt_len)
attn_probs = F.softmax(attn_weights, dim=-1)
attn_probs = self.dropout(attn_probs)
out = torch.bmm(attn_probs, v)
out = out.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
out = self.out_proj(out)
attn_weights_out = None
if output_attentions:
attn_weights_out = twod_bias_new
return out, attn_weights_out, twod_bias_new
class ErnieRNALayer(nn.Module):
def __init__(self, config):
super().__init__()
self.self_attn = ErnieRNAAttention(config)
self.self_attn_layer_norm = nn.LayerNorm(config.embed_dim)
self.fc1 = nn.Linear(config.embed_dim, config.ffn_embed_dim)
self.fc2 = nn.Linear(config.ffn_embed_dim, config.embed_dim)
self.final_layer_norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.dropout)
self.activation_dropout = nn.Dropout(config.activation_dropout)
self.activation_fn = ACT2FN[config.activation_fn]
def forward(self, x, key_padding_mask=None, twod_bias=None, output_attentions=False):
residual = x
x, attn_weights, twod_bias_new = self.self_attn(
x,
key_padding_mask=key_padding_mask,
twod_bias=twod_bias,
output_attentions=output_attentions,
)
x = self.dropout(x)
x = self.self_attn_layer_norm(residual + x)
residual = x
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout(x)
x = self.fc2(x)
x = self.dropout(x)
x = self.final_layer_norm(residual + x)
return x, attn_weights, twod_bias_new
class ErnieRNAModel(PreTrainedModel):
config_class = ErnieRNAConfig
base_model_prefix = "model"
_supports_sdpa = False
_supports_flash_attn_2 = False
def __init__(self, config):
super().__init__(config)
self.padding_idx = config.padding_idx
self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim, padding_idx=config.padding_idx)
self.embed_positions = ErnieRNASinusoidalPositionalEmbedding(
config.max_positions, config.embed_dim, config.padding_idx
)
self.segment_embeddings = nn.Embedding(config.num_segments, config.embed_dim)
self.emb_layer_norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.dropout)
self.layers = nn.ModuleList([ErnieRNALayer(config) for _ in range(config.num_layers)])
self.twod_proj = ErnieRNATwodProj(config)
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# HF: 1=attend, 0=pad -> True=padding
if attention_mask is not None:
padding_mask = attention_mask.eq(0)
else:
padding_mask = input_ids.eq(self.padding_idx)
# Zero out padding positions after masking (matches fairseq behavior)
x = self.embed_tokens(input_ids)
# Sinusoidal PE is a float32 buffer; cast to activation dtype for bfloat16 compat.
x = x + self.embed_positions(input_ids).to(x.dtype)
if token_type_ids is not None:
x = x + self.segment_embeddings(token_type_ids)
x = self.emb_layer_norm(x)
if padding_mask.any():
x = x * (~padding_mask).unsqueeze(-1).to(x.dtype)
x = self.dropout(x)
# Compute initial 2D bias from sequence (always float32 as in original)
pairing = _compute_pairing_bias(input_ids) # [B, T, T, 1]
twod_proj_f32 = self.twod_proj.float()
twod_bias = twod_proj_f32(pairing.float()) # [B, T, T, H]
twod_bias = twod_bias.permute(0, 3, 1, 2).contiguous().to(x.dtype) # [B, H, T, T]
# Transpose to [T, B, C] for attention
x = x.transpose(0, 1)
all_hidden_states = []
all_attentions = []
if output_hidden_states:
all_hidden_states.append(x.transpose(0, 1))
key_padding_mask = padding_mask if padding_mask.any() else None
for layer in self.layers:
x, attn_weights, twod_bias = layer(
x,
key_padding_mask=key_padding_mask,
twod_bias=twod_bias,
output_attentions=output_attentions,
)
if output_hidden_states:
all_hidden_states.append(x.transpose(0, 1))
if output_attentions:
all_attentions.append(attn_weights)
x = x.transpose(0, 1) # [B, T, C]
if not return_dict:
return tuple(v for v in [x, tuple(all_hidden_states) or None, tuple(all_attentions) or None] if v is not None)
return BaseModelOutput(
last_hidden_state=x,
hidden_states=tuple(all_hidden_states) if output_hidden_states else None,
attentions=tuple(all_attentions) if output_attentions else None,
)
class ErnieRNALMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.embed_dim, config.embed_dim)
self.layer_norm = nn.LayerNorm(config.embed_dim)
self.activation_fn = ACT2FN[config.activation_fn]
self.decoder = nn.Linear(config.embed_dim, config.vocab_size)
def forward(self, x):
x = self.layer_norm(self.activation_fn(self.dense(x)))
x = self.decoder(x)
return x
class ErnieRNAForMaskedLM(PreTrainedModel):
config_class = ErnieRNAConfig
base_model_prefix = "model"
_supports_sdpa = False
_supports_flash_attn_2 = False
def __init__(self, config):
super().__init__(config)
self.model = ErnieRNAModel(config)
self.lm_head = ErnieRNALMHead(config)
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
out = self.model(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = self.lm_head(out[0])
loss = None
if labels is not None:
loss = F.cross_entropy(
logits.view(-1, self.config.vocab_size),
labels.view(-1),
ignore_index=-100,
)
if not return_dict:
output = (logits,) + out[1:]
return ((loss,) + output) if loss is not None else output
return MaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=out.hidden_states,
attentions=out.attentions,
)