McClain commited on
Commit
ee76406
·
verified ·
1 Parent(s): b2e3372

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,74 +1,42 @@
1
  ---
2
- language:
3
- - en
4
- license: apache-2.0
5
  library_name: transformers
 
6
  tags:
7
- - biology
8
- - genomics
9
- - dna
10
- - plasmid
11
- - synthetic-biology
12
- - causal-lm
13
- - protein-engineering
14
- datasets:
15
- - custom
16
  pipeline_tag: text-generation
17
- model-index:
18
- - name: PlasmidLM
19
- results:
20
- - task:
21
- type: text-generation
22
- name: Plasmid DNA Generation
23
- metrics:
24
- - name: Eval Loss
25
- type: loss
26
- value: 0.093
27
- - name: Token Accuracy
28
- type: accuracy
29
- value: 0.961
30
  ---
31
 
32
  # PlasmidLM
33
 
34
- A 17M-parameter transformer language model for conditional generation of synthetic plasmid DNA sequences.
35
 
36
- ## Model Description
37
 
38
- PlasmidLM generates plasmid DNA sequences conditioned on functional component specifications. Given a prompt specifying desired elements (antibiotic resistance genes, origins of replication, promoters, reporters, etc.), it autoregressively generates a complete DNA sequence containing those elements.
39
-
40
- **Architecture**: LLaMA-style transformer decoder with RoPE, RMSNorm, and GELU activations.
41
-
42
- | Parameter | Value |
43
- |-----------|-------|
44
- | Parameters | 17M |
45
  | Hidden size | 384 |
46
  | Layers | 10 |
47
  | Attention heads | 8 |
48
- | Context length | 16,384 tokens |
49
- | Vocabulary | 120 tokens |
50
-
51
- The vocabulary consists of 5 DNA bases (A, T, C, G, N), control tokens (BOS, EOS, SEP, PAD, UNK), and ~100 categorical tokens representing functional plasmid components (e.g., `<AMR_KANAMYCIN>`, `<ORI_COLE1>`, `<PROM_T7>`).
52
 
