Instructions to use Taykhoom/UTR-LM-MLMSS with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Taykhoom/UTR-LM-MLMSS with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Taykhoom/UTR-LM-MLMSS", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """UTR-LM ported to Hugging Face PreTrainedModel.""" | |
| import math | |
| from typing import Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import PreTrainedModel | |
| from transformers.modeling_outputs import BaseModelOutput, MaskedLMOutput | |
| from .configuration_utrlm import UtrLmConfig | |
| # --------------------------------------------------------------------------- | |
| # Rotary embeddings | |
| # --------------------------------------------------------------------------- | |
| def _rotate_half(x: torch.Tensor) -> torch.Tensor: | |
| x1, x2 = x.chunk(2, dim=-1) | |
| return torch.cat((-x2, x1), dim=-1) | |
| def _apply_rotary_pos_emb(x, cos, sin): | |
| cos = cos[:, : x.shape[-2], :].to(x.dtype) | |
| sin = sin[:, : x.shape[-2], :].to(x.dtype) | |
| return (x * cos) + (_rotate_half(x) * sin) | |
| class RotaryEmbedding(nn.Module): | |
| def __init__(self, dim: int): | |
| super().__init__() | |
| inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) | |
| self.register_buffer("inv_freq", inv_freq) | |
| self._seq_len_cached: Optional[int] = None | |
| self._cos_cached: Optional[torch.Tensor] = None | |
| self._sin_cached: Optional[torch.Tensor] = None | |
| def _update_cos_sin_tables(self, x: torch.Tensor, seq_dimension: int = 1): | |
| seq_len = x.shape[seq_dimension] | |
| if seq_len != self._seq_len_cached or self._cos_cached.device != x.device: | |
| self._seq_len_cached = seq_len | |
| t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq) | |
| freqs = torch.einsum("i,j->ij", t, self.inv_freq) | |
| emb = torch.cat((freqs, freqs), dim=-1).to(x.device) | |
| self._cos_cached = emb.cos()[None, :, :] | |
| self._sin_cached = emb.sin()[None, :, :] | |
| return self._cos_cached, self._sin_cached | |
| def forward(self, q, k): | |
| self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2) | |
| return ( | |
| _apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), | |
| _apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Attention variants | |
| # --------------------------------------------------------------------------- | |
| class UtrLmAttention(nn.Module): | |
| """Eager (standard) attention.""" | |
| def __init__(self, embed_dim: int, num_heads: int): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.head_dim = embed_dim // num_heads | |
| self.scaling = self.head_dim ** -0.5 | |
| self.k_proj = nn.Linear(embed_dim, embed_dim) | |
| self.v_proj = nn.Linear(embed_dim, embed_dim) | |
| self.q_proj = nn.Linear(embed_dim, embed_dim) | |
| self.out_proj = nn.Linear(embed_dim, embed_dim) | |
| self.rot_emb = RotaryEmbedding(dim=self.head_dim) | |
| def _project(self, x): | |
| """Project and reshape x (T, B, E) -> q/k/v in (B*H, T, head_dim).""" | |
| tgt_len, bsz, _ = x.size() | |
| q = (self.q_proj(x) * self.scaling).contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) | |
| k = self.k_proj(x).contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) | |
| v = self.v_proj(x).contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) | |
| q, k = self.rot_emb(q, k) | |
| return q, k, v | |
| def forward(self, x, key_padding_mask, output_attentions: bool = False): | |
| tgt_len, bsz, _ = x.size() | |
| q, k, v = self._project(x) | |
| attn_weights = torch.bmm(q, k.transpose(1, 2)) | |
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, tgt_len) | |
| if key_padding_mask is not None: | |
| attn_weights = attn_weights.masked_fill( | |
| key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") | |
| ) | |
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, tgt_len) | |
| attn_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights) | |
| attn = torch.bmm(attn_probs, v) | |
| attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) | |
| out = self.out_proj(attn) | |
| if output_attentions: | |
| return out, attn_probs.view(bsz, self.num_heads, tgt_len, tgt_len) | |
| return out, None | |
| class UtrLmSdpaAttention(UtrLmAttention): | |
| """SDPA attention via torch.nn.functional.scaled_dot_product_attention.""" | |
| def forward(self, x, key_padding_mask, output_attentions: bool = False): | |
| if output_attentions: | |
| # SDPA doesn't expose attention weights; fall back to eager. | |
| return super().forward(x, key_padding_mask, output_attentions=True) | |
| tgt_len, bsz, _ = x.size() | |
| q, k, v = self._project(x) # (B*H, T, head_dim) | |
| # Reshape to (B, H, T, head_dim) for SDPA | |
| q = q.view(bsz, self.num_heads, tgt_len, self.head_dim) | |
| k = k.view(bsz, self.num_heads, tgt_len, self.head_dim) | |
| v = v.view(bsz, self.num_heads, tgt_len, self.head_dim) | |
| # Convert bool padding mask -> additive float mask (B, 1, 1, T) | |
| attn_mask = None | |
| if key_padding_mask is not None: | |
| attn_mask = torch.zeros(bsz, 1, 1, tgt_len, dtype=q.dtype, device=q.device) | |
| attn_mask = attn_mask.masked_fill(key_padding_mask[:, None, None, :], float("-inf")) | |
| # scale=1.0 because q is already pre-scaled by self.scaling | |
| out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=1.0) | |
| out = out.permute(2, 0, 1, 3).contiguous().view(tgt_len, bsz, self.embed_dim) | |
| return self.out_proj(out), None | |
| class UtrLmFlashAttention2(UtrLmAttention): | |
| """Flash Attention 2 via flash_attn (must be installed separately).""" | |
| def forward(self, x, key_padding_mask, output_attentions: bool = False): | |
| if output_attentions: | |
| # Flash attention doesn't expose attention weights; fall back to eager. | |
| return super().forward(x, key_padding_mask, output_attentions=True) | |
| try: | |
| from flash_attn import flash_attn_func | |
| from flash_attn.bert_padding import pad_input, unpad_input | |
| except ImportError as e: | |
| raise ImportError("flash_attn is required for attn_implementation='flash_attention_2'. " | |
| "Install with: pip install flash-attn --no-build-isolation") from e | |
| tgt_len, bsz, _ = x.size() | |
| q, k, v = self._project(x) # (B*H, T, head_dim) | |
| # Reshape to (B, T, H, head_dim) - flash_attn's expected layout | |
| q = q.view(bsz, self.num_heads, tgt_len, self.head_dim).permute(0, 2, 1, 3) | |
| k = k.view(bsz, self.num_heads, tgt_len, self.head_dim).permute(0, 2, 1, 3) | |
| v = v.view(bsz, self.num_heads, tgt_len, self.head_dim).permute(0, 2, 1, 3) | |
| # Flash attention requires fp16 or bf16 | |
| orig_dtype = q.dtype | |
| if orig_dtype not in (torch.float16, torch.bfloat16): | |
| q, k, v = q.to(torch.bfloat16), k.to(torch.bfloat16), v.to(torch.bfloat16) | |
| if key_padding_mask is not None: | |
| # Unpad, run varlen flash attention, repad | |
| from flash_attn import flash_attn_varlen_func | |
| attention_mask = ~key_padding_mask # True = valid token | |
| q_unpad, indices, cu_seqlens, max_seqlen, _ = unpad_input(q, attention_mask) | |
| k_unpad, _, _, _, _ = unpad_input(k, attention_mask) | |
| v_unpad, _, _, _, _ = unpad_input(v, attention_mask) | |
| out_unpad = flash_attn_varlen_func( | |
| q_unpad, k_unpad, v_unpad, | |
| cu_seqlens_q=cu_seqlens, | |
| cu_seqlens_k=cu_seqlens, | |
| max_seqlen_q=max_seqlen, | |
| max_seqlen_k=max_seqlen, | |
| softmax_scale=1.0, # q already pre-scaled | |
| causal=False, | |
| ) | |
| out = pad_input(out_unpad, indices, bsz, tgt_len) | |
| else: | |
| out = flash_attn_func(q, k, v, softmax_scale=1.0, causal=False) | |
| out = out.to(orig_dtype).permute(1, 0, 2, 3).contiguous().view(tgt_len, bsz, self.embed_dim) | |
| return self.out_proj(out), None | |
| UTRLM_ATTENTION_CLASSES = { | |
| "eager": UtrLmAttention, | |
| "sdpa": UtrLmSdpaAttention, | |
| "flash_attention_2": UtrLmFlashAttention2, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Transformer layer (pre-LN) | |
| # --------------------------------------------------------------------------- | |
| def _gelu(x): | |
| return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | |
| class UtrLmLayer(nn.Module): | |
| def __init__(self, embed_dim: int, attention_heads: int, config: UtrLmConfig): | |
| super().__init__() | |
| attn_cls = UTRLM_ATTENTION_CLASSES[getattr(config, "_attn_implementation", "eager")] | |
| self.self_attn = attn_cls(embed_dim, attention_heads) | |
| self.self_attn_layer_norm = nn.LayerNorm(embed_dim) | |
| self.fc1 = nn.Linear(embed_dim, 4 * embed_dim) | |
| self.fc2 = nn.Linear(4 * embed_dim, embed_dim) | |
| self.final_layer_norm = nn.LayerNorm(embed_dim) | |
| def forward(self, x, padding_mask, output_attentions: bool = False): | |
| residual = x | |
| x = self.self_attn_layer_norm(x) | |
| x, attn_weights = self.self_attn(x, key_padding_mask=padding_mask, output_attentions=output_attentions) | |
| x = residual + x | |
| residual = x | |
| x = self.final_layer_norm(x) | |
| x = _gelu(self.fc1(x)) | |
| x = self.fc2(x) | |
| return residual + x, attn_weights | |
| # --------------------------------------------------------------------------- | |
| # Backbone | |
| # --------------------------------------------------------------------------- | |
| class UtrLmModel(PreTrainedModel): | |
| """ | |
| UTR-LM encoder backbone. Returns last_hidden_state (B, T, E). | |
| The [CLS] token sits at position 0 (prepend_bos=True by default). | |
| """ | |
| config_class = UtrLmConfig | |
| base_model_prefix = "utrlm" | |
| _supports_sdpa = True | |
| _supports_flash_attn_2 = True | |
| def __init__(self, config: UtrLmConfig): | |
| super().__init__(config) | |
| self.embed_scale = 1 | |
| self.embed_tokens = nn.Embedding( | |
| config.alphabet_size, config.embed_dim, padding_idx=config.padding_idx | |
| ) | |
| self.layers = nn.ModuleList( | |
| [UtrLmLayer(config.embed_dim, config.attention_heads, config) for _ in range(config.num_layers)] | |
| ) | |
| self.emb_layer_norm_after = nn.LayerNorm(config.embed_dim) | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| return self.embed_tokens | |
| def set_input_embeddings(self, value): | |
| self.embed_tokens = value | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor, | |
| attention_mask: Optional[torch.BoolTensor] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, BaseModelOutput]: | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None | |
| else self.config.output_hidden_states | |
| ) | |
| output_attentions = ( | |
| output_attentions if output_attentions is not None else self.config.output_attentions | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| cfg = self.config | |
| # HF convention: attention_mask is 1=attend, 0=pad. | |
| # Convert to bool padding_mask (True = ignore) or derive from input_ids. | |
| if attention_mask is not None: | |
| padding_mask = attention_mask.eq(0) | |
| else: | |
| padding_mask = input_ids.eq(cfg.padding_idx) | |
| x = self.embed_scale * self.embed_tokens(input_ids) | |
| if cfg.token_dropout: | |
| x.masked_fill_((input_ids == cfg.mask_idx).unsqueeze(-1), 0.0) | |
| mask_ratio_train = 0.15 * 0.8 | |
| src_lengths = (~padding_mask).sum(-1) | |
| mask_ratio_observed = (input_ids == cfg.mask_idx).sum(-1).to(x.dtype) / src_lengths.to(x.dtype) | |
| x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None] | |
| if padding_mask is not None: | |
| x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) | |
| all_hidden_states = () if output_hidden_states else None | |
| all_attentions = () if output_attentions else None | |
| if output_hidden_states: | |
| all_hidden_states += (x,) | |
| x = x.transpose(0, 1) # (B, T, E) -> (T, B, E) | |
| effective_padding = padding_mask if padding_mask.any() else None | |
| for layer in self.layers: | |
| x, attn_weights = layer(x, padding_mask=effective_padding, output_attentions=output_attentions) | |
| if output_hidden_states: | |
| all_hidden_states += (x.transpose(0, 1),) | |
| if output_attentions: | |
| all_attentions += (attn_weights,) | |
| x = self.emb_layer_norm_after(x) | |
| x = x.transpose(0, 1) # (T, B, E) -> (B, T, E) | |
| if output_hidden_states: | |
| all_hidden_states = all_hidden_states[:-1] + (x,) | |
| if not return_dict: | |
| return tuple(v for v in [x, all_hidden_states, all_attentions] if v is not None) | |
| return BaseModelOutput( | |
| last_hidden_state=x, | |
| hidden_states=all_hidden_states, | |
| attentions=all_attentions, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # MLM head | |
| # --------------------------------------------------------------------------- | |
| class UtrLmForMaskedLM(PreTrainedModel): | |
| """ | |
| UTR-LM with a masked-language-modelling head. | |
| Returns MaskedLMOutput with logits (B, T, vocab_size). | |
| """ | |
| config_class = UtrLmConfig | |
| base_model_prefix = "utrlm" | |
| _supports_sdpa = True | |
| _supports_flash_attn_2 = True | |
| def __init__(self, config: UtrLmConfig): | |
| super().__init__(config) | |
| self.utrlm = UtrLmModel(config) | |
| embed_dim = config.embed_dim | |
| vocab_size = config.alphabet_size | |
| self.lm_head = nn.ModuleDict({ | |
| "dense": nn.Linear(embed_dim, embed_dim), | |
| "layer_norm": nn.LayerNorm(embed_dim), | |
| }) | |
| self.lm_head_bias = nn.Parameter(torch.zeros(vocab_size)) | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| return self.utrlm.embed_tokens | |
| def set_input_embeddings(self, value): | |
| self.utrlm.embed_tokens = value | |
| def get_output_embeddings(self): | |
| return self.utrlm.embed_tokens | |
| def set_output_embeddings(self, new_embeddings): | |
| self.utrlm.embed_tokens = new_embeddings | |
| def _lm_head_forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.lm_head["dense"](x) | |
| x = _gelu(x) | |
| x = self.lm_head["layer_norm"](x) | |
| return F.linear(x, self.utrlm.embed_tokens.weight) + self.lm_head_bias | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor, | |
| attention_mask: Optional[torch.BoolTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, MaskedLMOutput]: | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| outputs = self.utrlm( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| output_hidden_states=output_hidden_states, | |
| output_attentions=output_attentions, | |
| return_dict=True, | |
| ) | |
| logits = self._lm_head_forward(outputs.last_hidden_state) | |
| loss = None | |
| if labels is not None: | |
| loss = F.cross_entropy( | |
| logits.view(-1, self.config.alphabet_size), | |
| labels.view(-1), | |
| ignore_index=self.config.padding_idx, | |
| ) | |
| if not return_dict: | |
| output = (logits,) + outputs[1:] | |
| return (loss,) + output if loss is not None else output | |
| return MaskedLMOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |