McClain commited on
Commit
c8bd3a2
·
verified ·
1 Parent(s): 895003a

Upload PlasmidLM pretrained checkpoint (v4, step 15000)

Browse files
README.md ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: apache-2.0
5
+ library_name: transformers
6
+ tags:
7
+ - biology
8
+ - genomics
9
+ - dna
10
+ - plasmid
11
+ - synthetic-biology
12
+ - causal-lm
13
+ - protein-engineering
14
+ datasets:
15
+ - custom
16
+ pipeline_tag: text-generation
17
+ model-index:
18
+ - name: PlasmidLM
19
+ results:
20
+ - task:
21
+ type: text-generation
22
+ name: Plasmid DNA Generation
23
+ metrics:
24
+ - name: Eval Loss
25
+ type: loss
26
+ value: 0.093
27
+ - name: Token Accuracy
28
+ type: accuracy
29
+ value: 0.961
30
+ ---
31
+
32
+ # PlasmidLM
33
+
34
+ A 17M-parameter transformer language model for conditional generation of synthetic plasmid DNA sequences.
35
+
36
+ ## Model Description
37
+
38
+ PlasmidLM generates plasmid DNA sequences conditioned on functional component specifications. Given a prompt specifying desired elements (antibiotic resistance genes, origins of replication, promoters, reporters, etc.), it autoregressively generates a complete DNA sequence containing those elements.
39
+
40
+ **Architecture**: LLaMA-style transformer decoder with RoPE, RMSNorm, and GELU activations.
41
+
42
+ | Parameter | Value |
43
+ |-----------|-------|
44
+ | Parameters | 17M |
45
+ | Hidden size | 384 |
46
+ | Layers | 10 |
47
+ | Attention heads | 8 |
48
+ | Context length | 16,384 tokens |
49
+ | Vocabulary | 120 tokens |
50
+
51
+ The vocabulary consists of 5 DNA bases (A, T, C, G, N), control tokens (BOS, EOS, SEP, PAD, UNK), and ~100 categorical tokens representing functional plasmid components (e.g., `<AMR_KANAMYCIN>`, `<ORI_COLE1>`, `<PROM_T7>`).
52
+
53
+ ## Training
54
+
55
+ Pretrained with causal language modeling on ~108K plasmid sequences derived from the [Addgene](https://www.addgene.org/) repository, annotated with functional components via [pLannotate](https://github.com/barricklab/pLannotate).
56
+
57
+ - **Steps**: 15,000
58
+ - **Epochs**: ~2.3
59
+ - **Eval loss**: 0.093
60
+ - **Token accuracy**: 96.1%
61
+ - **Optimizer**: AdamW
62
+ - **Precision**: bf16
63
+
64
+ ## Intended Use
65
+
66
+ This is a **base pretrained model**. It has learned the statistical patterns of plasmid DNA sequences and their relationship to categorical component tokens. It can be used for:
67
+
68
+ - **Direct generation**: Prompt with component tokens to generate plasmid sequences
69
+ - **Fine-tuning**: Post-train with reinforcement learning (GRPO/PPO) to improve motif placement accuracy
70
+ - **Embeddings**: Use hidden states as learned representations of plasmid sequences
71
+ - **Research**: Study the learned structure of synthetic DNA
72
+
73
+ ## Usage
74
+
75
+ ```python
76
+ from transformers import AutoModelForCausalLM, AutoTokenizer
77
+
78
+ model = AutoModelForCausalLM.from_pretrained("McClain/PlasmidLM", trust_remote_code=True)
79
+ tokenizer = AutoTokenizer.from_pretrained("McClain/PlasmidLM", trust_remote_code=True)
80
+
81
+ # Generate a plasmid with kanamycin resistance and ColE1 origin
82
+ prompt = "<BOS><AMR_KANAMYCIN><ORI_COLE1><SEP>"
83
+ inputs = tokenizer(prompt, return_tensors="pt")
84
+ outputs = model.generate(**inputs, max_new_tokens=4096, temperature=0.8, do_sample=True)
85
+ sequence = tokenizer.decode(outputs[0], skip_special_tokens=False)
86
+ print(sequence)
87
+ ```
88
+
89
+ ## Input Format
90
+
91
+ ```
92
+ <BOS><TOKEN1><TOKEN2>...<SEP>
93
+ ```
94
+
95
+ The model generates DNA bases (A/T/C/G) after the `<SEP>` token until it produces `<EOS>` or hits the maximum length.
96
+
97
+ ## Component Categories
98
+
99
+ | Category | Examples | Count |
100
+ |----------|----------|-------|
101
+ | Antibiotic Resistance (AMR) | Kanamycin, Ampicillin, Chloramphenicol, ... | 11 |
102
+ | Origin of Replication (ORI) | ColE1, F1, P15A, pSC101, SV40, ... | 7 |
103
+ | Promoter (PROM) | CMV, T7, U6, EF1a, CAG, ... | 11 |
104
+ | Reporter | EGFP, mCherry, YFP, NanoLuc, ... | 6 |
105
+ | Vector Type (VEC) | Lentiviral, CRISPR, Bacterial, AAV, ... | 10 |
106
+ | Other | Tags, elements, species, backbones | ~55 |
107
+
108
+ ## Limitations
109
+
110
+ - This is a **pretrained base model** -- it learns sequence statistics but has not been optimized for motif placement accuracy. Post-training with RL significantly improves functional element fidelity.
111
+ - Generated sequences are **not experimentally validated**. Always verify computationally (e.g., with pLannotate) and experimentally before synthesis.
112
+ - The model was trained on Addgene plasmids, which are biased toward commonly deposited vectors (mammalian expression, bacterial cloning, CRISPR).
113
+ - Maximum context of 16K tokens (~16 kbp), which covers most but not all plasmids.
114
+
115
+ ## Citation
116
+
117
+ ```bibtex
118
+ @misc{thiel2026plasmidlm,
119
+ title={PlasmidLM: Language Models for Conditional Plasmid DNA Generation},
120
+ author={Thiel, McClain},
121
+ year={2026},
122
+ url={https://huggingface.co/McClain/PlasmidLM}
123
+ }
124
+ ```
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "PlasmidLMForCausalLM"
4
+ ],
5
+ "bos_token_id": 0,
6
+ "dtype": "float32",
7
+ "eos_token_id": 1,
8
+ "hidden_act": "gelu",
9
+ "hidden_size": 384,
10
+ "intermediate_size": 1536,
11
+ "max_position_embeddings": 16384,
12
+ "model_type": "plasmid_lm",
13
+ "num_attention_heads": 8,
14
+ "num_hidden_layers": 10,
15
+ "pad_token_id": 3,
16
+ "rms_norm_eps": 1e-05,
17
+ "rope_theta": 10000.0,
18
+ "transformers_version": "4.57.6",
19
+ "vocab_size": 120,
20
+ "auto_map": {
21
+ "AutoConfig": "configuration_plasmid_lm.PlasmidLMConfig",
22
+ "AutoModel": "modeling_plasmid_lm.PlasmidLMModel",
23
+ "AutoModelForCausalLM": "modeling_plasmid_lm.PlasmidLMForCausalLM",
24
+ "AutoTokenizer": [
25
+ "tokenization_plasmid_lm.PlasmidLMTokenizer",
26
+ null
27
+ ]
28
+ }
29
+ }
configuration_plasmid_lm.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HuggingFace configuration for PlasmidLM."""
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class PlasmidLMConfig(PretrainedConfig):
7
+ model_type = "plasmid_lm"
8
+
9
+ def __init__(
10
+ self,
11
+ vocab_size: int = 112,
12
+ hidden_size: int = 384,
13
+ num_hidden_layers: int = 10,
14
+ num_attention_heads: int = 8,
15
+ intermediate_size: int = 1536,
16
+ hidden_act: str = "gelu",
17
+ rms_norm_eps: float = 1e-5,
18
+ max_position_embeddings: int = 16384,
19
+ rope_theta: float = 10000.0,
20
+ tie_word_embeddings: bool = True,
21
+ **kwargs,
22
+ ):
23
+ self.hidden_size = hidden_size
24
+ self.num_hidden_layers = num_hidden_layers
25
+ self.num_attention_heads = num_attention_heads
26
+ self.intermediate_size = intermediate_size
27
+ self.hidden_act = hidden_act
28
+ self.rms_norm_eps = rms_norm_eps
29
+ self.max_position_embeddings = max_position_embeddings
30
+ self.rope_theta = rope_theta
31
+ super().__init__(
32
+ vocab_size=vocab_size,
33
+ tie_word_embeddings=tie_word_embeddings,
34
+ **kwargs,
35
+ )
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": 1,
5
+ "pad_token_id": 3,
6
+ "transformers_version": "4.57.6"
7
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4429ce1399e6bab5ac053d7aae115daa7870032b0f54769859d24acb664ba91
3
+ size 71004376
modeling_plasmid_lm.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HuggingFace-compatible PlasmidLM model for use with AutoModelForCausalLM and vLLM."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from typing import Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from transformers import PreTrainedModel
12
+ from transformers.cache_utils import DynamicCache
13
+ from transformers.generation import GenerationMixin
14
+ from transformers.modeling_outputs import CausalLMOutputWithPast
15
+
16
+ from .configuration_plasmid_lm import PlasmidLMConfig
17
+
18
+
19
+ def _rope_freqs(dim: int, max_len: int, base: float) -> Tuple[torch.Tensor, torch.Tensor]:
20
+ freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
21
+ t = torch.arange(max_len).float()
22
+ angles = torch.outer(t, freqs)
23
+ return torch.cos(angles), torch.sin(angles)
24
+
25
+
26
+ def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, offset: int = 0) -> torch.Tensor:
27
+ S = x.shape[2]
28
+ cos = cos[offset:offset + S].unsqueeze(0).unsqueeze(0)
29
+ sin = sin[offset:offset + S].unsqueeze(0).unsqueeze(0)
30
+ x1, x2 = x[..., ::2], x[..., 1::2]
31
+ return torch.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1).flatten(-2)
32
+
33
+
34
+ class PlasmidLMAttention(nn.Module):
35
+ def __init__(self, config: PlasmidLMConfig):
36
+ super().__init__()
37
+ self.num_heads = config.num_attention_heads
38
+ self.head_dim = config.hidden_size // config.num_attention_heads
39
+ self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
40
+ self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
41
+ self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
42
+ self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
43
+
44
+ def forward(
45
+ self,
46
+ hidden_states: torch.Tensor,
47
+ rope_cos: torch.Tensor,
48
+ rope_sin: torch.Tensor,
49
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
50
+ position_offset: int = 0,
51
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
52
+ B, S, _ = hidden_states.shape
53
+ q = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
54
+ k = self.k_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
55
+ v = self.v_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
56
+
57
+ dtype = q.dtype
58
+ q = _apply_rope(q, rope_cos, rope_sin, offset=position_offset).to(dtype)
59
+ k = _apply_rope(k, rope_cos, rope_sin, offset=position_offset).to(dtype)
60
+
61
+ if past_key_value is not None:
62
+ k = torch.cat([past_key_value[0], k], dim=2)
63
+ v = torch.cat([past_key_value[1], v], dim=2)
64
+ new_kv = (k, v)
65
+
66
+ use_causal = past_key_value is None
67
+ attn = F.scaled_dot_product_attention(q, k, v, is_causal=use_causal)
68
+ out = attn.transpose(1, 2).reshape(B, S, -1)
69
+ return self.o_proj(out), new_kv
70
+
71
+
72
+ class PlasmidLMMLP(nn.Module):
73
+ def __init__(self, config: PlasmidLMConfig):
74
+ super().__init__()
75
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
76
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
77
+ self.act = nn.GELU()
78
+
79
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
80
+ return self.down_proj(self.act(self.up_proj(x)))
81
+
82
+
83
+ class PlasmidLMDecoderLayer(nn.Module):
84
+ def __init__(self, config: PlasmidLMConfig):
85
+ super().__init__()
86
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
87
+ self.self_attn = PlasmidLMAttention(config)
88
+ self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
89
+ self.mlp = PlasmidLMMLP(config)
90
+
91
+ def forward(
92
+ self,
93
+ hidden_states: torch.Tensor,
94
+ rope_cos: torch.Tensor,
95
+ rope_sin: torch.Tensor,
96
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
97
+ position_offset: int = 0,
98
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
99
+ residual = hidden_states
100
+ hidden_states = self.input_layernorm(hidden_states)
101
+ attn_out, new_kv = self.self_attn(hidden_states, rope_cos, rope_sin, past_key_value, position_offset)
102
+ hidden_states = residual + attn_out
103
+
104
+ residual = hidden_states
105
+ hidden_states = residual + self.mlp(self.post_attention_layernorm(hidden_states))
106
+ return hidden_states, new_kv
107
+
108
+
109
+ class PlasmidLMPreTrainedModel(PreTrainedModel):
110
+ config_class = PlasmidLMConfig
111
+ base_model_prefix = "model"
112
+ supports_gradient_checkpointing = True
113
+
114
+ def _set_gradient_checkpointing(self, module, value=False):
115
+ if isinstance(module, PlasmidLMModel):
116
+ module.gradient_checkpointing = value
117
+
118
+
119
+ class PlasmidLMModel(PlasmidLMPreTrainedModel):
120
+ """Base model (backbone) — returned by AutoModel."""
121
+
122
+ def __init__(self, config: PlasmidLMConfig):
123
+ super().__init__(config)
124
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
125
+ self.layers = nn.ModuleList([PlasmidLMDecoderLayer(config) for _ in range(config.num_hidden_layers)])
126
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
127
+
128
+ head_dim = config.hidden_size // config.num_attention_heads
129
+ cos, sin = _rope_freqs(head_dim, config.max_position_embeddings, config.rope_theta)
130
+ self.register_buffer("rope_cos", cos, persistent=False)
131
+ self.register_buffer("rope_sin", sin, persistent=False)
132
+
133
+ self.gradient_checkpointing = False
134
+ self.post_init()
135
+
136
+ def get_input_embeddings(self):
137
+ return self.embed_tokens
138
+
139
+ def set_input_embeddings(self, value):
140
+ self.embed_tokens = value
141
+
142
+ def forward(
143
+ self,
144
+ input_ids: torch.Tensor,
145
+ past_key_values: Optional[list] = None,
146
+ position_offset: int = 0,
147
+ **kwargs,
148
+ ) -> Tuple[torch.Tensor, list]:
149
+ hidden_states = self.embed_tokens(input_ids)
150
+ new_kv_caches = []
151
+ for i, layer in enumerate(self.layers):
152
+ past_kv = past_key_values[i] if past_key_values else None
153
+ if self.gradient_checkpointing and self.training:
154
+ # Gradient checkpointing recomputes activations on backward — no past_kv during training
155
+ def make_ckpt_fn(l):
156
+ def fn(h, cos, sin):
157
+ out, kv = l(h, cos, sin, None, 0)
158
+ return out, kv[0], kv[1]
159
+ return fn
160
+ hidden_states, k, v = torch.utils.checkpoint.checkpoint(
161
+ make_ckpt_fn(layer), hidden_states, self.rope_cos, self.rope_sin,
162
+ use_reentrant=False,
163
+ )
164
+ new_kv = (k, v)
165
+ else:
166
+ hidden_states, new_kv = layer(hidden_states, self.rope_cos, self.rope_sin, past_kv, position_offset)
167
+ new_kv_caches.append(new_kv)
168
+ hidden_states = self.norm(hidden_states)
169
+ return hidden_states, new_kv_caches
170
+
171
+
172
+ class PlasmidLMForCausalLM(PlasmidLMPreTrainedModel, GenerationMixin):
173
+ _tied_weights_keys = ["lm_head.weight"]
174
+
175
+ def __init__(self, config: PlasmidLMConfig):
176
+ super().__init__(config)
177
+ self.model = PlasmidLMModel(config)
178
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
179
+
180
+ self.post_init()
181
+
182
+ def get_input_embeddings(self):
183
+ return self.model.embed_tokens
184
+
185
+ def set_input_embeddings(self, value):
186
+ self.model.embed_tokens = value
187
+
188
+ def get_output_embeddings(self):
189
+ return self.lm_head
190
+
191
+ def set_output_embeddings(self, new_embeddings):
192
+ self.lm_head = new_embeddings
193
+
194
+ def prepare_inputs_for_generation(
195
+ self,
196
+ input_ids: torch.Tensor,
197
+ past_key_values=None,
198
+ attention_mask: Optional[torch.Tensor] = None,
199
+ **kwargs,
200
+ ) -> dict:
201
+ has_cache = False
202
+ if past_key_values is not None:
203
+ if isinstance(past_key_values, DynamicCache):
204
+ has_cache = past_key_values.get_seq_length() > 0
205
+ elif isinstance(past_key_values, list):
206
+ has_cache = len(past_key_values) > 0 and past_key_values[0] is not None
207
+ if has_cache:
208
+ input_ids = input_ids[:, -1:]
209
+ return {
210
+ "input_ids": input_ids,
211
+ "past_key_values": past_key_values,
212
+ "use_cache": True,
213
+ }
214
+
215
+ def _convert_cache_to_list(self, past_key_values) -> Optional[list]:
216
+ """Convert DynamicCache to list of (k, v) tuples for our model."""
217
+ if past_key_values is None:
218
+ return None
219
+ if isinstance(past_key_values, list):
220
+ return past_key_values
221
+ if isinstance(past_key_values, DynamicCache):
222
+ if past_key_values.get_seq_length() == 0:
223
+ return None
224
+ return [(layer.keys, layer.values) for layer in past_key_values.layers]
225
+ return None
226
+
227
+ def _convert_list_to_cache(self, kv_list: list) -> DynamicCache:
228
+ """Convert list of (k, v) tuples to DynamicCache."""
229
+ cache = DynamicCache()
230
+ for i, (k, v) in enumerate(kv_list):
231
+ cache.update(k, v, layer_idx=i)
232
+ return cache
233
+
234
+ def forward(
235
+ self,
236
+ input_ids: torch.Tensor,
237
+ attention_mask: Optional[torch.Tensor] = None,
238
+ labels: Optional[torch.Tensor] = None,
239
+ past_key_values=None,
240
+ use_cache: bool = False,
241
+ **kwargs,
242
+ ) -> CausalLMOutputWithPast:
243
+ kv_list = self._convert_cache_to_list(past_key_values)
244
+
245
+ position_offset = 0
246
+ if kv_list is not None:
247
+ position_offset = kv_list[0][0].shape[2]
248
+
249
+ hidden_states, new_kv_list = self.model(input_ids, kv_list, position_offset)
250
+ logits = self.lm_head(hidden_states)
251
+
252
+ loss = None
253
+ if labels is not None:
254
+ shift_logits = logits[..., :-1, :].contiguous()
255
+ shift_labels = labels[..., 1:].contiguous()
256
+ loss = F.cross_entropy(
257
+ shift_logits.view(-1, shift_logits.size(-1)),
258
+ shift_labels.view(-1),
259
+ ignore_index=-100,
260
+ )
261
+
262
+ new_cache = None
263
+ if use_cache:
264
+ new_cache = self._convert_list_to_cache(new_kv_list)
265
+
266
+ return CausalLMOutputWithPast(
267
+ loss=loss,
268
+ logits=logits,
269
+ past_key_values=new_cache,
270
+ )
271
+
272
+ @torch.no_grad()
273
+ def generate_simple(
274
+ self,
275
+ input_ids: torch.Tensor,
276
+ max_new_tokens: int = 512,
277
+ temperature: float = 0.8,
278
+ top_k: int = 50,
279
+ ) -> torch.Tensor:
280
+ """Simple autoregressive generation with KV cache."""
281
+ # Prefill
282
+ hidden_states, kv_caches = self.model(input_ids)
283
+ logits = self.lm_head(hidden_states[:, -1:, :]).squeeze(1)
284
+ cur_len = input_ids.shape[1]
285
+
286
+ for _ in range(max_new_tokens):
287
+ scaled = logits.float() / temperature
288
+ scaled = torch.nan_to_num(scaled, nan=0.0, posinf=1e4, neginf=-1e4)
289
+ if top_k > 0:
290
+ k = min(top_k, scaled.size(-1))
291
+ v, _ = torch.topk(scaled, k)
292
+ scaled[scaled < v[:, [-1]]] = float("-inf")
293
+ probs = F.softmax(scaled, dim=-1)
294
+ next_token = torch.multinomial(probs, 1)
295
+ input_ids = torch.cat([input_ids, next_token], dim=1)
296
+
297
+ hidden_states, kv_caches = self.model(next_token, kv_caches, cur_len)
298
+ logits = self.lm_head(hidden_states).squeeze(1)
299
+ cur_len += 1
300
+
301
+ return input_ids
special_tokens.txt ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <BOS>
2
+ <EOS>
3
+ <SEP>
4
+ <PAD>
5
+ <UNK>
6
+ <SEQ>
7
+ <AMR_AMPICILLIN>
8
+ <AMR_BLASTICIDIN>
9
+ <AMR_CHLORAMPHENICOL>
10
+ <AMR_GENTAMICIN>
11
+ <AMR_HYGROMYCIN>
12
+ <AMR_KANAMYCIN>
13
+ <AMR_NEOMYCIN>
14
+ <AMR_PUROMYCIN>
15
+ <AMR_SPECTINOMYCIN>
16
+ <AMR_TETRACYCLINE>
17
+ <AMR_ZEOCIN>
18
+ <BB_LENTIGUIDE_PURO>
19
+ <BB_P1316_IGG2A>
20
+ <BB_PAAV2>
21
+ <BB_PAAV>
22
+ <BB_PCDNA31+>
23
+ <BB_PCDNA31>
24
+ <BB_PCDNA3>
25
+ <BB_PCMV>
26
+ <BB_PCRII_TOPO>
27
+ <BB_PD649>
28
+ <BB_PDONR221>
29
+ <BB_PDONR223>
30
+ <BB_PEGFP_C1>
31
+ <BB_PEGFP_N1>
32
+ <BB_PET28A>
33
+ <BB_PHAGE>
34
+ <BB_PLX_TRC317>
35
+ <BB_PTT3>
36
+ <BB_PUC19>
37
+ <BB_UNKNOWN>
38
+ <COPY_HIGH>
39
+ <COPY_LOW>
40
+ <ELEM_AAV_ITR>
41
+ <ELEM_CMV_ENHANCER>
42
+ <ELEM_CMV_INTRON>
43
+ <ELEM_CPPT>
44
+ <ELEM_GRNA_SCAFFOLD>
45
+ <ELEM_IRES>
46
+ <ELEM_LTR_3>
47
+ <ELEM_LTR_5>
48
+ <ELEM_MCS>
49
+ <ELEM_POLYA_BGH>
50
+ <ELEM_POLYA_SV40>
51
+ <ELEM_PSI>
52
+ <ELEM_TRACRRNA>
53
+ <ELEM_WPRE>
54
+ <ORI_2MU>
55
+ <ORI_COLE1>
56
+ <ORI_F1>
57
+ <ORI_P15A>
58
+ <ORI_PSC101>
59
+ <ORI_RSF>
60
+ <ORI_SV40>
61
+ <PROM_AMPR>
62
+ <PROM_CAG>
63
+ <PROM_CMV>
64
+ <PROM_EF1A>
65
+ <PROM_LAC>
66
+ <PROM_RSV>
67
+ <PROM_SP6>
68
+ <PROM_SV40>
69
+ <PROM_T3>
70
+ <PROM_T5>
71
+ <PROM_T7>
72
+ <PROM_U6>
73
+ <REPORTER_EGFP>
74
+ <REPORTER_GFP>
75
+ <REPORTER_MCHERRY>
76
+ <REPORTER_MEMERALD>
77
+ <REPORTER_NANOLUC>
78
+ <REPORTER_YFP>
79
+ <SP_CELEGANS>
80
+ <SP_DROSOPHILA>
81
+ <SP_ECOLI>
82
+ <SP_HUMAN>
83
+ <SP_MOUSE>
84
+ <SP_RAT>
85
+ <SP_SYNTHETIC>
86
+ <SP_YEAST>
87
+ <SP_ZEBRAFISH>
88
+ <TAG_FLAG>
89
+ <TAG_GST>
90
+ <TAG_HA>
91
+ <TAG_HIS>
92
+ <TAG_MYC>
93
+ <TAG_NLS>
94
+ <TAG_V5>
95
+ <VEC_AAV>
96
+ <VEC_BACTERIAL>
97
+ <VEC_CRISPR>
98
+ <VEC_GATEWAY>
99
+ <VEC_INSECT>
100
+ <VEC_LENTIVIRAL>
101
+ <VEC_MAMMALIAN>
102
+ <VEC_PLANT>
103
+ <VEC_REPORTER>
104
+ <VEC_RETROVIRAL>
105
+ <VEC_YEAST>
tokenization_plasmid_lm.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HuggingFace-compatible tokenizer for PlasmidLM."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ import re
8
+ from typing import List, Optional
9
+
10
+ from transformers import PreTrainedTokenizer
11
+
12
+
13
+ DNA_BASES = list("ATCGNatcgn")
14
+
15
+
16
+ class PlasmidLMTokenizer(PreTrainedTokenizer):
17
+ """Character-level tokenizer for plasmid sequences with special tokens."""
18
+
19
+ vocab_files_names = {"vocab_file": "vocab.json"}
20
+ model_input_names = ["input_ids", "attention_mask"]
21
+
22
+ def __init__(
23
+ self,
24
+ vocab_file: str,
25
+ bos_token: str = "<BOS>",
26
+ eos_token: str = "<EOS>",
27
+ unk_token: str = "<UNK>",
28
+ pad_token: str = "<PAD>",
29
+ sep_token: str = "<SEP>",
30
+ **kwargs,
31
+ ):
32
+ # Load vocab before calling super().__init__
33
+ with open(vocab_file, "r") as f:
34
+ data = json.load(f)
35
+
36
+ # Support nested format with "token_to_id" key
37
+ if isinstance(data, dict) and "token_to_id" in data:
38
+ data = data["token_to_id"]
39
+
40
+ # Ensure DNA bases are in the vocab (matching PlasmidTokenizer)
41
+ next_id = max(data.values()) + 1 if data else 0
42
+ for base in DNA_BASES:
43
+ if base not in data:
44
+ data[base] = next_id
45
+ next_id += 1
46
+
47
+ self._vocab = data
48
+ self._id_to_token = {v: k for k, v in self._vocab.items()}
49
+
50
+ # Only pass special tokens that actually exist in the vocab.
51
+ # PreTrainedTokenizer would otherwise create new IDs for them.
52
+ special_kwargs = {}
53
+ for name, tok in [("bos_token", bos_token), ("eos_token", eos_token),
54
+ ("unk_token", unk_token), ("pad_token", pad_token),
55
+ ("sep_token", sep_token)]:
56
+ if tok in self._vocab:
57
+ special_kwargs[name] = tok
58
+
59
+ super().__init__(**special_kwargs, **kwargs)
60
+
61
+ @property
62
+ def vocab_size(self) -> int:
63
+ return len(self._vocab)
64
+
65
+ @property
66
+ def pad_token_id(self) -> int:
67
+ return self._vocab.get("<PAD>", 0)
68
+
69
+ @property
70
+ def bos_token_id(self) -> int:
71
+ return self._vocab.get("<BOS>", 1)
72
+
73
+ @property
74
+ def eos_token_id(self) -> int:
75
+ return self._vocab.get("<EOS>", 2)
76
+
77
+ @property
78
+ def sep_token_id(self) -> int:
79
+ return self._vocab.get("<SEP>", 3)
80
+
81
+ def get_vocab(self) -> dict:
82
+ return dict(self._vocab)
83
+
84
+ def _tokenize(self, text: str) -> List[str]:
85
+ """Split into special <...> tokens and individual characters."""
86
+ parts = re.split(r"(<[^>]+>)", text)
87
+ tokens = []
88
+ for part in parts:
89
+ if not part or part.isspace():
90
+ continue
91
+ if part.startswith("<") and part.endswith(">"):
92
+ tokens.append(part)
93
+ else:
94
+ tokens.extend(c for c in part if not c.isspace())
95
+ return tokens
96
+
97
+ def _convert_token_to_id(self, token: str) -> int:
98
+ return self._vocab.get(token, self._vocab.get("<UNK>", 0))
99
+
100
+ def _convert_id_to_token(self, index: int) -> str:
101
+ return self._id_to_token.get(index, "<UNK>")
102
+
103
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
104
+ return "".join(tokens)
105
+
106
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
107
+ if not os.path.isdir(save_directory):
108
+ os.makedirs(save_directory, exist_ok=True)
109
+ vocab_file = os.path.join(
110
+ save_directory,
111
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.json",
112
+ )
113
+ with open(vocab_file, "w") as f:
114
+ json.dump(self._vocab, f, indent=2)
115
+ return (vocab_file,)
vocab.json ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<BOS>": 0,
3
+ "<EOS>": 1,
4
+ "<SEP>": 2,
5
+ "<PAD>": 3,
6
+ "<UNK>": 4,
7
+ "<SEQ>": 5,
8
+ "<AMR_AMPICILLIN>": 6,
9
+ "<AMR_BLASTICIDIN>": 7,
10
+ "<AMR_CHLORAMPHENICOL>": 8,
11
+ "<AMR_GENTAMICIN>": 9,
12
+ "<AMR_HYGROMYCIN>": 10,
13
+ "<AMR_KANAMYCIN>": 11,
14
+ "<AMR_NEOMYCIN>": 12,
15
+ "<AMR_PUROMYCIN>": 13,
16
+ "<AMR_SPECTINOMYCIN>": 14,
17
+ "<AMR_TETRACYCLINE>": 15,
18
+ "<AMR_ZEOCIN>": 16,
19
+ "<BB_LENTIGUIDE_PURO>": 17,
20
+ "<BB_P1316_IGG2A>": 18,
21
+ "<BB_PAAV2>": 19,
22
+ "<BB_PAAV>": 20,
23
+ "<BB_PCDNA31+>": 21,
24
+ "<BB_PCDNA31>": 22,
25
+ "<BB_PCDNA3>": 23,
26
+ "<BB_PCMV>": 24,
27
+ "<BB_PCRII_TOPO>": 25,
28
+ "<BB_PD649>": 26,
29
+ "<BB_PDONR221>": 27,
30
+ "<BB_PDONR223>": 28,
31
+ "<BB_PEGFP_C1>": 29,
32
+ "<BB_PEGFP_N1>": 30,
33
+ "<BB_PET28A>": 31,
34
+ "<BB_PHAGE>": 32,
35
+ "<BB_PLX_TRC317>": 33,
36
+ "<BB_PTT3>": 34,
37
+ "<BB_PUC19>": 35,
38
+ "<BB_UNKNOWN>": 36,
39
+ "<COPY_HIGH>": 37,
40
+ "<COPY_LOW>": 38,
41
+ "<ELEM_AAV_ITR>": 39,
42
+ "<ELEM_CMV_ENHANCER>": 40,
43
+ "<ELEM_CMV_INTRON>": 41,
44
+ "<ELEM_CPPT>": 42,
45
+ "<ELEM_GRNA_SCAFFOLD>": 43,
46
+ "<ELEM_IRES>": 44,
47
+ "<ELEM_LTR_3>": 45,
48
+ "<ELEM_LTR_5>": 46,
49
+ "<ELEM_MCS>": 47,
50
+ "<ELEM_POLYA_BGH>": 48,
51
+ "<ELEM_POLYA_SV40>": 49,
52
+ "<ELEM_PSI>": 50,
53
+ "<ELEM_TRACRRNA>": 51,
54
+ "<ELEM_WPRE>": 52,
55
+ "<ORI_2MU>": 53,
56
+ "<ORI_COLE1>": 54,
57
+ "<ORI_F1>": 55,
58
+ "<ORI_P15A>": 56,
59
+ "<ORI_PSC101>": 57,
60
+ "<ORI_RSF>": 58,
61
+ "<ORI_SV40>": 59,
62
+ "<PROM_AMPR>": 60,
63
+ "<PROM_CAG>": 61,
64
+ "<PROM_CMV>": 62,
65
+ "<PROM_EF1A>": 63,
66
+ "<PROM_LAC>": 64,
67
+ "<PROM_RSV>": 65,
68
+ "<PROM_SP6>": 66,
69
+ "<PROM_SV40>": 67,
70
+ "<PROM_T3>": 68,
71
+ "<PROM_T5>": 69,
72
+ "<PROM_T7>": 70,
73
+ "<PROM_U6>": 71,
74
+ "<REPORTER_EGFP>": 72,
75
+ "<REPORTER_GFP>": 73,
76
+ "<REPORTER_MCHERRY>": 74,
77
+ "<REPORTER_MEMERALD>": 75,
78
+ "<REPORTER_NANOLUC>": 76,
79
+ "<REPORTER_YFP>": 77,
80
+ "<SP_CELEGANS>": 78,
81
+ "<SP_DROSOPHILA>": 79,
82
+ "<SP_ECOLI>": 80,
83
+ "<SP_HUMAN>": 81,
84
+ "<SP_MOUSE>": 82,
85
+ "<SP_RAT>": 83,
86
+ "<SP_SYNTHETIC>": 84,
87
+ "<SP_YEAST>": 85,
88
+ "<SP_ZEBRAFISH>": 86,
89
+ "<TAG_FLAG>": 87,
90
+ "<TAG_GST>": 88,
91
+ "<TAG_HA>": 89,
92
+ "<TAG_HIS>": 90,
93
+ "<TAG_MYC>": 91,
94
+ "<TAG_NLS>": 92,
95
+ "<TAG_V5>": 93,
96
+ "<VEC_AAV>": 94,
97
+ "<VEC_BACTERIAL>": 95,
98
+ "<VEC_CRISPR>": 96,
99
+ "<VEC_GATEWAY>": 97,
100
+ "<VEC_INSECT>": 98,
101
+ "<VEC_LENTIVIRAL>": 99,
102
+ "<VEC_MAMMALIAN>": 100,
103
+ "<VEC_PLANT>": 101,
104
+ "<VEC_REPORTER>": 102,
105
+ "<VEC_RETROVIRAL>": 103,
106
+ "<VEC_YEAST>": 104,
107
+ "A": 105,
108
+ "T": 106,
109
+ "C": 107,
110
+ "G": 108,
111
+ "N": 109,
112
+ "a": 110,
113
+ "t": 111,
114
+ "c": 112,
115
+ "g": 113,
116
+ "n": 114
117
+ }