53
- ## Training
54
-
55
- Pretrained with causal language modeling on ~108K plasmid sequences derived from the [Addgene](https://www.addgene.org/) repository, annotated with functional components via [pLannotate](https://github.com/barricklab/pLannotate).
56
 
 
57
  - **Steps**: 15,000
58
- - **Epochs**: ~2.3
59
  - **Eval loss**: 0.093
60
  - **Token accuracy**: 96.1%
61
- - **Optimizer**: AdamW
62
- - **Precision**: bf16
63
-
64
- ## Intended Use
65
-
66
- This is a **base pretrained model**. It has learned the statistical patterns of plasmid DNA sequences and their relationship to categorical component tokens. It can be used for:
67
-
68
- - **Direct generation**: Prompt with component tokens to generate plasmid sequences
69
- - **Fine-tuning**: Post-train with reinforcement learning (GRPO/PPO) to improve motif placement accuracy
70
- - **Embeddings**: Use hidden states as learned representations of plasmid sequences
71
- - **Research**: Study the learned structure of synthetic DNA
72
 
73
  ## Usage
74
 
@@ -78,14 +46,15 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
78
  model = AutoModelForCausalLM.from_pretrained("McClain/PlasmidLM", trust_remote_code=True)
79
  tokenizer = AutoTokenizer.from_pretrained("McClain/PlasmidLM", trust_remote_code=True)
80
 
81
- # Generate a plasmid with kanamycin resistance and ColE1 origin
82
  prompt = "<BOS><AMR_KANAMYCIN><ORI_COLE1><SEP>"
83
  inputs = tokenizer(prompt, return_tensors="pt")
84
- outputs = model.generate(**inputs, max_new_tokens=4096, temperature=0.8, do_sample=True)
85
- sequence = tokenizer.decode(outputs[0], skip_special_tokens=False)
86
- print(sequence)
87
  ```
88
 
 
 
89
  ## Input Format
90
 
91
  ```
@@ -94,31 +63,37 @@ print(sequence)
94
 
95
  The model generates DNA bases (A/T/C/G) after the `<SEP>` token until it produces `<EOS>` or hits the maximum length.
96
 
97
- ## Component Categories
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- | Category | Examples | Count |
100
- |----------|----------|-------|
101
- | Antibiotic Resistance (AMR) | Kanamycin, Ampicillin, Chloramphenicol, ... | 11 |
102
- | Origin of Replication (ORI) | ColE1, F1, P15A, pSC101, SV40, ... | 7 |
103
- | Promoter (PROM) | CMV, T7, U6, EF1a, CAG, ... | 11 |
104
- | Reporter | EGFP, mCherry, YFP, NanoLuc, ... | 6 |
105
- | Vector Type (VEC) | Lentiviral, CRISPR, Bacterial, AAV, ... | 10 |
106
- | Other | Tags, elements, species, backbones | ~55 |
107
 
108
  ## Limitations
109
 
110
- - This is a **pretrained base model** -- it learns sequence statistics but has not been optimized for motif placement accuracy. Post-training with RL significantly improves functional element fidelity.
111
- - Generated sequences are **not experimentally validated**. Always verify computationally (e.g., with pLannotate) and experimentally before synthesis.
112
- - The model was trained on Addgene plasmids, which are biased toward commonly deposited vectors (mammalian expression, bacterial cloning, CRISPR).
113
- - Maximum context of 16K tokens (~16 kbp), which covers most but not all plasmids.
114
 
115
  ## Citation
116
 
117
  ```bibtex
118
  @misc{thiel2026plasmidlm,
119
- title={PlasmidLM: Language Models for Conditional Plasmid DNA Generation},
120
  author={Thiel, McClain},
121
- year={2026},
122
- url={https://huggingface.co/McClain/PlasmidLM}
123
  }
124
  ```
 
1
  ---
 
 
 
2
  library_name: transformers
3
+ license: apache-2.0
4
  tags:
5
+ - biology
6
+ - genomics
7
+ - plasmid
8
+ - dna
9
+ - causal-lm
10
+ - synthetic-biology
11
+ language:
12
+ - en
 
13
  pipeline_tag: text-generation
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  ---
15
 
16
  # PlasmidLM
17
 
18
+ A 17.7M parameter autoregressive language model for **plasmid DNA sequence generation**, trained on ~108K plasmid sequences from Addgene.
19
 
20
+ ## Model Details
21
 
22
+ | Property | Value |
23
+ |---|---|
24
+ | Parameters | 17.7M |
25
+ | Architecture | Transformer decoder (dense MLP), LLaMA-style |
 
 
 
26
  | Hidden size | 384 |
27
  | Layers | 10 |
28
  | Attention heads | 8 |
29
+ | Intermediate size | 1,536 |
30
+ | Max sequence length | 16,384 tokens |
31
+ | Tokenizer | Character-level (single DNA bases) |
32
+ | Vocab size | 120 |
33
 
34
+ ### Training
 
 
35
 
36
+ - **Data**: ~108K plasmid sequences from Addgene, annotated with functional components via pLannotate
37
  - **Steps**: 15,000
 
38
  - **Eval loss**: 0.093
39
  - **Token accuracy**: 96.1%
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  ## Usage
42
 
 
46
  model = AutoModelForCausalLM.from_pretrained("McClain/PlasmidLM", trust_remote_code=True)
47
  tokenizer = AutoTokenizer.from_pretrained("McClain/PlasmidLM", trust_remote_code=True)
48
 
49
+ # Condition on antibiotic resistance + origin of replication
50
  prompt = "<BOS><AMR_KANAMYCIN><ORI_COLE1><SEP>"
51
  inputs = tokenizer(prompt, return_tensors="pt")
52
+ outputs = model.generate(**inputs, max_new_tokens=4096, temperature=0.8, do_sample=True, top_p=0.95)
53
+ print(tokenizer.decode(outputs[0].tolist()))
 
54
  ```
55
 
56
+ The model generates plasmid DNA sequences conditioned on functional annotations (antibiotic resistance markers, origins of replication, promoters, reporters, etc.) provided as special tokens in the prompt.
57
+
58
  ## Input Format
59
 
60
  ```
 
63
 
64
  The model generates DNA bases (A/T/C/G) after the `<SEP>` token until it produces `<EOS>` or hits the maximum length.
65
 
66
+ ## Special Tokens
67
+
68
+ | Token | Purpose |
69
+ |---|---|
70
+ | `<BOS>` | Beginning of sequence |
71
+ | `<EOS>` | End of sequence |
72
+ | `<SEP>` | Separator between prompt annotations and DNA sequence |
73
+ | `<PAD>` | Padding |
74
+ | `<AMR_*>` | Antibiotic resistance markers (e.g., `<AMR_KANAMYCIN>`, `<AMR_AMPICILLIN>`) |
75
+ | `<ORI_*>` | Origins of replication (e.g., `<ORI_COLE1>`, `<ORI_P15A>`) |
76
+ | `<PROM_*>` | Promoters (e.g., `<PROM_CMV>`, `<PROM_T7>`) |
77
+ | `<REP_*>` | Reporters (e.g., `<REP_EGFP>`, `<REP_MCHERRY>`) |
78
+
79
+ ## Related Models
80
 
81
+ - [McClain/PlasmidLM-kmer6](https://huggingface.co/McClain/PlasmidLM-kmer6) kmer6 tokenizer, 19.3M params, dense
82
+ - [McClain/PlasmidLM-kmer6-MoE](https://huggingface.co/McClain/PlasmidLM-kmer6-MoE) — kmer6 tokenizer, 78.3M total params, Mixture-of-Experts
 
 
 
 
 
 
83
 
84
  ## Limitations
85
 
86
+ - This is a **pretrained base model** -- generated sequences are not optimized for functional element placement. Post-training with RL improves fidelity.
87
+ - Generated sequences are **not experimentally validated**. Always verify computationally and experimentally before synthesis.
88
+ - Trained on Addgene plasmids, which are biased toward commonly deposited vectors.
89
+ - Maximum context of 16K tokens (~16 kbp).
90
 
91
  ## Citation
92
 
93
  ```bibtex
94
  @misc{thiel2026plasmidlm,
95
+ title={PlasmidLM: Language Models for Plasmid DNA Generation},
96
  author={Thiel, McClain},
97
+ year={2026}
 
98
  }
99
  ```
config.json CHANGED
@@ -25,5 +25,8 @@
25
  "tokenization_plasmid_lm.PlasmidLMTokenizer",
26
  null
27
  ]
28
- }
 
 
 
29
  }
 
