Upload folder using huggingface_hub
Browse files- README.md +48 -73
- config.json +4 -1
- configuration_plasmid_lm.py +20 -0
- modeling_plasmid_lm.py +106 -32
- moe.py +99 -0
- tokenization_plasmid_lm.py +16 -0
README.md
CHANGED
|
@@ -1,74 +1,42 @@
|
|
| 1 |
---
|
| 2 |
-
language:
|
| 3 |
-
- en
|
| 4 |
-
license: apache-2.0
|
| 5 |
library_name: transformers
|
|
|
|
| 6 |
tags:
|
| 7 |
-
- biology
|
| 8 |
-
- genomics
|
| 9 |
-
-
|
| 10 |
-
-
|
| 11 |
-
-
|
| 12 |
-
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
- custom
|
| 16 |
pipeline_tag: text-generation
|
| 17 |
-
model-index:
|
| 18 |
-
- name: PlasmidLM
|
| 19 |
-
results:
|
| 20 |
-
- task:
|
| 21 |
-
type: text-generation
|
| 22 |
-
name: Plasmid DNA Generation
|
| 23 |
-
metrics:
|
| 24 |
-
- name: Eval Loss
|
| 25 |
-
type: loss
|
| 26 |
-
value: 0.093
|
| 27 |
-
- name: Token Accuracy
|
| 28 |
-
type: accuracy
|
| 29 |
-
value: 0.961
|
| 30 |
---
|
| 31 |
|
| 32 |
# PlasmidLM
|
| 33 |
|
| 34 |
-
A
|
| 35 |
|
| 36 |
-
## Model
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
| Parameter | Value |
|
| 43 |
-
|-----------|-------|
|
| 44 |
-
| Parameters | 17M |
|
| 45 |
| Hidden size | 384 |
|
| 46 |
| Layers | 10 |
|
| 47 |
| Attention heads | 8 |
|
| 48 |
-
|
|
| 49 |
-
|
|
| 50 |
-
|
| 51 |
-
|
| 52 |
|
| 53 |
-
## Training
|
| 54 |
-
|
| 55 |
-
Pretrained with causal language modeling on ~108K plasmid sequences derived from the [Addgene](https://www.addgene.org/) repository, annotated with functional components via [pLannotate](https://github.com/barricklab/pLannotate).
|
| 56 |
|
|
|
|
| 57 |
- **Steps**: 15,000
|
| 58 |
-
- **Epochs**: ~2.3
|
| 59 |
- **Eval loss**: 0.093
|
| 60 |
- **Token accuracy**: 96.1%
|
| 61 |
-
- **Optimizer**: AdamW
|
| 62 |
-
- **Precision**: bf16
|
| 63 |
-
|
| 64 |
-
## Intended Use
|
| 65 |
-
|
| 66 |
-
This is a **base pretrained model**. It has learned the statistical patterns of plasmid DNA sequences and their relationship to categorical component tokens. It can be used for:
|
| 67 |
-
|
| 68 |
-
- **Direct generation**: Prompt with component tokens to generate plasmid sequences
|
| 69 |
-
- **Fine-tuning**: Post-train with reinforcement learning (GRPO/PPO) to improve motif placement accuracy
|
| 70 |
-
- **Embeddings**: Use hidden states as learned representations of plasmid sequences
|
| 71 |
-
- **Research**: Study the learned structure of synthetic DNA
|
| 72 |
|
| 73 |
## Usage
|
| 74 |
|
|
@@ -78,14 +46,15 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
| 78 |
model = AutoModelForCausalLM.from_pretrained("McClain/PlasmidLM", trust_remote_code=True)
|
| 79 |
tokenizer = AutoTokenizer.from_pretrained("McClain/PlasmidLM", trust_remote_code=True)
|
| 80 |
|
| 81 |
-
#
|
| 82 |
prompt = "<BOS><AMR_KANAMYCIN><ORI_COLE1><SEP>"
|
| 83 |
inputs = tokenizer(prompt, return_tensors="pt")
|
| 84 |
-
outputs = model.generate(**inputs, max_new_tokens=4096, temperature=0.8, do_sample=True)
|
| 85 |
-
|
| 86 |
-
print(sequence)
|
| 87 |
```
|
| 88 |
|
|
|
|
|
|
|
| 89 |
## Input Format
|
| 90 |
|
| 91 |
```
|
|
@@ -94,31 +63,37 @@ print(sequence)
|
|
| 94 |
|
| 95 |
The model generates DNA bases (A/T/C/G) after the `<SEP>` token until it produces `<EOS>` or hits the maximum length.
|
| 96 |
|
| 97 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
| Antibiotic Resistance (AMR) | Kanamycin, Ampicillin, Chloramphenicol, ... | 11 |
|
| 102 |
-
| Origin of Replication (ORI) | ColE1, F1, P15A, pSC101, SV40, ... | 7 |
|
| 103 |
-
| Promoter (PROM) | CMV, T7, U6, EF1a, CAG, ... | 11 |
|
| 104 |
-
| Reporter | EGFP, mCherry, YFP, NanoLuc, ... | 6 |
|
| 105 |
-
| Vector Type (VEC) | Lentiviral, CRISPR, Bacterial, AAV, ... | 10 |
|
| 106 |
-
| Other | Tags, elements, species, backbones | ~55 |
|
| 107 |
|
| 108 |
## Limitations
|
| 109 |
|
| 110 |
-
- This is a **pretrained base model** --
|
| 111 |
-
- Generated sequences are **not experimentally validated**. Always verify computationally
|
| 112 |
-
-
|
| 113 |
-
- Maximum context of 16K tokens (~16 kbp)
|
| 114 |
|
| 115 |
## Citation
|
| 116 |
|
| 117 |
```bibtex
|
| 118 |
@misc{thiel2026plasmidlm,
|
| 119 |
-
title={PlasmidLM: Language Models for
|
| 120 |
author={Thiel, McClain},
|
| 121 |
-
year={2026}
|
| 122 |
-
url={https://huggingface.co/McClain/PlasmidLM}
|
| 123 |
}
|
| 124 |
```
|
|
|
|
| 1 |
---
|
|
|
|
|
|
|
|
|
|
| 2 |
library_name: transformers
|
| 3 |
+
license: apache-2.0
|
| 4 |
tags:
|
| 5 |
+
- biology
|
| 6 |
+
- genomics
|
| 7 |
+
- plasmid
|
| 8 |
+
- dna
|
| 9 |
+
- causal-lm
|
| 10 |
+
- synthetic-biology
|
| 11 |
+
language:
|
| 12 |
+
- en
|
|
|
|
| 13 |
pipeline_tag: text-generation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
---
|
| 15 |
|
| 16 |
# PlasmidLM
|
| 17 |
|
| 18 |
+
A 17.7M parameter autoregressive language model for **plasmid DNA sequence generation**, trained on ~108K plasmid sequences from Addgene.
|
| 19 |
|
| 20 |
+
## Model Details
|
| 21 |
|
| 22 |
+
| Property | Value |
|
| 23 |
+
|---|---|
|
| 24 |
+
| Parameters | 17.7M |
|
| 25 |
+
| Architecture | Transformer decoder (dense MLP), LLaMA-style |
|
|
|
|
|
|
|
|
|
|
| 26 |
| Hidden size | 384 |
|
| 27 |
| Layers | 10 |
|
| 28 |
| Attention heads | 8 |
|
| 29 |
+
| Intermediate size | 1,536 |
|
| 30 |
+
| Max sequence length | 16,384 tokens |
|
| 31 |
+
| Tokenizer | Character-level (single DNA bases) |
|
| 32 |
+
| Vocab size | 120 |
|
| 33 |
|
| 34 |
+
### Training
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
- **Data**: ~108K plasmid sequences from Addgene, annotated with functional components via pLannotate
|
| 37 |
- **Steps**: 15,000
|
|
|
|
| 38 |
- **Eval loss**: 0.093
|
| 39 |
- **Token accuracy**: 96.1%
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
## Usage
|
| 42 |
|
|
|
|
| 46 |
model = AutoModelForCausalLM.from_pretrained("McClain/PlasmidLM", trust_remote_code=True)
|
| 47 |
tokenizer = AutoTokenizer.from_pretrained("McClain/PlasmidLM", trust_remote_code=True)
|
| 48 |
|
| 49 |
+
# Condition on antibiotic resistance + origin of replication
|
| 50 |
prompt = "<BOS><AMR_KANAMYCIN><ORI_COLE1><SEP>"
|
| 51 |
inputs = tokenizer(prompt, return_tensors="pt")
|
| 52 |
+
outputs = model.generate(**inputs, max_new_tokens=4096, temperature=0.8, do_sample=True, top_p=0.95)
|
| 53 |
+
print(tokenizer.decode(outputs[0].tolist()))
|
|
|
|
| 54 |
```
|
| 55 |
|
| 56 |
+
The model generates plasmid DNA sequences conditioned on functional annotations (antibiotic resistance markers, origins of replication, promoters, reporters, etc.) provided as special tokens in the prompt.
|
| 57 |
+
|
| 58 |
## Input Format
|
| 59 |
|
| 60 |
```
|
|
|
|
| 63 |
|
| 64 |
The model generates DNA bases (A/T/C/G) after the `<SEP>` token until it produces `<EOS>` or hits the maximum length.
|
| 65 |
|
| 66 |
+
## Special Tokens
|
| 67 |
+
|
| 68 |
+
| Token | Purpose |
|
| 69 |
+
|---|---|
|
| 70 |
+
| `<BOS>` | Beginning of sequence |
|
| 71 |
+
| `<EOS>` | End of sequence |
|
| 72 |
+
| `<SEP>` | Separator between prompt annotations and DNA sequence |
|
| 73 |
+
| `<PAD>` | Padding |
|
| 74 |
+
| `<AMR_*>` | Antibiotic resistance markers (e.g., `<AMR_KANAMYCIN>`, `<AMR_AMPICILLIN>`) |
|
| 75 |
+
| `<ORI_*>` | Origins of replication (e.g., `<ORI_COLE1>`, `<ORI_P15A>`) |
|
| 76 |
+
| `<PROM_*>` | Promoters (e.g., `<PROM_CMV>`, `<PROM_T7>`) |
|
| 77 |
+
| `<REP_*>` | Reporters (e.g., `<REP_EGFP>`, `<REP_MCHERRY>`) |
|
| 78 |
+
|
| 79 |
+
## Related Models
|
| 80 |
|
| 81 |
+
- [McClain/PlasmidLM-kmer6](https://huggingface.co/McClain/PlasmidLM-kmer6) — kmer6 tokenizer, 19.3M params, dense
|
| 82 |
+
- [McClain/PlasmidLM-kmer6-MoE](https://huggingface.co/McClain/PlasmidLM-kmer6-MoE) — kmer6 tokenizer, 78.3M total params, Mixture-of-Experts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
## Limitations
|
| 85 |
|
| 86 |
+
- This is a **pretrained base model** -- generated sequences are not optimized for functional element placement. Post-training with RL improves fidelity.
|
| 87 |
+
- Generated sequences are **not experimentally validated**. Always verify computationally and experimentally before synthesis.
|
| 88 |
+
- Trained on Addgene plasmids, which are biased toward commonly deposited vectors.
|
| 89 |
+
- Maximum context of 16K tokens (~16 kbp).
|
| 90 |
|
| 91 |
## Citation
|
| 92 |
|
| 93 |
```bibtex
|
| 94 |
@misc{thiel2026plasmidlm,
|
| 95 |
+
title={PlasmidLM: Language Models for Plasmid DNA Generation},
|
| 96 |
author={Thiel, McClain},
|
| 97 |
+
year={2026}
|
|
|
|
| 98 |
}
|
| 99 |
```
|
config.json
CHANGED
|
@@ -25,5 +25,8 @@
|
|
| 25 |
"tokenization_plasmid_lm.PlasmidLMTokenizer",
|
| 26 |
null
|
| 27 |
]
|
| 28 |
-
}
|
|
|
|
|
|
|
|
|
|
| 29 |
}
|
|
|
|
| 25 |
"tokenization_plasmid_lm.PlasmidLMTokenizer",
|
| 26 |
null
|
| 27 |
]
|
| 28 |
+
},
|
| 29 |
+
"use_moe": false,
|
| 30 |
+
"tie_word_embeddings": true,
|
| 31 |
+
"use_cache": false
|
| 32 |
}
|
configuration_plasmid_lm.py
CHANGED
|
@@ -18,6 +18,16 @@ class PlasmidLMConfig(PretrainedConfig):
|
|
| 18 |
max_position_embeddings: int = 16384,
|
| 19 |
rope_theta: float = 10000.0,
|
| 20 |
tie_word_embeddings: bool = True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
**kwargs,
|
| 22 |
):
|
| 23 |
self.hidden_size = hidden_size
|
|
@@ -28,6 +38,16 @@ class PlasmidLMConfig(PretrainedConfig):
|
|
| 28 |
self.rms_norm_eps = rms_norm_eps
|
| 29 |
self.max_position_embeddings = max_position_embeddings
|
| 30 |
self.rope_theta = rope_theta
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
super().__init__(
|
| 32 |
vocab_size=vocab_size,
|
| 33 |
tie_word_embeddings=tie_word_embeddings,
|
|
|
|
| 18 |
max_position_embeddings: int = 16384,
|
| 19 |
rope_theta: float = 10000.0,
|
| 20 |
tie_word_embeddings: bool = True,
|
| 21 |
+
# MoE
|
| 22 |
+
use_moe: bool = False,
|
| 23 |
+
num_experts: int = 6,
|
| 24 |
+
num_experts_per_tok: int = 2,
|
| 25 |
+
moe_intermediate_size: int | None = None,
|
| 26 |
+
aux_loss_coef: float = 0.01,
|
| 27 |
+
# Tokenizer metadata (informational, saved in checkpoint)
|
| 28 |
+
tokenizer_type: str = "char",
|
| 29 |
+
kmer_k: int | None = None,
|
| 30 |
+
kmer_stride: int | None = None,
|
| 31 |
**kwargs,
|
| 32 |
):
|
| 33 |
self.hidden_size = hidden_size
|
|
|
|
| 38 |
self.rms_norm_eps = rms_norm_eps
|
| 39 |
self.max_position_embeddings = max_position_embeddings
|
| 40 |
self.rope_theta = rope_theta
|
| 41 |
+
# MoE
|
| 42 |
+
self.use_moe = use_moe
|
| 43 |
+
self.num_experts = num_experts
|
| 44 |
+
self.num_experts_per_tok = num_experts_per_tok
|
| 45 |
+
self.moe_intermediate_size = moe_intermediate_size or intermediate_size
|
| 46 |
+
self.aux_loss_coef = aux_loss_coef
|
| 47 |
+
# Tokenizer metadata
|
| 48 |
+
self.tokenizer_type = tokenizer_type
|
| 49 |
+
self.kmer_k = kmer_k
|
| 50 |
+
self.kmer_stride = kmer_stride
|
| 51 |
super().__init__(
|
| 52 |
vocab_size=vocab_size,
|
| 53 |
tie_word_embeddings=tie_word_embeddings,
|
modeling_plasmid_lm.py
CHANGED
|
@@ -14,6 +14,7 @@ from transformers.generation import GenerationMixin
|
|
| 14 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 15 |
|
| 16 |
from .configuration_plasmid_lm import PlasmidLMConfig
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
def _rope_freqs(dim: int, max_len: int, base: float) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
@@ -48,6 +49,7 @@ class PlasmidLMAttention(nn.Module):
|
|
| 48 |
rope_sin: torch.Tensor,
|
| 49 |
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 50 |
position_offset: int = 0,
|
|
|
|
| 51 |
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 52 |
B, S, _ = hidden_states.shape
|
| 53 |
q = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
|
|
@@ -63,8 +65,11 @@ class PlasmidLMAttention(nn.Module):
|
|
| 63 |
v = torch.cat([past_key_value[1], v], dim=2)
|
| 64 |
new_kv = (k, v)
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
| 68 |
out = attn.transpose(1, 2).reshape(B, S, -1)
|
| 69 |
return self.o_proj(out), new_kv
|
| 70 |
|
|
@@ -86,7 +91,11 @@ class PlasmidLMDecoderLayer(nn.Module):
|
|
| 86 |
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 87 |
self.self_attn = PlasmidLMAttention(config)
|
| 88 |
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 89 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
def forward(
|
| 92 |
self,
|
|
@@ -95,15 +104,21 @@ class PlasmidLMDecoderLayer(nn.Module):
|
|
| 95 |
rope_sin: torch.Tensor,
|
| 96 |
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 97 |
position_offset: int = 0,
|
| 98 |
-
|
|
|
|
| 99 |
residual = hidden_states
|
| 100 |
hidden_states = self.input_layernorm(hidden_states)
|
| 101 |
-
attn_out, new_kv = self.self_attn(hidden_states, rope_cos, rope_sin, past_key_value, position_offset)
|
| 102 |
hidden_states = residual + attn_out
|
| 103 |
|
| 104 |
residual = hidden_states
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
|
| 109 |
class PlasmidLMPreTrainedModel(PreTrainedModel):
|
|
@@ -111,17 +126,6 @@ class PlasmidLMPreTrainedModel(PreTrainedModel):
|
|
| 111 |
base_model_prefix = "model"
|
| 112 |
supports_gradient_checkpointing = True
|
| 113 |
|
| 114 |
-
def _init_weights(self, module):
|
| 115 |
-
if isinstance(module, PlasmidLMModel):
|
| 116 |
-
# Recompute RoPE buffers — they are non-persistent so not saved in
|
| 117 |
-
# safetensors. from_pretrained's fast-init path zeros them out.
|
| 118 |
-
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
| 119 |
-
cos, sin = _rope_freqs(
|
| 120 |
-
head_dim, self.config.max_position_embeddings, self.config.rope_theta
|
| 121 |
-
)
|
| 122 |
-
module.rope_cos = cos
|
| 123 |
-
module.rope_sin = sin
|
| 124 |
-
|
| 125 |
def _set_gradient_checkpointing(self, module, value=False):
|
| 126 |
if isinstance(module, PlasmidLMModel):
|
| 127 |
module.gradient_checkpointing = value
|
|
@@ -136,48 +140,114 @@ class PlasmidLMModel(PlasmidLMPreTrainedModel):
|
|
| 136 |
self.layers = nn.ModuleList([PlasmidLMDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 137 |
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 138 |
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
self.register_buffer("rope_cos",
|
| 142 |
-
self.register_buffer("rope_sin",
|
| 143 |
|
| 144 |
self.gradient_checkpointing = False
|
| 145 |
self.post_init()
|
| 146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
def get_input_embeddings(self):
|
| 148 |
return self.embed_tokens
|
| 149 |
|
| 150 |
def set_input_embeddings(self, value):
|
| 151 |
self.embed_tokens = value
|
| 152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
def forward(
|
| 154 |
self,
|
| 155 |
input_ids: torch.Tensor,
|
| 156 |
past_key_values: Optional[list] = None,
|
| 157 |
position_offset: int = 0,
|
|
|
|
| 158 |
**kwargs,
|
| 159 |
-
) -> Tuple[torch.Tensor, list]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
hidden_states = self.embed_tokens(input_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
new_kv_caches = []
|
|
|
|
| 162 |
for i, layer in enumerate(self.layers):
|
| 163 |
past_kv = past_key_values[i] if past_key_values else None
|
| 164 |
if self.gradient_checkpointing and self.training:
|
| 165 |
# Gradient checkpointing recomputes activations on backward — no past_kv during training
|
| 166 |
def make_ckpt_fn(l):
|
| 167 |
def fn(h, cos, sin):
|
| 168 |
-
out, kv = l(h, cos, sin, None, 0)
|
| 169 |
-
return out, kv[0], kv[1]
|
| 170 |
return fn
|
| 171 |
-
hidden_states, k, v = torch.utils.checkpoint.checkpoint(
|
| 172 |
make_ckpt_fn(layer), hidden_states, self.rope_cos, self.rope_sin,
|
| 173 |
use_reentrant=False,
|
| 174 |
)
|
| 175 |
new_kv = (k, v)
|
| 176 |
else:
|
| 177 |
-
hidden_states, new_kv = layer(
|
|
|
|
|
|
|
| 178 |
new_kv_caches.append(new_kv)
|
|
|
|
| 179 |
hidden_states = self.norm(hidden_states)
|
| 180 |
-
return hidden_states, new_kv_caches
|
| 181 |
|
| 182 |
|
| 183 |
class PlasmidLMForCausalLM(PlasmidLMPreTrainedModel, GenerationMixin):
|
|
@@ -220,6 +290,7 @@ class PlasmidLMForCausalLM(PlasmidLMPreTrainedModel, GenerationMixin):
|
|
| 220 |
return {
|
| 221 |
"input_ids": input_ids,
|
| 222 |
"past_key_values": past_key_values,
|
|
|
|
| 223 |
"use_cache": True,
|
| 224 |
}
|
| 225 |
|
|
@@ -257,7 +328,9 @@ class PlasmidLMForCausalLM(PlasmidLMPreTrainedModel, GenerationMixin):
|
|
| 257 |
if kv_list is not None:
|
| 258 |
position_offset = kv_list[0][0].shape[2]
|
| 259 |
|
| 260 |
-
hidden_states, new_kv_list = self.model(
|
|
|
|
|
|
|
| 261 |
logits = self.lm_head(hidden_states)
|
| 262 |
|
| 263 |
loss = None
|
|
@@ -269,6 +342,7 @@ class PlasmidLMForCausalLM(PlasmidLMPreTrainedModel, GenerationMixin):
|
|
| 269 |
shift_labels.view(-1),
|
| 270 |
ignore_index=-100,
|
| 271 |
)
|
|
|
|
| 272 |
|
| 273 |
new_cache = None
|
| 274 |
if use_cache:
|
|
@@ -289,8 +363,8 @@ class PlasmidLMForCausalLM(PlasmidLMPreTrainedModel, GenerationMixin):
|
|
| 289 |
top_k: int = 50,
|
| 290 |
) -> torch.Tensor:
|
| 291 |
"""Simple autoregressive generation with KV cache."""
|
| 292 |
-
# Prefill
|
| 293 |
-
hidden_states, kv_caches = self.model(input_ids)
|
| 294 |
logits = self.lm_head(hidden_states[:, -1:, :]).squeeze(1)
|
| 295 |
cur_len = input_ids.shape[1]
|
| 296 |
|
|
@@ -305,7 +379,7 @@ class PlasmidLMForCausalLM(PlasmidLMPreTrainedModel, GenerationMixin):
|
|
| 305 |
next_token = torch.multinomial(probs, 1)
|
| 306 |
input_ids = torch.cat([input_ids, next_token], dim=1)
|
| 307 |
|
| 308 |
-
hidden_states, kv_caches = self.model(next_token, kv_caches, cur_len)
|
| 309 |
logits = self.lm_head(hidden_states).squeeze(1)
|
| 310 |
cur_len += 1
|
| 311 |
|
|
|
|
| 14 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 15 |
|
| 16 |
from .configuration_plasmid_lm import PlasmidLMConfig
|
| 17 |
+
from .moe import PlasmidLMSparseMoE
|
| 18 |
|
| 19 |
|
| 20 |
def _rope_freqs(dim: int, max_len: int, base: float) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
| 49 |
rope_sin: torch.Tensor,
|
| 50 |
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 51 |
position_offset: int = 0,
|
| 52 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 53 |
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 54 |
B, S, _ = hidden_states.shape
|
| 55 |
q = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
| 65 |
v = torch.cat([past_key_value[1], v], dim=2)
|
| 66 |
new_kv = (k, v)
|
| 67 |
|
| 68 |
+
if attention_mask is not None:
|
| 69 |
+
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)
|
| 70 |
+
else:
|
| 71 |
+
use_causal = past_key_value is None
|
| 72 |
+
attn = F.scaled_dot_product_attention(q, k, v, is_causal=use_causal)
|
| 73 |
out = attn.transpose(1, 2).reshape(B, S, -1)
|
| 74 |
return self.o_proj(out), new_kv
|
| 75 |
|
|
|
|
| 91 |
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 92 |
self.self_attn = PlasmidLMAttention(config)
|
| 93 |
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 94 |
+
self.use_moe = config.use_moe
|
| 95 |
+
if self.use_moe:
|
| 96 |
+
self.moe = PlasmidLMSparseMoE(config)
|
| 97 |
+
else:
|
| 98 |
+
self.mlp = PlasmidLMMLP(config)
|
| 99 |
|
| 100 |
def forward(
|
| 101 |
self,
|
|
|
|
| 104 |
rope_sin: torch.Tensor,
|
| 105 |
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 106 |
position_offset: int = 0,
|
| 107 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 108 |
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
|
| 109 |
residual = hidden_states
|
| 110 |
hidden_states = self.input_layernorm(hidden_states)
|
| 111 |
+
attn_out, new_kv = self.self_attn(hidden_states, rope_cos, rope_sin, past_key_value, position_offset, attention_mask)
|
| 112 |
hidden_states = residual + attn_out
|
| 113 |
|
| 114 |
residual = hidden_states
|
| 115 |
+
if self.use_moe:
|
| 116 |
+
moe_out, aux_loss = self.moe(self.post_attention_layernorm(hidden_states))
|
| 117 |
+
hidden_states = residual + moe_out
|
| 118 |
+
else:
|
| 119 |
+
hidden_states = residual + self.mlp(self.post_attention_layernorm(hidden_states))
|
| 120 |
+
aux_loss = torch.tensor(0.0, device=hidden_states.device)
|
| 121 |
+
return hidden_states, new_kv, aux_loss
|
| 122 |
|
| 123 |
|
| 124 |
class PlasmidLMPreTrainedModel(PreTrainedModel):
|
|
|
|
| 126 |
base_model_prefix = "model"
|
| 127 |
supports_gradient_checkpointing = True
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
def _set_gradient_checkpointing(self, module, value=False):
|
| 130 |
if isinstance(module, PlasmidLMModel):
|
| 131 |
module.gradient_checkpointing = value
|
|
|
|
| 140 |
self.layers = nn.ModuleList([PlasmidLMDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 141 |
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 142 |
|
| 143 |
+
# Lazy RoPE: computed on first forward call to ensure correct device
|
| 144 |
+
# placement after from_pretrained (which uses meta device tensors).
|
| 145 |
+
self.register_buffer("rope_cos", None, persistent=False)
|
| 146 |
+
self.register_buffer("rope_sin", None, persistent=False)
|
| 147 |
|
| 148 |
self.gradient_checkpointing = False
|
| 149 |
self.post_init()
|
| 150 |
|
| 151 |
+
def _init_rope(self, device: torch.device) -> None:
|
| 152 |
+
"""Compute and cache RoPE cos/sin on the given device."""
|
| 153 |
+
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
| 154 |
+
cos, sin = _rope_freqs(head_dim, self.config.max_position_embeddings, self.config.rope_theta)
|
| 155 |
+
self.register_buffer("rope_cos", cos.to(device), persistent=False)
|
| 156 |
+
self.register_buffer("rope_sin", sin.to(device), persistent=False)
|
| 157 |
+
|
| 158 |
def get_input_embeddings(self):
|
| 159 |
return self.embed_tokens
|
| 160 |
|
| 161 |
def set_input_embeddings(self, value):
|
| 162 |
self.embed_tokens = value
|
| 163 |
|
| 164 |
+
def _build_4d_attention_mask(
|
| 165 |
+
self,
|
| 166 |
+
attention_mask: Optional[torch.Tensor],
|
| 167 |
+
seq_len: int,
|
| 168 |
+
past_seq_len: int,
|
| 169 |
+
device: torch.device,
|
| 170 |
+
dtype: torch.dtype,
|
| 171 |
+
) -> Optional[torch.Tensor]:
|
| 172 |
+
"""Build a 4D causal+padding mask for SDPA.
|
| 173 |
+
|
| 174 |
+
Returns (B, 1, S, S+past) float mask with 0 for attend and -inf for ignore,
|
| 175 |
+
or None if no masking is needed (no padding, no past KV).
|
| 176 |
+
"""
|
| 177 |
+
if attention_mask is None and past_seq_len == 0:
|
| 178 |
+
# No padding and no cache — SDPA's is_causal=True handles this
|
| 179 |
+
return None
|
| 180 |
+
|
| 181 |
+
total_len = past_seq_len + seq_len
|
| 182 |
+
# Causal mask: each query position can attend to itself and all prior positions
|
| 183 |
+
causal = torch.triu(
|
| 184 |
+
torch.full((seq_len, total_len), float("-inf"), device=device, dtype=dtype),
|
| 185 |
+
diagonal=past_seq_len + 1,
|
| 186 |
+
) # (S, S+past)
|
| 187 |
+
mask_4d = causal.unsqueeze(0).unsqueeze(0) # (1, 1, S, S+past)
|
| 188 |
+
|
| 189 |
+
if attention_mask is not None:
|
| 190 |
+
# attention_mask is (B, total_len) with 1=attend, 0=ignore
|
| 191 |
+
# Use a large finite negative instead of -inf for padding mask.
|
| 192 |
+
# With left-padding, the first padding positions can only attend
|
| 193 |
+
# to other padding positions (causal blocks future). If we use
|
| 194 |
+
# -inf, ALL keys are blocked → softmax([-inf,...]) = NaN.
|
| 195 |
+
# Using min_dtype keeps at least the self-attention score finite,
|
| 196 |
+
# so softmax produces a valid (though meaningless) output.
|
| 197 |
+
min_dtype = torch.finfo(dtype).min
|
| 198 |
+
pad_mask = torch.where(
|
| 199 |
+
attention_mask[:, None, None, :].bool(),
|
| 200 |
+
torch.zeros(1, device=device, dtype=dtype),
|
| 201 |
+
torch.tensor(min_dtype, device=device, dtype=dtype),
|
| 202 |
+
) # (B, 1, 1, total_len)
|
| 203 |
+
mask_4d = mask_4d + pad_mask
|
| 204 |
+
|
| 205 |
+
return mask_4d
|
| 206 |
+
|
| 207 |
def forward(
|
| 208 |
self,
|
| 209 |
input_ids: torch.Tensor,
|
| 210 |
past_key_values: Optional[list] = None,
|
| 211 |
position_offset: int = 0,
|
| 212 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 213 |
**kwargs,
|
| 214 |
+
) -> Tuple[torch.Tensor, list, torch.Tensor]:
|
| 215 |
+
# Lazy RoPE init: compute on first forward for correct device placement
|
| 216 |
+
if self.rope_cos is None:
|
| 217 |
+
self._init_rope(input_ids.device)
|
| 218 |
+
|
| 219 |
hidden_states = self.embed_tokens(input_ids)
|
| 220 |
+
|
| 221 |
+
past_seq_len = past_key_values[0][0].shape[2] if past_key_values else 0
|
| 222 |
+
mask_4d = self._build_4d_attention_mask(
|
| 223 |
+
attention_mask, input_ids.shape[1], past_seq_len,
|
| 224 |
+
input_ids.device, hidden_states.dtype,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
new_kv_caches = []
|
| 228 |
+
total_aux_loss = torch.tensor(0.0, device=input_ids.device)
|
| 229 |
for i, layer in enumerate(self.layers):
|
| 230 |
past_kv = past_key_values[i] if past_key_values else None
|
| 231 |
if self.gradient_checkpointing and self.training:
|
| 232 |
# Gradient checkpointing recomputes activations on backward — no past_kv during training
|
| 233 |
def make_ckpt_fn(l):
|
| 234 |
def fn(h, cos, sin):
|
| 235 |
+
out, kv, aux = l(h, cos, sin, None, 0)
|
| 236 |
+
return out, kv[0], kv[1], aux
|
| 237 |
return fn
|
| 238 |
+
hidden_states, k, v, layer_aux = torch.utils.checkpoint.checkpoint(
|
| 239 |
make_ckpt_fn(layer), hidden_states, self.rope_cos, self.rope_sin,
|
| 240 |
use_reentrant=False,
|
| 241 |
)
|
| 242 |
new_kv = (k, v)
|
| 243 |
else:
|
| 244 |
+
hidden_states, new_kv, layer_aux = layer(
|
| 245 |
+
hidden_states, self.rope_cos, self.rope_sin, past_kv, position_offset, mask_4d
|
| 246 |
+
)
|
| 247 |
new_kv_caches.append(new_kv)
|
| 248 |
+
total_aux_loss = total_aux_loss + layer_aux
|
| 249 |
hidden_states = self.norm(hidden_states)
|
| 250 |
+
return hidden_states, new_kv_caches, total_aux_loss
|
| 251 |
|
| 252 |
|
| 253 |
class PlasmidLMForCausalLM(PlasmidLMPreTrainedModel, GenerationMixin):
|
|
|
|
| 290 |
return {
|
| 291 |
"input_ids": input_ids,
|
| 292 |
"past_key_values": past_key_values,
|
| 293 |
+
"attention_mask": attention_mask,
|
| 294 |
"use_cache": True,
|
| 295 |
}
|
| 296 |
|
|
|
|
| 328 |
if kv_list is not None:
|
| 329 |
position_offset = kv_list[0][0].shape[2]
|
| 330 |
|
| 331 |
+
hidden_states, new_kv_list, aux_loss = self.model(
|
| 332 |
+
input_ids, kv_list, position_offset, attention_mask=attention_mask
|
| 333 |
+
)
|
| 334 |
logits = self.lm_head(hidden_states)
|
| 335 |
|
| 336 |
loss = None
|
|
|
|
| 342 |
shift_labels.view(-1),
|
| 343 |
ignore_index=-100,
|
| 344 |
)
|
| 345 |
+
loss = loss + self.config.aux_loss_coef * aux_loss
|
| 346 |
|
| 347 |
new_cache = None
|
| 348 |
if use_cache:
|
|
|
|
| 363 |
top_k: int = 50,
|
| 364 |
) -> torch.Tensor:
|
| 365 |
"""Simple autoregressive generation with KV cache."""
|
| 366 |
+
# Prefill (aux_loss ignored during generation)
|
| 367 |
+
hidden_states, kv_caches, _ = self.model(input_ids)
|
| 368 |
logits = self.lm_head(hidden_states[:, -1:, :]).squeeze(1)
|
| 369 |
cur_len = input_ids.shape[1]
|
| 370 |
|
|
|
|
| 379 |
next_token = torch.multinomial(probs, 1)
|
| 380 |
input_ids = torch.cat([input_ids, next_token], dim=1)
|
| 381 |
|
| 382 |
+
hidden_states, kv_caches, _ = self.model(next_token, kv_caches, cur_len)
|
| 383 |
logits = self.lm_head(hidden_states).squeeze(1)
|
| 384 |
cur_len += 1
|
| 385 |
|
moe.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Mixture of Experts (Mixtral-style) for PlasmidLM."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from .configuration_plasmid_lm import PlasmidLMConfig
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PlasmidLMExpertMLP(nn.Module):
|
| 13 |
+
"""Single expert MLP — same architecture as PlasmidLMMLP."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, hidden_size: int, intermediate_size: int):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 18 |
+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
| 19 |
+
self.act = nn.GELU()
|
| 20 |
+
|
| 21 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 22 |
+
return self.down_proj(self.act(self.up_proj(x)))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class PlasmidLMSparseMoE(nn.Module):
|
| 26 |
+
"""Sparse Mixture of Experts with top-k routing and load balancing loss.
|
| 27 |
+
|
| 28 |
+
Implements Mixtral-style routing: softmax over all experts, then select
|
| 29 |
+
top-k. The output is a weighted sum of the selected experts' outputs.
|
| 30 |
+
|
| 31 |
+
Load balancing auxiliary loss: L_aux = N * sum(f_i * P_i) where
|
| 32 |
+
f_i = fraction of tokens routed to expert i, P_i = mean routing
|
| 33 |
+
probability for expert i.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, config: PlasmidLMConfig):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.num_experts = config.num_experts
|
| 39 |
+
self.top_k = config.num_experts_per_tok
|
| 40 |
+
intermediate = config.moe_intermediate_size
|
| 41 |
+
|
| 42 |
+
self.router = nn.Linear(config.hidden_size, self.num_experts, bias=False)
|
| 43 |
+
self.experts = nn.ModuleList(
|
| 44 |
+
[PlasmidLMExpertMLP(config.hidden_size, intermediate) for _ in range(self.num_experts)]
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 48 |
+
"""
|
| 49 |
+
Args:
|
| 50 |
+
hidden_states: (batch, seq_len, hidden_size)
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
output: (batch, seq_len, hidden_size)
|
| 54 |
+
aux_loss: scalar load balancing loss
|
| 55 |
+
"""
|
| 56 |
+
batch, seq_len, hidden = hidden_states.shape
|
| 57 |
+
flat = hidden_states.view(-1, hidden) # (B*S, H)
|
| 58 |
+
num_tokens = flat.shape[0]
|
| 59 |
+
|
| 60 |
+
# Router: compute probabilities over experts
|
| 61 |
+
router_logits = self.router(flat) # (B*S, num_experts)
|
| 62 |
+
router_probs = F.softmax(router_logits, dim=-1)
|
| 63 |
+
|
| 64 |
+
# Top-k selection
|
| 65 |
+
top_weights, top_indices = torch.topk(router_probs, self.top_k, dim=-1) # (B*S, top_k)
|
| 66 |
+
# Normalize selected weights to sum to 1
|
| 67 |
+
top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)
|
| 68 |
+
|
| 69 |
+
# Dispatch: loop over experts, gather tokens, compute, scatter back
|
| 70 |
+
output = torch.zeros_like(flat)
|
| 71 |
+
for i, expert in enumerate(self.experts):
|
| 72 |
+
# Mask for tokens where expert i is in the top-k
|
| 73 |
+
mask = (top_indices == i).any(dim=-1) # (B*S,)
|
| 74 |
+
if not mask.any():
|
| 75 |
+
continue
|
| 76 |
+
expert_input = flat[mask] # (n_tokens, H)
|
| 77 |
+
expert_output = expert(expert_input) # (n_tokens, H)
|
| 78 |
+
# Weight for this expert for selected tokens
|
| 79 |
+
# Find which top-k slot(s) matched and get corresponding weight
|
| 80 |
+
match_positions = (top_indices[mask] == i) # (n_tokens, top_k)
|
| 81 |
+
weights = (top_weights[mask] * match_positions.float()).sum(dim=-1, keepdim=True) # (n_tokens, 1)
|
| 82 |
+
output[mask] += weights * expert_output
|
| 83 |
+
|
| 84 |
+
output = output.view(batch, seq_len, hidden)
|
| 85 |
+
|
| 86 |
+
# Load balancing auxiliary loss
|
| 87 |
+
# f_i: fraction of tokens dispatched to expert i
|
| 88 |
+
# P_i: mean routing probability assigned to expert i
|
| 89 |
+
with torch.no_grad():
|
| 90 |
+
# Count tokens per expert (based on top-k assignments)
|
| 91 |
+
expert_counts = torch.zeros(self.num_experts, device=flat.device)
|
| 92 |
+
for k in range(self.top_k):
|
| 93 |
+
expert_counts.scatter_add_(0, top_indices[:, k], torch.ones(num_tokens, device=flat.device))
|
| 94 |
+
f = expert_counts / (num_tokens * self.top_k) # fraction
|
| 95 |
+
|
| 96 |
+
P = router_probs.mean(dim=0) # (num_experts,)
|
| 97 |
+
aux_loss = self.num_experts * (f * P).sum()
|
| 98 |
+
|
| 99 |
+
return output, aux_loss
|
tokenization_plasmid_lm.py
CHANGED
|
@@ -62,6 +62,22 @@ class PlasmidLMTokenizer(PreTrainedTokenizer):
|
|
| 62 |
def vocab_size(self) -> int:
|
| 63 |
return len(self._vocab)
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
def get_vocab(self) -> dict:
|
| 66 |
return dict(self._vocab)
|
| 67 |
|
|
|
|
| 62 |
def vocab_size(self) -> int:
|
| 63 |
return len(self._vocab)
|
| 64 |
|
| 65 |
+
@property
|
| 66 |
+
def pad_token_id(self) -> int:
|
| 67 |
+
return self._vocab.get("<PAD>", 0)
|
| 68 |
+
|
| 69 |
+
@property
|
| 70 |
+
def bos_token_id(self) -> int:
|
| 71 |
+
return self._vocab.get("<BOS>", 1)
|
| 72 |
+
|
| 73 |
+
@property
|
| 74 |
+
def eos_token_id(self) -> int:
|
| 75 |
+
return self._vocab.get("<EOS>", 2)
|
| 76 |
+
|
| 77 |
+
@property
|
| 78 |
+
def sep_token_id(self) -> int:
|
| 79 |
+
return self._vocab.get("<SEP>", 3)
|
| 80 |
+
|
| 81 |
def get_vocab(self) -> dict:
|
| 82 |
return dict(self._vocab)
|
| 83 |
|