UTR-LM-MLM / modeling_utrlm.py
Taykhoom's picture
Upload folder using huggingface_hub
0a535de verified
"""UTR-LM ported to Hugging Face PreTrainedModel."""
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput, MaskedLMOutput
from .configuration_utrlm import UtrLmConfig
# ---------------------------------------------------------------------------
# Rotary embeddings
# ---------------------------------------------------------------------------
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def _apply_rotary_pos_emb(x, cos, sin):
cos = cos[:, : x.shape[-2], :].to(x.dtype)
sin = sin[:, : x.shape[-2], :].to(x.dtype)
return (x * cos) + (_rotate_half(x) * sin)
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self._seq_len_cached: Optional[int] = None
self._cos_cached: Optional[torch.Tensor] = None
self._sin_cached: Optional[torch.Tensor] = None
def _update_cos_sin_tables(self, x: torch.Tensor, seq_dimension: int = 1):
seq_len = x.shape[seq_dimension]
if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
self._seq_len_cached = seq_len
t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self._cos_cached = emb.cos()[None, :, :]
self._sin_cached = emb.sin()[None, :, :]
return self._cos_cached, self._sin_cached
def forward(self, q, k):
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
return (
_apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
_apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
)
# ---------------------------------------------------------------------------
# Attention variants
# ---------------------------------------------------------------------------
class UtrLmAttention(nn.Module):
"""Eager (standard) attention."""
def __init__(self, embed_dim: int, num_heads: int):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scaling = self.head_dim ** -0.5
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.rot_emb = RotaryEmbedding(dim=self.head_dim)
def _project(self, x):
"""Project and reshape x (T, B, E) -> q/k/v in (B*H, T, head_dim)."""
tgt_len, bsz, _ = x.size()
q = (self.q_proj(x) * self.scaling).contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
k = self.k_proj(x).contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
v = self.v_proj(x).contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
q, k = self.rot_emb(q, k)
return q, k, v
def forward(self, x, key_padding_mask, output_attentions: bool = False):
tgt_len, bsz, _ = x.size()
q, k, v = self._project(x)
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, tgt_len)
if key_padding_mask is not None:
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, tgt_len)
attn_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights)
attn = torch.bmm(attn_probs, v)
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
out = self.out_proj(attn)
if output_attentions:
return out, attn_probs.view(bsz, self.num_heads, tgt_len, tgt_len)
return out, None
class UtrLmSdpaAttention(UtrLmAttention):
"""SDPA attention via torch.nn.functional.scaled_dot_product_attention."""
def forward(self, x, key_padding_mask, output_attentions: bool = False):
if output_attentions:
# SDPA doesn't expose attention weights; fall back to eager.
return super().forward(x, key_padding_mask, output_attentions=True)
tgt_len, bsz, _ = x.size()
q, k, v = self._project(x) # (B*H, T, head_dim)
# Reshape to (B, H, T, head_dim) for SDPA
q = q.view(bsz, self.num_heads, tgt_len, self.head_dim)
k = k.view(bsz, self.num_heads, tgt_len, self.head_dim)
v = v.view(bsz, self.num_heads, tgt_len, self.head_dim)
# Convert bool padding mask -> additive float mask (B, 1, 1, T)
attn_mask = None
if key_padding_mask is not None:
attn_mask = torch.zeros(bsz, 1, 1, tgt_len, dtype=q.dtype, device=q.device)
attn_mask = attn_mask.masked_fill(key_padding_mask[:, None, None, :], float("-inf"))
# scale=1.0 because q is already pre-scaled by self.scaling
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=1.0)
out = out.permute(2, 0, 1, 3).contiguous().view(tgt_len, bsz, self.embed_dim)
return self.out_proj(out), None
class UtrLmFlashAttention2(UtrLmAttention):
"""Flash Attention 2 via flash_attn (must be installed separately)."""
def forward(self, x, key_padding_mask, output_attentions: bool = False):
if output_attentions:
# Flash attention doesn't expose attention weights; fall back to eager.
return super().forward(x, key_padding_mask, output_attentions=True)
try:
from flash_attn import flash_attn_func
from flash_attn.bert_padding import pad_input, unpad_input
except ImportError as e:
raise ImportError("flash_attn is required for attn_implementation='flash_attention_2'. "
"Install with: pip install flash-attn --no-build-isolation") from e
tgt_len, bsz, _ = x.size()
q, k, v = self._project(x) # (B*H, T, head_dim)
# Reshape to (B, T, H, head_dim) - flash_attn's expected layout
q = q.view(bsz, self.num_heads, tgt_len, self.head_dim).permute(0, 2, 1, 3)
k = k.view(bsz, self.num_heads, tgt_len, self.head_dim).permute(0, 2, 1, 3)
v = v.view(bsz, self.num_heads, tgt_len, self.head_dim).permute(0, 2, 1, 3)
# Flash attention requires fp16 or bf16
orig_dtype = q.dtype
if orig_dtype not in (torch.float16, torch.bfloat16):
q, k, v = q.to(torch.bfloat16), k.to(torch.bfloat16), v.to(torch.bfloat16)
if key_padding_mask is not None:
# Unpad, run varlen flash attention, repad
from flash_attn import flash_attn_varlen_func
attention_mask = ~key_padding_mask # True = valid token
q_unpad, indices, cu_seqlens, max_seqlen, _ = unpad_input(q, attention_mask)
k_unpad, _, _, _, _ = unpad_input(k, attention_mask)
v_unpad, _, _, _, _ = unpad_input(v, attention_mask)
out_unpad = flash_attn_varlen_func(
q_unpad, k_unpad, v_unpad,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
softmax_scale=1.0, # q already pre-scaled
causal=False,
)
out = pad_input(out_unpad, indices, bsz, tgt_len)
else:
out = flash_attn_func(q, k, v, softmax_scale=1.0, causal=False)
out = out.to(orig_dtype).permute(1, 0, 2, 3).contiguous().view(tgt_len, bsz, self.embed_dim)
return self.out_proj(out), None
UTRLM_ATTENTION_CLASSES = {
"eager": UtrLmAttention,
"sdpa": UtrLmSdpaAttention,
"flash_attention_2": UtrLmFlashAttention2,
}
# ---------------------------------------------------------------------------
# Transformer layer (pre-LN)
# ---------------------------------------------------------------------------
def _gelu(x):
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
class UtrLmLayer(nn.Module):
def __init__(self, embed_dim: int, attention_heads: int, config: UtrLmConfig):
super().__init__()
attn_cls = UTRLM_ATTENTION_CLASSES[getattr(config, "_attn_implementation", "eager")]
self.self_attn = attn_cls(embed_dim, attention_heads)
self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
self.fc1 = nn.Linear(embed_dim, 4 * embed_dim)
self.fc2 = nn.Linear(4 * embed_dim, embed_dim)
self.final_layer_norm = nn.LayerNorm(embed_dim)
def forward(self, x, padding_mask, output_attentions: bool = False):
residual = x
x = self.self_attn_layer_norm(x)
x, attn_weights = self.self_attn(x, key_padding_mask=padding_mask, output_attentions=output_attentions)
x = residual + x
residual = x
x = self.final_layer_norm(x)
x = _gelu(self.fc1(x))
x = self.fc2(x)
return residual + x, attn_weights
# ---------------------------------------------------------------------------
# Backbone
# ---------------------------------------------------------------------------
class UtrLmModel(PreTrainedModel):
"""
UTR-LM encoder backbone. Returns last_hidden_state (B, T, E).
The [CLS] token sits at position 0 (prepend_bos=True by default).
"""
config_class = UtrLmConfig
base_model_prefix = "utrlm"
_supports_sdpa = True
_supports_flash_attn_2 = True
def __init__(self, config: UtrLmConfig):
super().__init__(config)
self.embed_scale = 1
self.embed_tokens = nn.Embedding(
config.alphabet_size, config.embed_dim, padding_idx=config.padding_idx
)
self.layers = nn.ModuleList(
[UtrLmLayer(config.embed_dim, config.attention_heads, config) for _ in range(config.num_layers)]
)
self.emb_layer_norm_after = nn.LayerNorm(config.embed_dim)
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.BoolTensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None
else self.config.output_hidden_states
)
output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
cfg = self.config
# HF convention: attention_mask is 1=attend, 0=pad.
# Convert to bool padding_mask (True = ignore) or derive from input_ids.
if attention_mask is not None:
padding_mask = attention_mask.eq(0)
else:
padding_mask = input_ids.eq(cfg.padding_idx)
x = self.embed_scale * self.embed_tokens(input_ids)
if cfg.token_dropout:
x.masked_fill_((input_ids == cfg.mask_idx).unsqueeze(-1), 0.0)
mask_ratio_train = 0.15 * 0.8
src_lengths = (~padding_mask).sum(-1)
mask_ratio_observed = (input_ids == cfg.mask_idx).sum(-1).to(x.dtype) / src_lengths.to(x.dtype)
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
if padding_mask is not None:
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
if output_hidden_states:
all_hidden_states += (x,)
x = x.transpose(0, 1) # (B, T, E) -> (T, B, E)
effective_padding = padding_mask if padding_mask.any() else None
for layer in self.layers:
x, attn_weights = layer(x, padding_mask=effective_padding, output_attentions=output_attentions)
if output_hidden_states:
all_hidden_states += (x.transpose(0, 1),)
if output_attentions:
all_attentions += (attn_weights,)
x = self.emb_layer_norm_after(x)
x = x.transpose(0, 1) # (T, B, E) -> (B, T, E)
if output_hidden_states:
all_hidden_states = all_hidden_states[:-1] + (x,)
if not return_dict:
return tuple(v for v in [x, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=x,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
# ---------------------------------------------------------------------------
# MLM head
# ---------------------------------------------------------------------------
class UtrLmForMaskedLM(PreTrainedModel):
"""
UTR-LM with a masked-language-modelling head.
Returns MaskedLMOutput with logits (B, T, vocab_size).
"""
config_class = UtrLmConfig
base_model_prefix = "utrlm"
_supports_sdpa = True
_supports_flash_attn_2 = True
def __init__(self, config: UtrLmConfig):
super().__init__(config)
self.utrlm = UtrLmModel(config)
embed_dim = config.embed_dim
vocab_size = config.alphabet_size
self.lm_head = nn.ModuleDict({
"dense": nn.Linear(embed_dim, embed_dim),
"layer_norm": nn.LayerNorm(embed_dim),
})
self.lm_head_bias = nn.Parameter(torch.zeros(vocab_size))
self.post_init()
def get_input_embeddings(self):
return self.utrlm.embed_tokens
def set_input_embeddings(self, value):
self.utrlm.embed_tokens = value
def get_output_embeddings(self):
return self.utrlm.embed_tokens
def set_output_embeddings(self, new_embeddings):
self.utrlm.embed_tokens = new_embeddings
def _lm_head_forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.lm_head["dense"](x)
x = _gelu(x)
x = self.lm_head["layer_norm"](x)
return F.linear(x, self.utrlm.embed_tokens.weight) + self.lm_head_bias
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.BoolTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MaskedLMOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.utrlm(
input_ids,
attention_mask=attention_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=True,
)
logits = self._lm_head_forward(outputs.last_hidden_state)
loss = None
if labels is not None:
loss = F.cross_entropy(
logits.view(-1, self.config.alphabet_size),
labels.view(-1),
ignore_index=self.config.padding_idx,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return MaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)