25
  "tokenization_plasmid_lm.PlasmidLMTokenizer",
26
  null
27
  ]
28
+ },
29
+ "use_moe": false,
30
+ "tie_word_embeddings": true,
31
+ "use_cache": false
32
  }
configuration_plasmid_lm.py CHANGED
@@ -18,6 +18,16 @@ class PlasmidLMConfig(PretrainedConfig):
18
  max_position_embeddings: int = 16384,
19
  rope_theta: float = 10000.0,
20
  tie_word_embeddings: bool = True,
 
 
 
 
 
 
 
 
 
 
21
  **kwargs,
22
  ):
23
  self.hidden_size = hidden_size
@@ -28,6 +38,16 @@ class PlasmidLMConfig(PretrainedConfig):
28
  self.rms_norm_eps = rms_norm_eps
29
  self.max_position_embeddings = max_position_embeddings
30
  self.rope_theta = rope_theta
 
 
 
 
 
 
 
 
 
 
31
  super().__init__(
32
  vocab_size=vocab_size,
33
  tie_word_embeddings=tie_word_embeddings,
 
18
  max_position_embeddings: int = 16384,
19
  rope_theta: float = 10000.0,
20
  tie_word_embeddings: bool = True,
21
+ # MoE
22
+ use_moe: bool = False,
23
+ num_experts: int = 6,
24
+ num_experts_per_tok: int = 2,
25
+ moe_intermediate_size: int | None = None,
26
+ aux_loss_coef: float = 0.01,
27
+ # Tokenizer metadata (informational, saved in checkpoint)
28
+ tokenizer_type: str = "char",
29
+ kmer_k: int | None = None,
30
+ kmer_stride: int | None = None,
31
  **kwargs,
32
  ):
33
  self.hidden_size = hidden_size
 
38
  self.rms_norm_eps = rms_norm_eps
39
  self.max_position_embeddings = max_position_embeddings
40
  self.rope_theta = rope_theta
