Refresh weights and simplify README; add W&B link

#2
by McClain - opened
Files changed (2) hide show
  1. README.md +34 -247
  2. model.safetensors +1 -1
README.md CHANGED
@@ -1,258 +1,45 @@
1
- # PlasmidGPT-GRPO: Reinforcement Learning Fine-tuned Plasmid Generator
2
 
3
- [![W&B Run](https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg)](https://wandb.ai/ucl-cssb/PlasmidRL/runs/u3wt9c50)
4
 
5
- **A biologically-constrained plasmid design model trained with reinforcement learning to generate functional DNA sequences.**
6
 
7
- This model is a fine-tuned version of [McClain/plasmidgpt-addgene-gpt2](https://huggingface.co/McClain/plasmidgpt-addgene-gpt2) (itself based on the original [PlasmidGPT](https://github.com/lingxusb/PlasmidGPT) by Bin Shao), optimized using **Group Relative Policy Optimization (GRPO)** to generate plasmids that satisfy biological constraints.
 
 
8
 
9
- ## 🎯 Key Improvements Over Base Model
 
10
 
11
- This RL-fine-tuned model has been trained to generate plasmids that:
12
-
13
- - ✅ Contain **correct numbers** of essential genetic elements (ori, promoters, terminators, markers, CDS)
14
- - ✅ Avoid **repeat regions** (>50 bp repeats penalized)
15
- - ✅ Generate **shorter, more efficient** sequences (rewarded for compactness)
16
- - ✅ Maintain **proper gene cassette organization** (promoter → CDS → terminator)
17
- - ✅ Achieve up to **1.0 reward score** for optimal plasmid design
18
-
19
- ### Reward Structure
20
-
21
- The model was trained using a custom bioinformatics reward function that scores sequences based on:
22
-
23
- | Component | Min | Max | Weight | Description |
24
- |-----------|-----|-----|--------|-------------|
25
- | **Origin of Replication (ori)** | 1 | 1 | 1.5× | Essential for plasmid replication |
26
- | **Promoters** | 1 | 1 | 1.0× | Drive gene expression |
27
- | **Terminators** | 0 | 2 | 0.5× | Stop transcription |
28
- | **Selectable Markers** | 1 | 2 | 1.0× | Antibiotic resistance |
29
- | **Coding Sequences (CDS)** | 1 | 5 | 1.0× | Functional genes |
30
-
31
- **Additional Scoring:**
32
- - **Repeat Penalty**: -0.1 per repeat region ≥50 bp (including reverse complements)
33
- - **Length Bonus**: Rewards for shorter, more compact sequences (up to +0.5)
34
- - **Location Awareness**: Bonuses for correct gene cassette ordering and proximity
35
-
36
- **Maximum reward:** 1.0 (perfect plasmid with all constraints satisfied)
37
-
38
- ## 🚀 Quick Start
39
-
40
- ### Basic Sequence Generation
41
-
42
- ```python
43
- import torch
44
- from transformers import AutoTokenizer, AutoModelForCausalLM
45
-
46
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
47
-
48
- model = AutoModelForCausalLM.from_pretrained(
49
- "McClain/plasmidgpt-grpo-rl",
50
- trust_remote_code=True
51
- ).to(device)
52
- model.eval()
53
-
54
- tokenizer = AutoTokenizer.from_pretrained(
55
- "McClain/plasmidgpt-grpo-rl",
56
- trust_remote_code=True
57
- )
58
-
59
- # Generate optimized plasmid sequence
60
- start_sequence = 'ATGGCTAGCGAATTC'
61
- input_ids = tokenizer.encode(start_sequence, return_tensors='pt').to(device)
62
-
63
- outputs = model.generate(
64
- input_ids,
65
- max_length=400,
66
- num_return_sequences=5,
67
- temperature=0.8,
68
- do_sample=True,
69
- top_k=50,
70
- top_p=0.95,
71
- pad_token_id=tokenizer.pad_token_id,
72
- eos_token_id=tokenizer.eos_token_id
73
- )
74
-
75
- for i, output in enumerate(outputs):
76
- sequence = tokenizer.decode(output, skip_special_tokens=True)
77
- print(f"Plasmid {i+1}: {len(sequence)} bp")
78
  ```
79
-
80
- ### Scoring Generated Plasmids
81
-
82
- To evaluate plasmids using the same reward function from training:
83
-
84
- ```python
85
- # Install plasmidkit for annotation
86
- # pip install plasmidkit
87
-
88
- from plasmidrl.rewards import Scorer, RewardConfig
89
-
90
- # Use the same config as training
91
- reward_config = RewardConfig(
92
- punish_mode=True,
93
- length_reward_mode=False,
94
- repeat_penalty_enabled=True,
95
- repeat_min_length=50,
96
- repeat_penalty_per_region=0.1,
97
- ori_min=1, ori_max=1, ori_weight=1.5,
98
- promoter_min=1, promoter_max=1, promoter_weight=1.0,
99
- terminator_min=0, terminator_max=2, terminator_weight=0.5,
100
- marker_min=1, marker_max=2, marker_weight=1.0,
101
- cds_min=1, cds_max=5, cds_weight=1.0,
102
- location_aware=True
103
- )
104
-
105
- scorer = Scorer(reward_config)
106
- score, components = scorer.score(generated_sequence)
107
-
108
- print(f"Reward Score: {score:.3f}")
109
- print(f"Components: {components}")
110
  ```
111
 
112
- ## 📊 Training Details
113
-
114
- ### Training Configuration
115
-
116
- - **Base Model**: [McClain/plasmidgpt-addgene-gpt2](https://huggingface.co/McClain/plasmidgpt-addgene-gpt2)
117
- - **RL Algorithm**: GRPO (Group Relative Policy Optimization)
118
- - **Training Steps**: 2,500 steps
119
- - **Training Repository**: [PlasmidRL](https://github.com/McClain-Thiel/PlasmidRL)
120
- - **W&B Run**: [u3wt9c50](https://wandb.ai/ucl-cssb/PlasmidRL/runs/u3wt9c50)
121
-
122
- ### Model Architecture
123
-
124
- | Parameter | Value |
125
- |-----------|-------|
126
- | **Architecture** | GPT-2 (Decoder-only Transformer) |
127
- | **Parameters** | 110 million |
128
- | **Layers** | 12 |
129
- | **Hidden Size** | 768 |
130
- | **Attention Heads** | 12 |
131
- | **Context Length** | 2048 tokens |
132
- | **Vocabulary Size** | 30,002 |
133
-
134
- ### Framework Versions
135
-
136
- - **TRL**: 0.23.1
137
- - **Transformers**: 4.57.0
138
- - **PyTorch**: 2.8.0
139
- - **Datasets**: 4.1.1
140
- - **Tokenizers**: 0.22.1
141
-
142
- ## 🧬 Use Cases
143
-
144
- 1. **Optimized Plasmid Design**: Generate plasmids that satisfy specific biological constraints
145
- 2. **Synthetic Biology**: Create novel genetic constructs for molecular cloning
146
- 3. **Gene Cassette Engineering**: Design properly organized promoter-CDS-terminator cassettes
147
- 4. **Compact Plasmid Construction**: Generate shorter plasmids while maintaining functionality
148
- 5. **Repeat-Free Sequences**: Avoid problematic repeat regions in plasmid design
149
-
150
- ## 🔗 Related Resources
151
-
152
- ### Original PlasmidGPT
153
-
154
- This model builds upon the original PlasmidGPT work:
155
-
156
- - **Paper**: [PlasmidGPT: a generative framework for plasmid design and annotation](https://www.biorxiv.org/content/10.1101/2024.09.30.615762v1) (bioRxiv 2024.09.30.615762)
157
- - **Author**: Bin Shao (lingxusb)
158
- - **Original Repository**: [github.com/lingxusb/PlasmidGPT](https://github.com/lingxusb/PlasmidGPT)
159
- - **Original Model**: [huggingface.co/lingxusb/PlasmidGPT](https://huggingface.co/lingxusb/PlasmidGPT)
160
-
161
- ### Training Infrastructure
162
-
163
- - **Training Code**: [github.com/McClain-Thiel/PlasmidRL](https://github.com/McClain-Thiel/PlasmidRL)
164
- - **W&B Project**: [ucl-cssb/PlasmidRL](https://wandb.ai/ucl-cssb/PlasmidRL)
165
- - **Base Model**: [McClain/plasmidgpt-addgene-gpt2](https://huggingface.co/McClain/plasmidgpt-addgene-gpt2)
166
-
167
- ## 📚 Citations
168
-
169
- If you use this model, please cite:
170
-
171
- ### This RL Model
172
-
173
- ```bibtex
174
- @misc{thiel2024plasmidgpt_grpo,
175
- title={PlasmidGPT-GRPO: Reinforcement Learning for Functional Plasmid Design},
176
- author={Thiel, McClain},
177
- year={2024},
178
- howpublished={\url{https://github.com/McClain-Thiel/PlasmidRL}},
179
- note={Training run: https://wandb.ai/ucl-cssb/PlasmidRL/runs/u3wt9c50}
180
- }
181
- ```
182
-
183
- ### Original PlasmidGPT
184
-
185
- ```bibtex
186
- @article{shao2024plasmidgpt,
187
- title={PlasmidGPT: a generative framework for plasmid design and annotation},
188
- author={Shao, Bin and others},
189
- journal={bioRxiv},
190
- year={2024},
191
- doi={10.1101/2024.09.30.615762},
192
- url={https://www.biorxiv.org/content/10.1101/2024.09.30.615762v1}
193
- }
194
  ```
195
-
196
- ### GRPO Algorithm
197
-
198
- ```bibtex
199
- @article{shao2024deepseekmath,
200
- title={{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
201
- author={Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
202
- journal={arXiv preprint arXiv:2402.03300},
203
- year={2024}
204
- }
205
- ```
206
-
207
- ### TRL Library
208
-
209
- ```bibtex
210
- @misc{vonwerra2022trl,
211
- title={{TRL: Transformer Reinforcement Learning}},
212
- author={Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec},
213
- year={2020},
214
- publisher={GitHub},
215
- howpublished={\url{https://github.com/huggingface/trl}}
216
- }
217
  ```
218
 
219
- ## ⚙️ Technical Details
220
-
221
- ### Reward Function Components
222
-
223
- The bioinformatics reward function (`src/rewards/bioinformatics/scorer.py`) includes:
224
-
225
- 1. **Feature Counting**: Uses [PlasmidKit](https://github.com/jbloomlab/plasmidkit) for automated annotation
226
- 2. **Overlap Merging**: Intelligently merges overlapping features (80% threshold)
227
- 3. **CDS Filtering**: Removes CDS annotations overlapping with ori/promoter/terminator/marker
228
- 4. **Strand Awareness**: Considers strand orientation for gene cassette scoring
229
- 5. **Repeat Detection**: Finds direct and reverse complement repeats using k-mer indexing
230
- 6. **Proximity Scoring**: Rewards features within 300 bp for proper cassette formation
231
-
232
- ### Training Hyperparameters
233
-
234
- View complete hyperparameters and metrics on [W&B](https://wandb.ai/ucl-cssb/PlasmidRL/runs/u3wt9c50).
235
-
236
- ## ⚠️ Important Notes
237
-
238
- - **Research Use Only**: Generated plasmids should be validated before experimental use
239
- - **Annotation Dependency**: Scoring requires `plasmidkit` for feature annotation
240
- - **Compute Requirements**: GPU recommended for generation (CPU fallback available)
241
- - **Sequence Validation**: Always verify generated sequences contain expected features
242
-
243
- ## 📄 License
244
-
245
- This model inherits licensing from the original PlasmidGPT repository. Please refer to the [original repository](https://github.com/lingxusb/PlasmidGPT) for details.
246
-
247
- ## 🙏 Acknowledgments
248
-
249
- - **Bin Shao (lingxusb)** for the original PlasmidGPT model and architecture
250
- - **Addgene** for providing the training data (153k plasmid sequences)
251
- - **HuggingFace TRL team** for the GRPO implementation
252
- - **UCL CSSB** for computational resources
253
-
254
- ---
255
-
256
- **Model Version**: grpo-production-20251110_132247
257
- **Training Date**: November 10, 2025
258
- **Last Updated**: November 13, 2025
 
1
+ # PlasmidGPT-GRPO
2
 
3
+ PlasmidGPT-GRPO is a GRPO-trained causal language model for plasmid/DNA sequence generation.
4
 
5
+ This update refreshes the weights (model.safetensors) and streamlines the documentation.
6
 
7
+ ## Weights
8
+ - `model.safetensors` (updated)
9
+ - All tokenizer/config files remain unchanged.
10
 
11
+ ## Training Run
12
+ - Weights and metrics: https://wandb.ai/ucl-cssb/PlasmidRL/runs/ty13u43j/overview
13
 
14
+ ## Usage
15
+ Install:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  ```
17
+ pip install torch transformers safetensors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  ```
19
 
20
+ Load and generate:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  ```
22
+ from transformers import AutoModelForCausalLM, AutoTokenizer
23
+
24
+ model_id = "UCL-CSSB/PlasmidGPT-GRPO"
25
+ tok = AutoTokenizer.from_pretrained(model_id)
26
+ if tok.pad_token is None:
27
+ tok.pad_token = tok.eos_token
28
+ model = AutoModelForCausalLM.from_pretrained(model_id)
29
+
30
+ inputs = tok(["ATG"], return_tensors="pt")
31
+ out = model.generate(
32
+ **inputs,
33
+ max_new_tokens=128,
34
+ do_sample=True,
35
+ temperature=0.7,
36
+ top_p=0.9,
37
+ pad_token_id=tok.eos_token_id,
38
+ eos_token_id=tok.eos_token_id,
39
+ )
40
+ print(tok.decode(out[0], skip_special_tokens=True))
 
 
 
41
  ```
42
 
43
+ Notes:
44
+ - Use sampling (temperature/top_p) for diverse sequences; disable for deterministic output.
45
+ - Runs on CPU, CUDA, or Apple MPS depending on your PyTorch install.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ba508e7a9d4bfb9c095f95c11fe0e7a1131f6a9076e89852bdd22f67ca00c324
3
  size 438696576
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:353de867743e69096257539c5ae44131947d9e41ef8a9a0ffdd863b3cff9eee6
3
  size 438696576