McClain commited on
Commit
2fd6e3e
·
verified ·
1 Parent(s): 2ed2e2a

Upload 9 files

Browse files
README.md ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PlasmidGPT-GRPO: Reinforcement Learning Fine-tuned Plasmid Generator
2
+
3
+ [![W&B Run](https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg)](https://wandb.ai/ucl-cssb/PlasmidRL/runs/u3wt9c50)
4
+
5
+ **A biologically-constrained plasmid design model trained with reinforcement learning to generate functional DNA sequences.**
6
+
7
+ This model is a fine-tuned version of [McClain/plasmidgpt-addgene-gpt2](https://huggingface.co/McClain/plasmidgpt-addgene-gpt2) (itself based on the original [PlasmidGPT](https://github.com/lingxusb/PlasmidGPT) by Bin Shao), optimized using **Group Relative Policy Optimization (GRPO)** to generate plasmids that satisfy biological constraints.
8
+
9
+ ## 🎯 Key Improvements Over Base Model
10
+
11
+ This RL-fine-tuned model has been trained to generate plasmids that:
12
+
13
+ - ✅ Contain **correct numbers** of essential genetic elements (ori, promoters, terminators, markers, CDS)
14
+ - ✅ Avoid **repeat regions** (>50 bp repeats penalized)
15
+ - ✅ Generate **shorter, more efficient** sequences (rewarded for compactness)
16
+ - ✅ Maintain **proper gene cassette organization** (promoter → CDS → terminator)
17
+ - ✅ Achieve up to **1.0 reward score** for optimal plasmid design
18
+
19
+ ### Reward Structure
20
+
21
+ The model was trained using a custom bioinformatics reward function that scores sequences based on:
22
+
23
+ | Component | Min | Max | Weight | Description |
24
+ |-----------|-----|-----|--------|-------------|
25
+ | **Origin of Replication (ori)** | 1 | 1 | 1.5× | Essential for plasmid replication |
26
+ | **Promoters** | 1 | 1 | 1.0× | Drive gene expression |
27
+ | **Terminators** | 0 | 2 | 0.5× | Stop transcription |
28
+ | **Selectable Markers** | 1 | 2 | 1.0× | Antibiotic resistance |
29
+ | **Coding Sequences (CDS)** | 1 | 5 | 1.0× | Functional genes |
30
+
31
+ **Additional Scoring:**
32
+ - **Repeat Penalty**: -0.1 per repeat region ≥50 bp (including reverse complements)
33
+ - **Length Bonus**: Rewards for shorter, more compact sequences (up to +0.5)
34
+ - **Location Awareness**: Bonuses for correct gene cassette ordering and proximity
35
+
36
+ **Maximum reward:** 1.0 (perfect plasmid with all constraints satisfied)
37
+
38
+ ## 🚀 Quick Start
39
+
40
+ ### Basic Sequence Generation
41
+
42
+ ```python
43
+ import torch
44
+ from transformers import AutoTokenizer, AutoModelForCausalLM
45
+
46
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
47
+
48
+ model = AutoModelForCausalLM.from_pretrained(
49
+ "McClain/plasmidgpt-grpo-rl",
50
+ trust_remote_code=True
51
+ ).to(device)
52
+ model.eval()
53
+
54
+ tokenizer = AutoTokenizer.from_pretrained(
55
+ "McClain/plasmidgpt-grpo-rl",
56
+ trust_remote_code=True
57
+ )
58
+
59
+ # Generate optimized plasmid sequence
60
+ start_sequence = 'ATGGCTAGCGAATTC'
61
+ input_ids = tokenizer.encode(start_sequence, return_tensors='pt').to(device)
62
+
63
+ outputs = model.generate(
64
+ input_ids,
65
+ max_length=400,
66
+ num_return_sequences=5,
67
+ temperature=0.8,
68
+ do_sample=True,
69
+ top_k=50,
70
+ top_p=0.95,
71
+ pad_token_id=tokenizer.pad_token_id,
72
+ eos_token_id=tokenizer.eos_token_id
73
+ )
74
+
75
+ for i, output in enumerate(outputs):
76
+ sequence = tokenizer.decode(output, skip_special_tokens=True)
77
+ print(f"Plasmid {i+1}: {len(sequence)} bp")
78
+ ```
79
+
80
+ ### Scoring Generated Plasmids
81
+
82
+ To evaluate plasmids using the same reward function from training:
83
+
84
+ ```python
85
+ # Install plasmidkit for annotation
86
+ # pip install plasmidkit
87
+
88
+ from plasmidrl.rewards import Scorer, RewardConfig
89
+
90
+ # Use the same config as training
91
+ reward_config = RewardConfig(
92
+ punish_mode=True,
93
+ length_reward_mode=False,
94
+ repeat_penalty_enabled=True,
95
+ repeat_min_length=50,
96
+ repeat_penalty_per_region=0.1,
97
+ ori_min=1, ori_max=1, ori_weight=1.5,
98
+ promoter_min=1, promoter_max=1, promoter_weight=1.0,
99
+ terminator_min=0, terminator_max=2, terminator_weight=0.5,
100
+ marker_min=1, marker_max=2, marker_weight=1.0,
101
+ cds_min=1, cds_max=5, cds_weight=1.0,
102
+ location_aware=True
103
+ )
104
+
105
+ scorer = Scorer(reward_config)
106
+ score, components = scorer.score(generated_sequence)
107
+
108
+ print(f"Reward Score: {score:.3f}")
109
+ print(f"Components: {components}")
110
+ ```
111
+
112
+ ## 📊 Training Details
113
+
114
+ ### Training Configuration
115
+
116
+ - **Base Model**: [McClain/plasmidgpt-addgene-gpt2](https://huggingface.co/McClain/plasmidgpt-addgene-gpt2)
117
+ - **RL Algorithm**: GRPO (Group Relative Policy Optimization)
118
+ - **Training Steps**: 2,500 steps
119
+ - **Training Repository**: [PlasmidRL](https://github.com/McClain-Thiel/PlasmidRL)
120
+ - **W&B Run**: [u3wt9c50](https://wandb.ai/ucl-cssb/PlasmidRL/runs/u3wt9c50)
121
+
122
+ ### Model Architecture
123
+
124
+ | Parameter | Value |
125
+ |-----------|-------|
126
+ | **Architecture** | GPT-2 (Decoder-only Transformer) |
127
+ | **Parameters** | 110 million |
128
+ | **Layers** | 12 |
129
+ | **Hidden Size** | 768 |
130
+ | **Attention Heads** | 12 |
131
+ | **Context Length** | 2048 tokens |
132
+ | **Vocabulary Size** | 30,002 |
133
+
134
+ ### Framework Versions
135
+
136
+ - **TRL**: 0.23.1
137
+ - **Transformers**: 4.57.0
138
+ - **PyTorch**: 2.8.0
139
+ - **Datasets**: 4.1.1
140
+ - **Tokenizers**: 0.22.1
141
+
142
+ ## 🧬 Use Cases
143
+
144
+ 1. **Optimized Plasmid Design**: Generate plasmids that satisfy specific biological constraints
145
+ 2. **Synthetic Biology**: Create novel genetic constructs for molecular cloning
146
+ 3. **Gene Cassette Engineering**: Design properly organized promoter-CDS-terminator cassettes
147
+ 4. **Compact Plasmid Construction**: Generate shorter plasmids while maintaining functionality
148
+ 5. **Repeat-Free Sequences**: Avoid problematic repeat regions in plasmid design
149
+
150
+ ## 🔗 Related Resources
151
+
152
+ ### Original PlasmidGPT
153
+
154
+ This model builds upon the original PlasmidGPT work:
155
+
156
+ - **Paper**: [PlasmidGPT: a generative framework for plasmid design and annotation](https://www.biorxiv.org/content/10.1101/2024.09.30.615762v1) (bioRxiv 2024.09.30.615762)
157
+ - **Author**: Bin Shao (lingxusb)
158
+ - **Original Repository**: [github.com/lingxusb/PlasmidGPT](https://github.com/lingxusb/PlasmidGPT)
159
+ - **Original Model**: [huggingface.co/lingxusb/PlasmidGPT](https://huggingface.co/lingxusb/PlasmidGPT)
160
+
161
+ ### Training Infrastructure
162
+
163
+ - **Training Code**: [github.com/McClain-Thiel/PlasmidRL](https://github.com/McClain-Thiel/PlasmidRL)
164
+ - **W&B Project**: [ucl-cssb/PlasmidRL](https://wandb.ai/ucl-cssb/PlasmidRL)
165
+ - **Base Model**: [McClain/plasmidgpt-addgene-gpt2](https://huggingface.co/McClain/plasmidgpt-addgene-gpt2)
166
+
167
+ ## 📚 Citations
168
+
169
+ If you use this model, please cite:
170
+
171
+ ### This RL Model
172
+
173
+ ```bibtex
174
+ @misc{thiel2024plasmidgpt_grpo,
175
+ title={PlasmidGPT-GRPO: Reinforcement Learning for Functional Plasmid Design},
176
+ author={Thiel, McClain},
177
+ year={2024},
178
+ howpublished={\url{https://github.com/McClain-Thiel/PlasmidRL}},
179
+ note={Training run: https://wandb.ai/ucl-cssb/PlasmidRL/runs/u3wt9c50}
180
+ }
181
+ ```
182
+
183
+ ### Original PlasmidGPT
184
+
185
+ ```bibtex
186
+ @article{shao2024plasmidgpt,
187
+ title={PlasmidGPT: a generative framework for plasmid design and annotation},
188
+ author={Shao, Bin and others},
189
+ journal={bioRxiv},
190
+ year={2024},
191
+ doi={10.1101/2024.09.30.615762},
192
+ url={https://www.biorxiv.org/content/10.1101/2024.09.30.615762v1}
193
+ }
194
+ ```
195
+
196
+ ### GRPO Algorithm
197
+
198
+ ```bibtex
199
+ @article{shao2024deepseekmath,
200
+ title={{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
201
+ author={Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
202
+ journal={arXiv preprint arXiv:2402.03300},
203
+ year={2024}
204
+ }
205
+ ```
206
+
207
+ ### TRL Library
208
+
209
+ ```bibtex
210
+ @misc{vonwerra2022trl,
211
+ title={{TRL: Transformer Reinforcement Learning}},
212
+ author={Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec},
213
+ year={2020},
214
+ publisher={GitHub},
215
+ howpublished={\url{https://github.com/huggingface/trl}}
216
+ }
217
+ ```
218
+
219
+ ## ⚙️ Technical Details
220
+
221
+ ### Reward Function Components
222
+
223
+ The bioinformatics reward function (`src/rewards/bioinformatics/scorer.py`) includes:
224
+
225
+ 1. **Feature Counting**: Uses [PlasmidKit](https://github.com/jbloomlab/plasmidkit) for automated annotation
226
+ 2. **Overlap Merging**: Intelligently merges overlapping features (80% threshold)
227
+ 3. **CDS Filtering**: Removes CDS annotations overlapping with ori/promoter/terminator/marker
228
+ 4. **Strand Awareness**: Considers strand orientation for gene cassette scoring
229
+ 5. **Repeat Detection**: Finds direct and reverse complement repeats using k-mer indexing
230
+ 6. **Proximity Scoring**: Rewards features within 300 bp for proper cassette formation
231
+
232
+ ### Training Hyperparameters
233
+
234
+ View complete hyperparameters and metrics on [W&B](https://wandb.ai/ucl-cssb/PlasmidRL/runs/u3wt9c50).
235
+
236
+ ## ⚠️ Important Notes
237
+
238
+ - **Research Use Only**: Generated plasmids should be validated before experimental use
239
+ - **Annotation Dependency**: Scoring requires `plasmidkit` for feature annotation
240
+ - **Compute Requirements**: GPU recommended for generation (CPU fallback available)
241
+ - **Sequence Validation**: Always verify generated sequences contain expected features
242
+
243
+ ## 📄 License
244
+
245
+ This model inherits licensing from the original PlasmidGPT repository. Please refer to the [original repository](https://github.com/lingxusb/PlasmidGPT) for details.
246
+
247
+ ## 🙏 Acknowledgments
248
+
249
+ - **Bin Shao (lingxusb)** for the original PlasmidGPT model and architecture
250
+ - **Addgene** for providing the training data (153k plasmid sequences)
251
+ - **HuggingFace TRL team** for the GRPO implementation
252
+ - **UCL CSSB** for computational resources
253
+
254
+ ---
255
+
256
+ **Model Version**: grpo-production-20251110_132247
257
+ **Training Date**: November 10, 2025
258
+ **Last Updated**: November 13, 2025
config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_function": "gelu_new",
3
+ "architectures": [
4
+ "GPT2LMHeadModel"
5
+ ],
6
+ "attn_pdrop": 0.1,
7
+ "bos_token_id": 30000,
8
+ "dtype": "float32",
9
+ "embd_pdrop": 0.1,
10
+ "eos_token_id": 30001,
11
+ "initializer_range": 0.02,
12
+ "layer_norm_epsilon": 1e-05,
13
+ "model_type": "gpt2",
14
+ "n_ctx": 2048,
15
+ "n_embd": 768,
16
+ "n_head": 12,
17
+ "n_inner": null,
18
+ "n_layer": 12,
19
+ "n_positions": 2048,
20
+ "pad_token_id": 3,
21
+ "reorder_and_upcast_attn": false,
22
+ "resid_pdrop": 0.1,
23
+ "scale_attn_by_inverse_layer_idx": false,
24
+ "scale_attn_weights": true,
25
+ "summary_activation": null,
26
+ "summary_first_dropout": 0.1,
27
+ "summary_proj_to_labels": true,
28
+ "summary_type": "cls_index",
29
+ "summary_use_proj": true,
30
+ "task_specific_params": {
31
+ "text-generation": {
32
+ "do_sample": true,
33
+ "max_length": 50
34
+ }
35
+ },
36
+ "transformers_version": "4.57.0",
37
+ "use_cache": true,
38
+ "vocab_size": 30002
39
+ }
generation_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 30000,
4
+ "eos_token_id": [
5
+ 30001
6
+ ],
7
+ "pad_token_id": 3,
8
+ "transformers_version": "4.57.0"
9
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e290ca1ff16f34af23f74de1d660398209b66fad9fea9ba6065f5b1426ce1eb
3
+ size 235269120
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "eos_token": "</s>",
4
+ "pad_token": "[PAD]"
5
+ }
test_generation.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+
4
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
5
+ print(f"Using device: {device}\n")
6
+
7
+ print("Loading RL-optimized PlasmidGPT-GRPO model...")
8
+ model = AutoModelForCausalLM.from_pretrained(
9
+ ".",
10
+ trust_remote_code=True
11
+ ).to(device)
12
+ model.eval()
13
+
14
+ tokenizer = AutoTokenizer.from_pretrained(
15
+ ".",
16
+ trust_remote_code=True
17
+ )
18
+
19
+ print("Generating optimized plasmid sequences...\n")
20
+
21
+ start_sequence = 'ATGGCTAGCGAATTCGGCGCGCCT'
22
+ print(f"Start sequence: {start_sequence}\n")
23
+
24
+ input_ids = tokenizer.encode(start_sequence, return_tensors='pt').to(device)
25
+
26
+ outputs = model.generate(
27
+ input_ids,
28
+ max_length=400,
29
+ num_return_sequences=3,
30
+ temperature=0.8,
31
+ do_sample=True,
32
+ top_k=50,
33
+ top_p=0.95,
34
+ pad_token_id=tokenizer.pad_token_id,
35
+ eos_token_id=tokenizer.eos_token_id
36
+ )
37
+
38
+ print("=" * 80)
39
+ for i, output in enumerate(outputs, 1):
40
+ sequence = tokenizer.decode(output, skip_special_tokens=True)
41
+ print(f"\nPlasmid {i}:")
42
+ print(f" Length: {len(sequence)} bp")
43
+ print(f" First 100 bp: {sequence[:100]}")
44
+ print(f" Last 100 bp: {sequence[-100:]}")
45
+ print("\n" + "=" * 80)
46
+
47
+ print("\nNote: These sequences are generated by an RL-optimized model trained to:")
48
+ print(" ✓ Include proper genetic elements (ori, promoters, CDS, markers)")
49
+ print(" ✓ Avoid repeat regions > 50 bp")
50
+ print(" ✓ Generate compact, functional plasmids")
51
+ print(" ✓ Organize genes in proper cassettes (promoter → CDS → terminator)")
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[UNK]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "[CLS]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[SEP]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[PAD]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "30000": {
44
+ "content": "<s>",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "30001": {
52
+ "content": "</s>",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ }
59
+ },
60
+ "bos_token": "<s>",
61
+ "clean_up_tokenization_spaces": false,
62
+ "eos_token": "</s>",
63
+ "extra_special_tokens": {},
64
+ "max_length": null,
65
+ "model_max_length": 1000000000000000019884624838656,
66
+ "pad_to_multiple_of": null,
67
+ "pad_token": "[PAD]",
68
+ "pad_token_type_id": 0,
69
+ "padding_side": "left",
70
+ "tokenizer_class": "PreTrainedTokenizerFast"
71
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e38dae2f73a0f51976b1a463bd135d46624f945c6fd07f96a168b9f33e315d7
3
+ size 7377