41
+ # MoE
42
+ self.use_moe = use_moe
43
+ self.num_experts = num_experts
44
+ self.num_experts_per_tok = num_experts_per_tok
45
+ self.moe_intermediate_size = moe_intermediate_size or intermediate_size
46
+ self.aux_loss_coef = aux_loss_coef
47
+ # Tokenizer metadata
48
+ self.tokenizer_type = tokenizer_type
49
+ self.kmer_k = kmer_k
50
+ self.kmer_stride = kmer_stride
51
  super().__init__(
52
  vocab_size=vocab_size,
53
  tie_word_embeddings=tie_word_embeddings,
modeling_plasmid_lm.py CHANGED
@@ -14,6 +14,7 @@ from transformers.generation import GenerationMixin
14
  from transformers.modeling_outputs import CausalLMOutputWithPast
15
 
16
  from .configuration_plasmid_lm import PlasmidLMConfig
 
17
 
18
 
19
  def _rope_freqs(dim: int, max_len: int, base: float) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -48,6 +49,7 @@ class PlasmidLMAttention(nn.Module):
48
  rope_sin: torch.Tensor,
49
  past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
50
  position_offset: int = 0,
 
51
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
52
  B, S, _ = hidden_states.shape
53
  q = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
@@ -63,8 +65,11 @@ class PlasmidLMAttention(nn.Module):
63
  v = torch.cat([past_key_value[1], v], dim=2)
64
  new_kv = (k, v)
65
 
66
- use_causal = past_key_value is None
67
- attn = F.scaled_dot_product_attention(q, k, v, is_causal=use_causal)
 
 
 
68
  out = attn.transpose(1, 2).reshape(B, S, -1)
69
  return self.o_proj(out), new_kv
70
 
@@ -86,7 +91,11 @@ class PlasmidLMDecoderLayer(nn.Module):
86
  self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
87
  self.self_attn = PlasmidLMAttention(config)
88
  self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
89
- self.mlp = PlasmidLMMLP(config)
 
 
 
 
90
 
91
  def forward(
92
  self,
@@ -95,15 +104,21 @@ class PlasmidLMDecoderLayer(nn.Module):
95
  rope_sin: torch.Tensor,
96
  past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
97
  position_offset: int = 0,
98
- ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
 
99
  residual = hidden_states
100
  hidden_states = self.input_layernorm(hidden_states)
101
- attn_out, new_kv = self.self_attn(hidden_states, rope_cos, rope_sin, past_key_value, position_offset)
102
  hidden_states = residual + attn_out
103
 
104
  residual = hidden_states
105
- hidden_states = residual + self.mlp(self.post_attention_layernorm(hidden_states))
106
- return hidden_states, new_kv
 
 
 
 
 
107
 
108
 
109
  class PlasmidLMPreTrainedModel(PreTrainedModel):
@@ -111,17 +126,6 @@ class PlasmidLMPreTrainedModel(PreTrainedModel):
111
  base_model_prefix = "model"
112
  supports_gradient_checkpointing = True
113
 
114
- def _init_weights(self, module):
115
- if isinstance(module, PlasmidLMModel):
116
- # Recompute RoPE buffers — they are non-persistent so not saved in
117
- # safetensors. from_pretrained's fast-init path zeros them out.
118
- head_dim = self.config.hidden_size // self.config.num_attention_heads
119
- cos, sin = _rope_freqs(
120
- head_dim, self.config.max_position_embeddings, self.config.rope_theta
121
- )
122
- module.rope_cos = cos
123
- module.rope_sin = sin
124
-
125
  def _set_gradient_checkpointing(self, module, value=False):
126
  if isinstance(module, PlasmidLMModel):
127
  module.gradient_checkpointing = value
@@ -136,48 +140,114 @@ class PlasmidLMModel(PlasmidLMPreTrainedModel):
136
  self.layers = nn.ModuleList([PlasmidLMDecoderLayer(config) for _ in range(config.num_hidden_layers)])
137
  self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
138
 
139
- head_dim = config.hidden_size // config.num_attention_heads
140
- cos, sin = _rope_freqs(head_dim, config.max_position_embeddings, config.rope_theta)
141
- self.register_buffer("rope_cos", cos, persistent=False)
142
- self.register_buffer("rope_sin", sin, persistent=False)
143
 
144
  self.gradient_checkpointing = False
145
  self.post_init()
146
 
 
 
 
 
 
 
 
147
  def get_input_embeddings(self):
148
  return self.embed_tokens
149
 
150
  def set_input_embeddings(self, value):
151
  self.embed_tokens = value
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  def forward(
154
  self,
155
  input_ids: torch.Tensor,
156
  past_key_values: Optional[list] = None,
157
  position_offset: int = 0,
 
158
  **kwargs,
159
- ) -> Tuple[torch.Tensor, list]:
 
 
 
 
160
  hidden_states = self.embed_tokens(input_ids)
 
 
 
 
 
 
 
161
  new_kv_caches = []
 
162
  for i, layer in enumerate(self.layers):
163
  past_kv = past_key_values[i] if past_key_values else None
164
  if self.gradient_checkpointing and self.training:
165
  # Gradient checkpointing recomputes activations on backward — no past_kv during training
166
  def make_ckpt_fn(l):
167
  def fn(h, cos, sin):
168
- out, kv = l(h, cos, sin, None, 0)
169
- return out, kv[0], kv[1]
170
  return fn
171
- hidden_states, k, v = torch.utils.checkpoint.checkpoint(
172
  make_ckpt_fn(layer), hidden_states, self.rope_cos, self.rope_sin,
173
  use_reentrant=False,
174
  )
175
  new_kv = (k, v)
176
  else:
177
- hidden_states, new_kv = layer(hidden_states, self.rope_cos, self.rope_sin, past_kv, position_offset)
 
 
178
  new_kv_caches.append(new_kv)
 
179
  hidden_states = self.norm(hidden_states)
180
- return hidden_states, new_kv_caches
181
 
182
 
183
  class PlasmidLMForCausalLM(PlasmidLMPreTrainedModel, GenerationMixin):
@@ -220,6 +290,7 @@ class PlasmidLMForCausalLM(PlasmidLMPreTrainedModel, GenerationMixin):
220
  return {
221
  "input_ids": input_ids,
222
  "past_key_values": past_key_values,
 
223
  "use_cache": True,
224
  }
225
 
@@ -257,7 +328,9 @@ class PlasmidLMForCausalLM(PlasmidLMPreTrainedModel, GenerationMixin):
257
  if kv_list is not None:
258
  position_offset = kv_list[0][0].shape[2]
259
 
260
- hidden_states, new_kv_list = self.model(input_ids, kv_list, position_offset)
 
 
261
  logits = self.lm_head(hidden_states)
262
 
263
  loss = None
@@ -269,6 +342,7 @@ class PlasmidLMForCausalLM(PlasmidLMPreTrainedModel, GenerationMixin):
269
  shift_labels.view(-1),
270
  ignore_index=-100,
271
  )
 
272
 
273
  new_cache = None
274
  if use_cache:
@@ -289,8 +363,8 @@ class PlasmidLMForCausalLM(PlasmidLMPreTrainedModel, GenerationMixin):
289
  top_k: int = 50,
290
  ) -> torch.Tensor:
