Taykhoom commited on
Commit
03d9aff
·
verified ·
1 Parent(s): 3e4e994

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - rna
4
+ library_name: transformers
5
+ tags:
6
+ - RNA
7
+ - language-model
8
+ - UTR
9
+ - genomics
10
+ - biology
11
+ license: gpl-3.0
12
+ ---
13
+
14
+ # UTR-LM-MLMSS
15
+
16
+ UTR-LM is a 5' UTR RNA language model based on ESM2, pretrained on endogenous 5' UTRs from five species and a large synthetic library. This checkpoint (`UTR-LM-MLMSS`) was trained with **MLM + secondary structure prediction** as a supervised auxiliary objective.
17
+
18
+ ## Architecture
19
+
20
+ | Parameter | Value |
21
+ |---|---|
22
+ | Layers | 6 |
23
+ | Attention heads | 16 |
24
+ | Embedding dimension | 128 |
25
+ | Vocabulary size | 10 |
26
+ | Positional encoding | Rotary (RoPE) |
27
+ | Architecture | ESM2-style pre-LN Transformer |
28
+
29
+ **Vocabulary:** `<pad>` (0), `<eos>` (1), `<unk>` (2), `A` (3), `G` (4), `C` (5), `T` (6), `<cls>` (7), `<mask>` (8), `<sep>` (9)
30
+
31
+ ## Pretraining
32
+
33
+ - **Objective:** Masked language modeling + per-token secondary structure prediction (3-class: unpaired, stem, loop)
34
+ - **Data:** Endogenous 5' UTRs from five species (human, mouse, zebrafish, *Drosophila*, yeast) combined with the Cao et al. random 5' UTR synthetic library
35
+ - **Source checkpoint:** `ESM2SS_FS4.1_fiveSpeciesCao_6layers_16heads_128embedsize_4096batchToks_lr1e-05_structureweight1.0_MLMLossMin_epoch200.pkl`
36
+
37
+ Only one `ESM2SS` (secondary structure only, no MFE regression) checkpoint was available; no selection decision was required.
38
+
39
+ ## Parity Verification
40
+
41
+ Hidden-state representations produced by this HF model are verified to be **exactly identical** (max absolute difference = 0.00) to the original ESM2-based implementation at all 7 representation levels (initial embedding + 6 transformer layers). Verified on GPU with PyTorch 2.8 / CUDA 12.6.
42
+
43
+ ## Related Models
44
+
45
+ | Model | Pretraining Objective | Notes |
46
+ |---|---|---|
47
+ | [UTR-LM-MLM](https://huggingface.co/Taykhoom/UTR-LM-MLM) | MLM | Base model |
48
+ | [UTR-LM-MLMSI](https://huggingface.co/Taykhoom/UTR-LM-MLMSI) | MLM + MFE regression | Recommended for TE / EL tasks |
49
+ | **[UTR-LM-MLMSS](https://huggingface.co/Taykhoom/UTR-LM-MLMSS)** | MLM + secondary structure | This model |
50
+ | [UTR-LM-MLMSISS](https://huggingface.co/Taykhoom/UTR-LM-MLMSISS) | MLM + MFE + secondary structure | Recommended for MRL tasks |
51
+
52
+ ## Usage
53
+
54
+ ### Embedding generation
55
+
56
+ ```python
57
+ import torch
58
+ from transformers import AutoTokenizer, AutoModel
59
+
60
+ tokenizer = AutoTokenizer.from_pretrained("Taykhoom/UTR-LM-MLMSS", trust_remote_code=True)
61
+ model = AutoModel.from_pretrained("Taykhoom/UTR-LM-MLMSS", trust_remote_code=True)
62
+ model.eval()
63
+
64
+ sequences = ["ATGCATGCATGC", "GCTAGCTAGCTAGCTA"]
65
+ enc = tokenizer(sequences, return_tensors="pt", padding=True)
66
+
67
+ with torch.no_grad():
68
+ out = model(**enc)
69
+
70
+ # CLS token embedding (position 0) - recommended for sequence-level tasks
71
+ cls_emb = out.last_hidden_state[:, 0, :] # (batch, 128)
72
+
73
+ # All-token embeddings
74
+ token_emb = out.last_hidden_state # (batch, seq_len, 128)
75
+
76
+ # Intermediate layer representations
77
+ out_all = model(**enc, output_hidden_states=True)
78
+ layer3_emb = out_all.hidden_states[3] # after layer 3, shape (batch, seq_len, 128)
79
+ ```
80
+
81
+ ### MLM logits
82
+
83
+ ```python
84
+ import torch
85
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
86
+
87
+ tokenizer = AutoTokenizer.from_pretrained("Taykhoom/UTR-LM-MLMSS", trust_remote_code=True)
88
+ model = AutoModelForMaskedLM.from_pretrained("Taykhoom/UTR-LM-MLMSS", trust_remote_code=True)
89
+ model.eval()
90
+
91
+ enc = tokenizer(["ATGC<mask>ATGC"], return_tensors="pt")
92
+ with torch.no_grad():
93
+ logits = model(**enc).logits # (1, seq_len, 10)
94
+ ```
95
+
96
+ ### Fine-tuning
97
+
98
+ The model follows standard HF conventions and can be fine-tuned with any Trainer-compatible setup. For sequence regression tasks, use the CLS token embedding as input to a prediction head (as done in the original UTR-LM paper).
99
+
100
+ ## Citation
101
+
102
+ ```bibtex
103
+ @article{chu2023utrlm,
104
+ title = {A 5'UTR Language Model for Decoding Untranslated Regions of mRNA and Function Predictions},
105
+ author = {Chu, Yanyi and others},
106
+ journal = {bioRxiv},
107
+ year = {2023},
108
+ doi = {10.1101/2023.10.11.561938}
109
+ }
110
+ ```
111
+
112
+ ## Implementation Notes
113
+
114
+ The original UTR-LM implementation uses standard scaled dot-product attention. This HF port adds support for `attn_implementation="sdpa"` (PyTorch `F.scaled_dot_product_attention`) and `attn_implementation="flash_attention_2"` (requires `pip install flash-attn --no-build-isolation`), which were not part of the original codebase.
115
+
116
+ ## Credits
117
+
118
+ Original model and code by Yanyi Chu et al. (Stanford). Source code: [UTR-LM GitHub repository](https://github.com/a96123155/UTR-LM). The HF conversion code was authored primarily by [Claude Code](https://claude.ai/code) and reviewed manually by Taykhoom Dalal.
119
+
120
+ ## License
121
+
122
+ GPL-3.0, following the original UTR-LM repository.
config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alphabet_size": 10,
3
+ "append_eos": true,
4
+ "attention_heads": 16,
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_utrlm.UtrLmConfig",
7
+ "AutoModel": "modeling_utrlm.UtrLmModel",
8
+ "AutoModelForMaskedLM": "modeling_utrlm.UtrLmForMaskedLM",
9
+ "AutoTokenizer": "tokenization_utrlm.UtrLmTokenizer"
10
+ },
11
+ "cls_idx": 7,
12
+ "embed_dim": 128,
13
+ "eos_idx": 1,
14
+ "mask_idx": 8,
15
+ "model_type": "utrlm",
16
+ "num_layers": 6,
17
+ "pad_token_id": 0,
18
+ "padding_idx": 0,
19
+ "prepend_bos": true,
20
+ "token_dropout": true,
21
+ "transformers_version": "4.57.6"
22
+ }
configuration_utrlm.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class UtrLmConfig(PretrainedConfig):
5
+ """
6
+ Configuration for UTR-LM (ESM2-based RNA language model).
7
+
8
+ Vocab (10 tokens):
9
+ <pad>:0 <eos>:1 <unk>:2 A:3 G:4 C:5 T:6 <cls>:7 <mask>:8 <sep>:9
10
+ """
11
+
12
+ model_type = "utrlm"
13
+
14
+ def __init__(
15
+ self,
16
+ num_layers: int = 6,
17
+ embed_dim: int = 128,
18
+ attention_heads: int = 16,
19
+ alphabet_size: int = 10,
20
+ padding_idx: int = 0,
21
+ mask_idx: int = 8,
22
+ cls_idx: int = 7,
23
+ eos_idx: int = 1,
24
+ prepend_bos: bool = True,
25
+ append_eos: bool = True,
26
+ token_dropout: bool = True,
27
+ **kwargs,
28
+ ):
29
+ kwargs.setdefault("pad_token_id", padding_idx)
30
+ super().__init__(**kwargs)
31
+ # Written into config.json so AutoModel / AutoModelForMaskedLM resolve
32
+ # the correct classes when loading from the Hub with trust_remote_code=True.
33
+ self.auto_map = {
34
+ "AutoConfig": "configuration_utrlm.UtrLmConfig",
35
+ "AutoTokenizer": "tokenization_utrlm.UtrLmTokenizer",
36
+ "AutoModel": "modeling_utrlm.UtrLmModel",
37
+ "AutoModelForMaskedLM": "modeling_utrlm.UtrLmForMaskedLM",
38
+ }
39
+ self.num_layers = num_layers
40
+ self.embed_dim = embed_dim
41
+ self.attention_heads = attention_heads
42
+ self.alphabet_size = alphabet_size
43
+ self.padding_idx = padding_idx
44
+ self.mask_idx = mask_idx
45
+ self.cls_idx = cls_idx
46
+ self.eos_idx = eos_idx
47
+ self.prepend_bos = prepend_bos
48
+ self.append_eos = append_eos
49
+ self.token_dropout = token_dropout
50
+
51
+ @property
52
+ def hidden_size(self) -> int:
53
+ return self.embed_dim
modeling_utrlm.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """UTR-LM ported to Hugging Face PreTrainedModel."""
2
+
3
+ import math
4
+ from typing import Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from transformers import PreTrainedModel
10
+ from transformers.modeling_outputs import BaseModelOutput, MaskedLMOutput
11
+
12
+ from .configuration_utrlm import UtrLmConfig
13
+
14
+
15
+ # ---------------------------------------------------------------------------
16
+ # Rotary embeddings
17
+ # ---------------------------------------------------------------------------
18
+
19
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
20
+ x1, x2 = x.chunk(2, dim=-1)
21
+ return torch.cat((-x2, x1), dim=-1)
22
+
23
+
24
+ def _apply_rotary_pos_emb(x, cos, sin):
25
+ cos = cos[:, : x.shape[-2], :].to(x.dtype)
26
+ sin = sin[:, : x.shape[-2], :].to(x.dtype)
27
+ return (x * cos) + (_rotate_half(x) * sin)
28
+
29
+
30
+ class RotaryEmbedding(nn.Module):
31
+ def __init__(self, dim: int):
32
+ super().__init__()
33
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
34
+ self.register_buffer("inv_freq", inv_freq)
35
+ self._seq_len_cached: Optional[int] = None
36
+ self._cos_cached: Optional[torch.Tensor] = None
37
+ self._sin_cached: Optional[torch.Tensor] = None
38
+
39
+ def _update_cos_sin_tables(self, x: torch.Tensor, seq_dimension: int = 1):
40
+ seq_len = x.shape[seq_dimension]
41
+ if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
42
+ self._seq_len_cached = seq_len
43
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
44
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
45
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
46
+ self._cos_cached = emb.cos()[None, :, :]
47
+ self._sin_cached = emb.sin()[None, :, :]
48
+ return self._cos_cached, self._sin_cached
49
+
50
+ def forward(self, q, k):
51
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
52
+ return (
53
+ _apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
54
+ _apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
55
+ )
56
+
57
+
58
+ # ---------------------------------------------------------------------------
59
+ # Attention variants
60
+ # ---------------------------------------------------------------------------
61
+
62
+ class UtrLmAttention(nn.Module):
63
+ """Eager (standard) attention."""
64
+
65
+ def __init__(self, embed_dim: int, num_heads: int):
66
+ super().__init__()
67
+ self.embed_dim = embed_dim
68
+ self.num_heads = num_heads
69
+ self.head_dim = embed_dim // num_heads
70
+ self.scaling = self.head_dim ** -0.5
71
+
72
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
73
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
74
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
75
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
76
+ self.rot_emb = RotaryEmbedding(dim=self.head_dim)
77
+
78
+ def _project(self, x):
79
+ """Project and reshape x (T, B, E) -> q/k/v in (B*H, T, head_dim)."""
80
+ tgt_len, bsz, _ = x.size()
81
+ q = (self.q_proj(x) * self.scaling).contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
82
+ k = self.k_proj(x).contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
83
+ v = self.v_proj(x).contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
84
+ q, k = self.rot_emb(q, k)
85
+ return q, k, v
86
+
87
+ def forward(self, x, key_padding_mask, output_attentions: bool = False):
88
+ tgt_len, bsz, _ = x.size()
89
+ q, k, v = self._project(x)
90
+
91
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
92
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, tgt_len)
93
+ if key_padding_mask is not None:
94
+ attn_weights = attn_weights.masked_fill(
95
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
96
+ )
97
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, tgt_len)
98
+
99
+ attn_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights)
100
+ attn = torch.bmm(attn_probs, v)
101
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
102
+ out = self.out_proj(attn)
103
+
104
+ if output_attentions:
105
+ return out, attn_probs.view(bsz, self.num_heads, tgt_len, tgt_len)
106
+ return out, None
107
+
108
+
109
+ class UtrLmSdpaAttention(UtrLmAttention):
110
+ """SDPA attention via torch.nn.functional.scaled_dot_product_attention."""
111
+
112
+ def forward(self, x, key_padding_mask, output_attentions: bool = False):
113
+ if output_attentions:
114
+ # SDPA doesn't expose attention weights; fall back to eager.
115
+ return super().forward(x, key_padding_mask, output_attentions=True)
116
+
117
+ tgt_len, bsz, _ = x.size()
118
+ q, k, v = self._project(x) # (B*H, T, head_dim)
119
+
120
+ # Reshape to (B, H, T, head_dim) for SDPA
121
+ q = q.view(bsz, self.num_heads, tgt_len, self.head_dim)
122
+ k = k.view(bsz, self.num_heads, tgt_len, self.head_dim)
123
+ v = v.view(bsz, self.num_heads, tgt_len, self.head_dim)
124
+
125
+ # Convert bool padding mask -> additive float mask (B, 1, 1, T)
126
+ attn_mask = None
127
+ if key_padding_mask is not None:
128
+ attn_mask = torch.zeros(bsz, 1, 1, tgt_len, dtype=q.dtype, device=q.device)
129
+ attn_mask = attn_mask.masked_fill(key_padding_mask[:, None, None, :], float("-inf"))
130
+
131
+ # scale=1.0 because q is already pre-scaled by self.scaling
132
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=1.0)
133
+ out = out.permute(2, 0, 1, 3).contiguous().view(tgt_len, bsz, self.embed_dim)
134
+ return self.out_proj(out), None
135
+
136
+
137
+ class UtrLmFlashAttention2(UtrLmAttention):
138
+ """Flash Attention 2 via flash_attn (must be installed separately)."""
139
+
140
+ def forward(self, x, key_padding_mask, output_attentions: bool = False):
141
+ if output_attentions:
142
+ # Flash attention doesn't expose attention weights; fall back to eager.
143
+ return super().forward(x, key_padding_mask, output_attentions=True)
144
+
145
+ try:
146
+ from flash_attn import flash_attn_func
147
+ from flash_attn.bert_padding import pad_input, unpad_input
148
+ except ImportError as e:
149
+ raise ImportError("flash_attn is required for attn_implementation='flash_attention_2'. "
150
+ "Install with: pip install flash-attn --no-build-isolation") from e
151
+
152
+ tgt_len, bsz, _ = x.size()
153
+ q, k, v = self._project(x) # (B*H, T, head_dim)
154
+
155
+ # Reshape to (B, T, H, head_dim) - flash_attn's expected layout
156
+ q = q.view(bsz, self.num_heads, tgt_len, self.head_dim).permute(0, 2, 1, 3)
157
+ k = k.view(bsz, self.num_heads, tgt_len, self.head_dim).permute(0, 2, 1, 3)
158
+ v = v.view(bsz, self.num_heads, tgt_len, self.head_dim).permute(0, 2, 1, 3)
159
+
160
+ # Flash attention requires fp16 or bf16
161
+ orig_dtype = q.dtype
162
+ if orig_dtype not in (torch.float16, torch.bfloat16):
163
+ q, k, v = q.to(torch.bfloat16), k.to(torch.bfloat16), v.to(torch.bfloat16)
164
+
165
+ if key_padding_mask is not None:
166
+ # Unpad, run varlen flash attention, repad
167
+ from flash_attn import flash_attn_varlen_func
168
+ attention_mask = ~key_padding_mask # True = valid token
169
+ q_unpad, indices, cu_seqlens, max_seqlen, _ = unpad_input(q, attention_mask)
170
+ k_unpad, _, _, _, _ = unpad_input(k, attention_mask)
171
+ v_unpad, _, _, _, _ = unpad_input(v, attention_mask)
172
+
173
+ out_unpad = flash_attn_varlen_func(
174
+ q_unpad, k_unpad, v_unpad,
175
+ cu_seqlens_q=cu_seqlens,
176
+ cu_seqlens_k=cu_seqlens,
177
+ max_seqlen_q=max_seqlen,
178
+ max_seqlen_k=max_seqlen,
179
+ softmax_scale=1.0, # q already pre-scaled
180
+ causal=False,
181
+ )
182
+ out = pad_input(out_unpad, indices, bsz, tgt_len)
183
+ else:
184
+ out = flash_attn_func(q, k, v, softmax_scale=1.0, causal=False)
185
+
186
+ out = out.to(orig_dtype).permute(1, 0, 2, 3).contiguous().view(tgt_len, bsz, self.embed_dim)
187
+ return self.out_proj(out), None
188
+
189
+
190
+ UTRLM_ATTENTION_CLASSES = {
191
+ "eager": UtrLmAttention,
192
+ "sdpa": UtrLmSdpaAttention,
193
+ "flash_attention_2": UtrLmFlashAttention2,
194
+ }
195
+
196
+
197
+ # ---------------------------------------------------------------------------
198
+ # Transformer layer (pre-LN)
199
+ # ---------------------------------------------------------------------------
200
+
201
+ def _gelu(x):
202
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
203
+
204
+
205
+ class UtrLmLayer(nn.Module):
206
+ def __init__(self, embed_dim: int, attention_heads: int, config: UtrLmConfig):
207
+ super().__init__()
208
+ attn_cls = UTRLM_ATTENTION_CLASSES[getattr(config, "_attn_implementation", "eager")]
209
+ self.self_attn = attn_cls(embed_dim, attention_heads)
210
+ self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
211
+ self.fc1 = nn.Linear(embed_dim, 4 * embed_dim)
212
+ self.fc2 = nn.Linear(4 * embed_dim, embed_dim)
213
+ self.final_layer_norm = nn.LayerNorm(embed_dim)
214
+
215
+ def forward(self, x, padding_mask, output_attentions: bool = False):
216
+ residual = x
217
+ x = self.self_attn_layer_norm(x)
218
+ x, attn_weights = self.self_attn(x, key_padding_mask=padding_mask, output_attentions=output_attentions)
219
+ x = residual + x
220
+
221
+ residual = x
222
+ x = self.final_layer_norm(x)
223
+ x = _gelu(self.fc1(x))
224
+ x = self.fc2(x)
225
+ return residual + x, attn_weights
226
+
227
+
228
+ # ---------------------------------------------------------------------------
229
+ # Backbone
230
+ # ---------------------------------------------------------------------------
231
+
232
+ class UtrLmModel(PreTrainedModel):
233
+ """
234
+ UTR-LM encoder backbone. Returns last_hidden_state (B, T, E).
235
+ The [CLS] token sits at position 0 (prepend_bos=True by default).
236
+ """
237
+
238
+ config_class = UtrLmConfig
239
+ base_model_prefix = "utrlm"
240
+ _supports_sdpa = True
241
+ _supports_flash_attn_2 = True
242
+
243
+ def __init__(self, config: UtrLmConfig):
244
+ super().__init__(config)
245
+ self.embed_scale = 1
246
+ self.embed_tokens = nn.Embedding(
247
+ config.alphabet_size, config.embed_dim, padding_idx=config.padding_idx
248
+ )
249
+ self.layers = nn.ModuleList(
250
+ [UtrLmLayer(config.embed_dim, config.attention_heads, config) for _ in range(config.num_layers)]
251
+ )
252
+ self.emb_layer_norm_after = nn.LayerNorm(config.embed_dim)
253
+ self.post_init()
254
+
255
+ def get_input_embeddings(self):
256
+ return self.embed_tokens
257
+
258
+ def set_input_embeddings(self, value):
259
+ self.embed_tokens = value
260
+
261
+ def forward(
262
+ self,
263
+ input_ids: torch.LongTensor,
264
+ attention_mask: Optional[torch.BoolTensor] = None,
265
+ output_hidden_states: Optional[bool] = None,
266
+ output_attentions: Optional[bool] = None,
267
+ return_dict: Optional[bool] = None,
268
+ ) -> Union[Tuple, BaseModelOutput]:
269
+ output_hidden_states = (
270
+ output_hidden_states if output_hidden_states is not None
271
+ else self.config.output_hidden_states
272
+ )
273
+ output_attentions = (
274
+ output_attentions if output_attentions is not None else self.config.output_attentions
275
+ )
276
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
277
+
278
+ cfg = self.config
279
+ # HF convention: attention_mask is 1=attend, 0=pad.
280
+ # Convert to bool padding_mask (True = ignore) or derive from input_ids.
281
+ if attention_mask is not None:
282
+ padding_mask = attention_mask.eq(0)
283
+ else:
284
+ padding_mask = input_ids.eq(cfg.padding_idx)
285
+
286
+ x = self.embed_scale * self.embed_tokens(input_ids)
287
+
288
+ if cfg.token_dropout:
289
+ x.masked_fill_((input_ids == cfg.mask_idx).unsqueeze(-1), 0.0)
290
+ mask_ratio_train = 0.15 * 0.8
291
+ src_lengths = (~padding_mask).sum(-1)
292
+ mask_ratio_observed = (input_ids == cfg.mask_idx).sum(-1).to(x.dtype) / src_lengths.to(x.dtype)
293
+ x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
294
+
295
+ if padding_mask is not None:
296
+ x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
297
+
298
+ all_hidden_states = () if output_hidden_states else None
299
+ all_attentions = () if output_attentions else None
300
+ if output_hidden_states:
301
+ all_hidden_states += (x,)
302
+
303
+ x = x.transpose(0, 1) # (B, T, E) -> (T, B, E)
304
+ effective_padding = padding_mask if padding_mask.any() else None
305
+
306
+ for layer in self.layers:
307
+ x, attn_weights = layer(x, padding_mask=effective_padding, output_attentions=output_attentions)
308
+ if output_hidden_states:
309
+ all_hidden_states += (x.transpose(0, 1),)
310
+ if output_attentions:
311
+ all_attentions += (attn_weights,)
312
+
313
+ x = self.emb_layer_norm_after(x)
314
+ x = x.transpose(0, 1) # (T, B, E) -> (B, T, E)
315
+
316
+ if output_hidden_states:
317
+ all_hidden_states = all_hidden_states[:-1] + (x,)
318
+
319
+ if not return_dict:
320
+ return tuple(v for v in [x, all_hidden_states, all_attentions] if v is not None)
321
+
322
+ return BaseModelOutput(
323
+ last_hidden_state=x,
324
+ hidden_states=all_hidden_states,
325
+ attentions=all_attentions,
326
+ )
327
+
328
+
329
+ # ---------------------------------------------------------------------------
330
+ # MLM head
331
+ # ---------------------------------------------------------------------------
332
+
333
+ class UtrLmForMaskedLM(PreTrainedModel):
334
+ """
335
+ UTR-LM with a masked-language-modelling head.
336
+ Returns MaskedLMOutput with logits (B, T, vocab_size).
337
+ """
338
+
339
+ config_class = UtrLmConfig
340
+ base_model_prefix = "utrlm"
341
+ _supports_sdpa = True
342
+ _supports_flash_attn_2 = True
343
+
344
+ def __init__(self, config: UtrLmConfig):
345
+ super().__init__(config)
346
+ self.utrlm = UtrLmModel(config)
347
+
348
+ embed_dim = config.embed_dim
349
+ vocab_size = config.alphabet_size
350
+ self.lm_head = nn.ModuleDict({
351
+ "dense": nn.Linear(embed_dim, embed_dim),
352
+ "layer_norm": nn.LayerNorm(embed_dim),
353
+ })
354
+ self.lm_head_bias = nn.Parameter(torch.zeros(vocab_size))
355
+
356
+ self.post_init()
357
+
358
+ def get_input_embeddings(self):
359
+ return self.utrlm.embed_tokens
360
+
361
+ def set_input_embeddings(self, value):
362
+ self.utrlm.embed_tokens = value
363
+
364
+ def get_output_embeddings(self):
365
+ return self.utrlm.embed_tokens
366
+
367
+ def set_output_embeddings(self, new_embeddings):
368
+ self.utrlm.embed_tokens = new_embeddings
369
+
370
+ def _lm_head_forward(self, x: torch.Tensor) -> torch.Tensor:
371
+ x = self.lm_head["dense"](x)
372
+ x = _gelu(x)
373
+ x = self.lm_head["layer_norm"](x)
374
+ return F.linear(x, self.utrlm.embed_tokens.weight) + self.lm_head_bias
375
+
376
+ def forward(
377
+ self,
378
+ input_ids: torch.LongTensor,
379
+ attention_mask: Optional[torch.BoolTensor] = None,
380
+ labels: Optional[torch.LongTensor] = None,
381
+ output_hidden_states: Optional[bool] = None,
382
+ output_attentions: Optional[bool] = None,
383
+ return_dict: Optional[bool] = None,
384
+ ) -> Union[Tuple, MaskedLMOutput]:
385
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
386
+
387
+ outputs = self.utrlm(
388
+ input_ids,
389
+ attention_mask=attention_mask,
390
+ output_hidden_states=output_hidden_states,
391
+ output_attentions=output_attentions,
392
+ return_dict=True,
393
+ )
394
+ logits = self._lm_head_forward(outputs.last_hidden_state)
395
+
396
+ loss = None
397
+ if labels is not None:
398
+ loss = F.cross_entropy(
399
+ logits.view(-1, self.config.alphabet_size),
400
+ labels.view(-1),
401
+ ignore_index=self.config.padding_idx,
402
+ )
403
+
404
+ if not return_dict:
405
+ output = (logits,) + outputs[1:]
406
+ return (loss,) + output if loss is not None else output
407
+
408
+ return MaskedLMOutput(
409
+ loss=loss,
410
+ logits=logits,
411
+ hidden_states=outputs.hidden_states,
412
+ attentions=outputs.attentions,
413
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33e7cfeb0d8b44636ee45e87de2c8af59f114abc963378eabe40c41073654e63
3
+ size 4866715
special_tokens_map.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "<cls>",
3
+ "eos_token": "<eos>",
4
+ "mask_token": "<mask>",
5
+ "pad_token": "<pad>",
6
+ "sep_token": "<sep>",
7
+ "unk_token": "<unk>"
8
+ }
tokenization_utrlm.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Character-level RNA tokenizer for UTR-LM."""
2
+
3
+ import json
4
+ import os
5
+ from typing import Dict, List, Optional, Tuple
6
+
7
+ from transformers import PreTrainedTokenizer
8
+
9
+ # Canonical vocab - fixed; never changes across checkpoints.
10
+ _VOCAB: Dict[str, int] = {
11
+ "<pad>": 0,
12
+ "<eos>": 1,
13
+ "<unk>": 2,
14
+ "A": 3,
15
+ "G": 4,
16
+ "C": 5,
17
+ "T": 6,
18
+ "<cls>": 7,
19
+ "<mask>": 8,
20
+ "<sep>": 9,
21
+ }
22
+ _IDS_TO_TOKENS: Dict[int, str] = {v: k for k, v in _VOCAB.items()}
23
+
24
+
25
+ class UtrLmTokenizer(PreTrainedTokenizer):
26
+ """
27
+ Character-level tokenizer for UTR-LM RNA sequences.
28
+
29
+ Each nucleotide (A / G / C / T) maps to a single token.
30
+ Sequences are automatically wrapped with [CLS] ... [EOS] on encoding.
31
+
32
+ Example::
33
+
34
+ tok = UtrLmTokenizer()
35
+ enc = tok("ATGCATG", return_tensors="pt")
36
+ # enc.input_ids: [[7, 3, 6, 4, 5, 3, 6, 1]]
37
+ # CLS A T G C A T EOS
38
+ """
39
+
40
+ vocab_files_names = {"vocab_file": "vocab.json"}
41
+ model_input_names = ["input_ids", "attention_mask"]
42
+
43
+ def __init__(
44
+ self,
45
+ vocab_file: Optional[str] = None,
46
+ cls_token: str = "<cls>",
47
+ pad_token: str = "<pad>",
48
+ mask_token: str = "<mask>",
49
+ eos_token: str = "<eos>",
50
+ unk_token: str = "<unk>",
51
+ sep_token: str = "<sep>",
52
+ **kwargs,
53
+ ):
54
+ # Build vocab from file if provided (allows future extension), else use default
55
+ if vocab_file is not None and os.path.isfile(vocab_file):
56
+ with open(vocab_file) as f:
57
+ self._vocab = json.load(f)
58
+ else:
59
+ self._vocab = dict(_VOCAB)
60
+ self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
61
+
62
+ super().__init__(
63
+ cls_token=cls_token,
64
+ pad_token=pad_token,
65
+ mask_token=mask_token,
66
+ eos_token=eos_token,
67
+ unk_token=unk_token,
68
+ sep_token=sep_token,
69
+ **kwargs,
70
+ )
71
+
72
+ # ------------------------------------------------------------------
73
+ # Required overrides
74
+ # ------------------------------------------------------------------
75
+
76
+ @property
77
+ def vocab_size(self) -> int:
78
+ return len(self._vocab)
79
+
80
+ def get_vocab(self) -> Dict[str, int]:
81
+ return dict(self._vocab)
82
+
83
+ def _tokenize(self, text: str) -> List[str]:
84
+ """Split sequence into individual characters."""
85
+ return list(text)
86
+
87
+ def _convert_token_to_id(self, token: str) -> int:
88
+ return self._vocab.get(token, self._vocab["<unk>"])
89
+
90
+ def _convert_id_to_token(self, index: int) -> str:
91
+ return self._ids_to_tokens.get(index, "<unk>")
92
+
93
+ def save_vocabulary(
94
+ self, save_directory: str, filename_prefix: Optional[str] = None
95
+ ) -> Tuple[str]:
96
+ os.makedirs(save_directory, exist_ok=True)
97
+ fname = (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
98
+ path = os.path.join(save_directory, fname)
99
+ with open(path, "w") as f:
100
+ json.dump(self._vocab, f, indent=2)
101
+ return (path,)
102
+
103
+ # ------------------------------------------------------------------
104
+ # Special-token wrapping: prepend [CLS], append [EOS]
105
+ # ------------------------------------------------------------------
106
+
107
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
108
+ cls = [self.cls_token_id]
109
+ eos = [self.eos_token_id]
110
+ if token_ids_1 is None:
111
+ return cls + token_ids_0 + eos
112
+ return cls + token_ids_0 + eos + cls + token_ids_1 + eos
113
+
114
+ def get_special_tokens_mask(self, token_ids_0, token_ids_1=None,
115
+ already_has_special_tokens=False):
116
+ if already_has_special_tokens:
117
+ return super().get_special_tokens_mask(
118
+ token_ids_0, token_ids_1, already_has_special_tokens=True
119
+ )
120
+ mask = [1] + [0] * len(token_ids_0) + [1]
121
+ if token_ids_1 is not None:
122
+ mask += [1] + [0] * len(token_ids_1) + [1]
123
+ return mask
124
+
125
+ def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
126
+ if token_ids_1 is None:
127
+ return [0] + token_ids_0 + [0]
128
+ return [0] + token_ids_0 + [0, 0] + token_ids_1 + [0]
tokenizer_config.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<pad>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<eos>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "<unk>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "7": {
28
+ "content": "<cls>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "8": {
36
+ "content": "<mask>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "9": {
44
+ "content": "<sep>",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ }
51
+ },
52
+ "clean_up_tokenization_spaces": false,
53
+ "cls_token": "<cls>",
54
+ "eos_token": "<eos>",
55
+ "extra_special_tokens": {},
56
+ "mask_token": "<mask>",
57
+ "model_max_length": 1024,
58
+ "pad_token": "<pad>",
59
+ "sep_token": "<sep>",
60
+ "tokenizer_class": "UtrLmTokenizer",
61
+ "unk_token": "<unk>"
62
+ }
vocab.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<pad>": 0,
3
+ "<eos>": 1,
4
+ "<unk>": 2,
5
+ "A": 3,
6
+ "G": 4,
7
+ "C": 5,
8
+ "T": 6,
9
+ "<cls>": 7,
10
+ "<mask>": 8,
11
+ "<sep>": 9
12
+ }