291
  """Simple autoregressive generation with KV cache."""
292
- # Prefill
293
- hidden_states, kv_caches = self.model(input_ids)
294
  logits = self.lm_head(hidden_states[:, -1:, :]).squeeze(1)
295
  cur_len = input_ids.shape[1]
296
 
@@ -305,7 +379,7 @@ class PlasmidLMForCausalLM(PlasmidLMPreTrainedModel, GenerationMixin):
305
  next_token = torch.multinomial(probs, 1)
306
  input_ids = torch.cat([input_ids, next_token], dim=1)
307
 
308
- hidden_states, kv_caches = self.model(next_token, kv_caches, cur_len)
309
  logits = self.lm_head(hidden_states).squeeze(1)
310
  cur_len += 1
311
 
 
14
  from transformers.modeling_outputs import CausalLMOutputWithPast
15
 
16
  from .configuration_plasmid_lm import PlasmidLMConfig
17
+ from .moe import PlasmidLMSparseMoE
18
 
19
 
20
  def _rope_freqs(dim: int, max_len: int, base: float) -> Tuple[torch.Tensor, torch.Tensor]:
 
49
  rope_sin: torch.Tensor,
50
  past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
51
  position_offset: int = 0,
52
+ attention_mask: Optional[torch.Tensor] = None,
53
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
54
  B, S, _ = hidden_states.shape
55
  q = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
 
65
  v = torch.cat([past_key_value[1], v], dim=2)
66
  new_kv = (k, v)
67
 
68
+ if attention_mask is not None:
69
+ attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)
70
+ else:
71
+ use_causal = past_key_value is None
72
+ attn = F.scaled_dot_product_attention(q, k, v, is_causal=use_causal)
73
  out = attn.transpose(1, 2).reshape(B, S, -1)
74
  return self.o_proj(out), new_kv
75
 
 
91
  self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
92
  self.self_attn = PlasmidLMAttention(config)
93
  self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
94
+ self.use_moe = config.use_moe
95
+ if self.use_moe:
96
+ self.moe = PlasmidLMSparseMoE(config)
97
+ else:
98
+ self.mlp = PlasmidLMMLP(config)
99
 
100
  def forward(
101
  self,
 
104
  rope_sin: torch.Tensor,
105
  past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
106
  position_offset: int = 0,
107
+ attention_mask: Optional[torch.Tensor] = None,
108
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
109
  residual = hidden_states
110
  hidden_states = self.input_layernorm(hidden_states)
111
+ attn_out, new_kv = self.self_attn(hidden_states, rope_cos, rope_sin, past_key_value, position_offset, attention_mask)
112
  hidden_states = residual + attn_out
113
 
114
  residual = hidden_states
115
+ if self.use_moe:
116
+ moe_out, aux_loss = self.moe(self.post_attention_layernorm(hidden_states))
117
+ hidden_states = residual + moe_out
118
+ else:
119
+ hidden_states = residual + self.mlp(self.post_attention_layernorm(hidden_states))
120
+ aux_loss = torch.tensor(0.0, device=hidden_states.device)
121
+ return hidden_states, new_kv, aux_loss
122
 
123
 
124
  class PlasmidLMPreTrainedModel(PreTrainedModel):
 
126
  base_model_prefix = "model"
127
  supports_gradient_checkpointing = True
128
 
 
 
 
 
 
 
 
 
 
 
 
129
  def _set_gradient_checkpointing(self, module, value=False):
130
  if isinstance(module, PlasmidLMModel):
131
  module.gradient_checkpointing = value
 
140
  self.layers = nn.ModuleList([PlasmidLMDecoderLayer(config) for _ in range(config.num_hidden_layers)])
141
  self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
142
 
143
+ # Lazy RoPE: computed on first forward call to ensure correct device
144
+ # placement after from_pretrained (which uses meta device tensors).
145
+ self.register_buffer("rope_cos", None, persistent=False)
146
+ self.register_buffer("rope_sin", None, persistent=False)
147
 
148
  self.gradient_checkpointing = False
149
  self.post_init()
150
 
151
+ def _init_rope(self, device: torch.device) -> None:
152
+ """Compute and cache RoPE cos/sin on the given device."""
153
+ head_dim = self.config.hidden_size // self.config.num_attention_heads
154
+ cos, sin = _rope_freqs(head_dim, self.config.max_position_embeddings, self.config.rope_theta)
155
+ self.register_buffer("rope_cos", cos.to(device), persistent=False)
156
+ self.register_buffer("rope_sin", sin.to(device), persistent=False)
157
+
158
  def get_input_embeddings(self):
159
  return self.embed_tokens
160
 
161
  def set_input_embeddings(self, value):
162
  self.embed_tokens = value
163
 
164
+ def _build_4d_attention_mask(
165
+ self,
166
+ attention_mask: Optional[torch.Tensor],
167
+ seq_len: int,
168
+ past_seq_len: int,
169
+ device: torch.device,
170
+ dtype: torch.dtype,
171
+ ) -> Optional[torch.Tensor]:
172
+ """Build a 4D causal+padding mask for SDPA.
173
+
174
+ Returns (B, 1, S, S+past) float mask with 0 for attend and -inf for ignore,
175
+ or None if no masking is needed (no padding, no past KV).
176
+ """
177
+ if attention_mask is None and past_seq_len == 0:
178
+ # No padding and no cache — SDPA's is_causal=True handles this
179
+ return None
180
+
181
+ total_len = past_seq_len + seq_len
182
+ # Causal mask: each query position can attend to itself and all prior positions
183
+ causal = torch.triu(
184
+ torch.full((seq_len, total_len), float("-inf"), device=device, dtype=dtype),
185
+ diagonal=past_seq_len + 1,
186
+ ) # (S, S+past)
187
+ mask_4d = causal.unsqueeze(0).unsqueeze(0) # (1, 1, S, S+past)
188
+
189
+ if attention_mask is not None:
190
+ # attention_mask is (B, total_len) with 1=attend, 0=ignore
191
+ # Use a large finite negative instead of -inf for padding mask.
192
+ # With left-padding, the first padding positions can only attend
193
+ # to other padding positions (causal blocks future). If we use
194
+ # -inf, ALL keys are blocked → softmax([-inf,...]) = NaN.
195
+ # Using min_dtype keeps at least the self-attention score finite,
196
+ # so softmax produces a valid (though meaningless) output.
197
+ min_dtype = torch.finfo(dtype).min
198
+ pad_mask = torch.where(
199
+ attention_mask[:, None, None, :].bool(),
200
+ torch.zeros(1, device=device, dtype=dtype),
201
+ torch.tensor(min_dtype, device=device, dtype=dtype),
202
+ ) # (B, 1, 1, total_len)
203
+ mask_4d = mask_4d + pad_mask
204
+
205
+ return mask_4d
206
+
207
  def forward(
208
  self,
209
  input_ids: torch.Tensor,
210
  past_key_values: Optional[list] = None,
211
  position_offset: int = 0,
212
+ attention_mask: Optional[torch.Tensor] = None,
213
  **kwargs,
214
+ ) -> Tuple[torch.Tensor, list, torch.Tensor]:
215
+ # Lazy RoPE init: compute on first forward for correct device placement
216
+ if self.rope_cos is None:
217
+ self._init_rope(input_ids.device)
218
+
219
  hidden_states = self.embed_tokens(input_ids)
220
+
221
+ past_seq_len = past_key_values[0][0].shape[2] if past_key_values else 0
222
+ mask_4d = self._build_4d_attention_mask(
223
+ attention_mask, input_ids.shape[1], past_seq_len,
224
+ input_ids.device, hidden_states.dtype,
225
+ )
226
+
227
  new_kv_caches = []
228
+ total_aux_loss = torch.tensor(0.0, device=input_ids.device)
229
  for i, layer in enumerate(self.layers):
230
  past_kv = past_key_values[i] if past_key_values else None
231
  if self.gradient_checkpointing and self.training:
232
  # Gradient checkpointing recomputes activations on backward — no past_kv during training
233
  def make_ckpt_fn(l):
234
  def fn(h, cos, sin):
235
+ out, kv, aux = l(h, cos, sin, None, 0)
236
+ return out, kv[0], kv[1], aux
237
  return fn
238
+ hidden_states, k, v, layer_aux = torch.utils.checkpoint.checkpoint(
239
  make_ckpt_fn(layer), hidden_states, self.rope_cos, self.rope_sin,
240
  use_reentrant=False,
241
  )
242
  new_kv = (k, v)
243
  else:
244
+ hidden_states, new_kv, layer_aux = layer(
245
+ hidden_states, self.rope_cos, self.rope_sin, past_kv, position_offset, mask_4d
246
+ )
247
  new_kv_caches.append(new_kv)
248
+ total_aux_loss = total_aux_loss + layer_aux
249
  hidden_states = self.norm(hidden_states)
250
+ return hidden_states, new_kv_caches, total_aux_loss
251
 
252
 
253
  class PlasmidLMForCausalLM(PlasmidLMPreTrainedModel, GenerationMixin):
 
290
  return {
291
  "input_ids": input_ids,
292
  "past_key_values": past_key_values,
293
+ "attention_mask": attention_mask,
294
  "use_cache": True,
295
  }
296
 
 
328
  if kv_list is not None:
329
  position_offset = kv_list[0][0].shape[2]
330
 
331
+ hidden_states, new_kv_list, aux_loss = self.model(
332
+ input_ids, kv_list, position_offset, attention_mask=attention_mask
333
+ )
334
  logits = self.lm_head(hidden_states)
335
 
336
  loss = None
 
342
  shift_labels.view(-1),
343
  ignore_index=-100,
344
  )
345
+ loss = loss + self.config.aux_loss_coef * aux_loss
346
 
347
  new_cache = None
348
  if use_cache:
 
363
  top_k: int = 50,
364
  ) -> torch.Tensor:
365
  """Simple autoregressive generation with KV cache."""
366
+ # Prefill (aux_loss ignored during generation)
367
+ hidden_states, kv_caches, _ = self.model(input_ids)
368
  logits = self.lm_head(hidden_states[:, -1:, :]).squeeze(1)
369
  cur_len = input_ids.shape[1]
370
 
 
379
  next_token = torch.multinomial(probs, 1)
380
  input_ids = torch.cat([input_ids, next_token], dim=1)
381
 
382
+ hidden_states, kv_caches, _ = self.model(next_token, kv_caches, cur_len)
383
  logits = self.lm_head(hidden_states).squeeze(1)
384
  cur_len += 1
385
 
moe.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Mixture of Experts (Mixtral-style) for PlasmidLM."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from .configuration_plasmid_lm import PlasmidLMConfig
10
+
11
+
12
+ class PlasmidLMExpertMLP(nn.Module):
13
+ """Single expert MLP — same architecture as PlasmidLMMLP."""
14
+
15
+ def __init__(self, hidden_size: int, intermediate_size: int):
16
+ super().__init__()
17
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
18
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
19
+ self.act = nn.GELU()
20
+
21
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
22
+ return self.down_proj(self.act(self.up_proj(x)))
23
+
24
+
25
+ class PlasmidLMSparseMoE(nn.Module):
26
+ """Sparse Mixture of Experts with top-k routing and load balancing loss.
27
+
28
+ Implements Mixtral-style routing: softmax over all experts, then select
29
+ top-k. The output is a weighted sum of the selected experts' outputs.
30
+
31
+ Load balancing auxiliary loss: L_aux = N * sum(f_i * P_i) where
32
+ f_i = fraction of tokens routed to expert i, P_i = mean routing
33
+ probability for expert i.
34
+ """
35
+
36
+ def __init__(self, config: PlasmidLMConfig):
37
+ super().__init__()
38
+ self.num_experts = config.num_experts
39
+ self.top_k = config.num_experts_per_tok
40
+ intermediate = config.moe_intermediate_size
41
+
42
+ self.router = nn.Linear(config.hidden_size, self.num_experts, bias=False)
43
+ self.experts = nn.ModuleList(
44
+ [PlasmidLMExpertMLP(config.hidden_size, intermediate) for _ in range(self.num_experts)]
45
+ )
46
+
47
+ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
48
+ """
49
+ Args:
50
+ hidden_states: (batch, seq_len, hidden_size)
51
+
52
+ Returns:
53
+ output: (batch, seq_len, hidden_size)
54
+ aux_loss: scalar load balancing loss
55
+ """
56
+ batch, seq_len, hidden = hidden_states.shape
57
+ flat = hidden_states.view(-1, hidden) # (B*S, H)
58
+ num_tokens = flat.shape[0]
59
+
60
+ # Router: compute probabilities over experts
61
+ router_logits = self.router(flat) # (B*S, num_experts)
62
+ router_probs = F.softmax(router_logits, dim=-1)
63
+
64
+ # Top-k selection
65
+ top_weights, top_indices = torch.topk(router_probs, self.top_k, dim=-1) # (B*S, top_k)
66
+ # Normalize selected weights to sum to 1
67
+ top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)
68
+
69
+ # Dispatch: loop over experts, gather tokens, compute, scatter back
70
+ output = torch.zeros_like(flat)
71
+ for i, expert in enumerate(self.experts):
72
+ # Mask for tokens where expert i is in the top-k
73
+ mask = (top_indices == i).any(dim=-1) # (B*S,)
74
+ if not mask.any():
75
+ continue
76
+ expert_input = flat[mask] # (n_tokens, H)
77
+ expert_output = expert(expert_input) # (n_tokens, H)
78
+ # Weight for this expert for selected tokens
79
+ # Find which top-k slot(s) matched and get corresponding weight
80
+ match_positions = (top_indices[mask] == i) # (n_tokens, top_k)
81
+ weights = (top_weights[mask] * match_positions.float()).sum(dim=-1, keepdim=True) # (n_tokens, 1)
82
+ output[mask] += weights * expert_output
83
+
84
+ output = output.view(batch, seq_len, hidden)
85
+
86
+ # Load balancing auxiliary loss
87
+ # f_i: fraction of tokens dispatched to expert i
88
+ # P_i: mean routing probability assigned to expert i
89
+ with torch.no_grad():
90
+ # Count tokens per expert (based on top-k assignments)
91
+ expert_counts = torch.zeros(self.num_experts, device=flat.device)
92
+ for k in range(self.top_k):
93
+ expert_counts.scatter_add_(0, top_indices[:, k], torch.ones(num_tokens, device=flat.device))
94
+ f = expert_counts / (num_tokens * self.top_k) # fraction
95
+
96
+ P = router_probs.mean(dim=0) # (num_experts,)
97
+ aux_loss = self.num_experts * (f * P).sum()
98
+
99
+ return output, aux_loss
tokenization_plasmid_lm.py CHANGED
@@ -62,6 +62,22 @@ class PlasmidLMTokenizer(PreTrainedTokenizer):
62
  def vocab_size(self) -> int:
63
  return len(self._vocab)
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def get_vocab(self) -> dict:
66
  return dict(self._vocab)
67
 
 
62
  def vocab_size(self) -> int:
63
  return len(self._vocab)
64
 
65
+ @property
66
+ def pad_token_id(self) -> int:
67
+ return self._vocab.get("<PAD>", 0)
68
+
69
+ @property
70
+ def bos_token_id(self) -> int:
71
+ return self._vocab.get("<BOS>", 1)
72
+
73
+ @property
74
+ def eos_token_id(self) -> int:
75
+ return self._vocab.get("<EOS>", 2)
76
+
77
+ @property
78
+ def sep_token_id(self) -> int:
79
+ return self._vocab.get("<SEP>", 3)
80
+
81
  def get_vocab(self) -> dict:
82
  return dict(self._vocab)
83