McClain commited on
Commit
db2462a
·
verified ·
1 Parent(s): c7cbbd2

Update model to grpo-production-20251110 (90% pass rate, 3 ORI types at temp 1.3)

Browse files
Files changed (2) hide show
  1. README.md +91 -30
  2. model.safetensors +1 -1
README.md CHANGED
@@ -1,45 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # PlasmidGPT-GRPO
2
 
3
- PlasmidGPT-GRPO is a GRPO-trained causal language model for plasmid/DNA sequence generation.
4
 
5
- This update refreshes the weights (model.safetensors) and streamlines the documentation.
6
 
7
- ## Weights
8
- - `model.safetensors` (updated)
9
- - All tokenizer/config files remain unchanged.
10
 
11
- ## Training Run
12
- - Weights and metrics: https://wandb.ai/ucl-cssb/PlasmidRL/runs/ty13u43j/overview
13
 
14
- ## Usage
15
- Install:
16
- ```
17
- pip install torch transformers safetensors
18
- ```
19
 
20
- Load and generate:
21
- ```
 
22
  from transformers import AutoModelForCausalLM, AutoTokenizer
23
 
24
- model_id = "UCL-CSSB/PlasmidGPT-GRPO"
25
- tok = AutoTokenizer.from_pretrained(model_id)
26
- if tok.pad_token is None:
27
- tok.pad_token = tok.eos_token
28
- model = AutoModelForCausalLM.from_pretrained(model_id)
29
 
30
- inputs = tok(["ATG"], return_tensors="pt")
31
- out = model.generate(
 
 
 
32
  **inputs,
33
- max_new_tokens=128,
 
34
  do_sample=True,
35
- temperature=0.7,
36
- top_p=0.9,
37
- pad_token_id=tok.eos_token_id,
38
- eos_token_id=tok.eos_token_id,
39
  )
40
- print(tok.decode(out[0], skip_special_tokens=True))
 
 
41
  ```
42
 
43
- Notes:
44
- - Use sampling (temperature/top_p) for diverse sequences; disable for deterministic output.
45
- - Runs on CPU, CUDA, or Apple MPS depending on your PyTorch install.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: McClain/plasmidgpt-addgene-gpt2
3
+ library_name: transformers
4
+ model_name: PlasmidGPT-GRPO
5
+ tags:
6
+ - generated_from_trainer
7
+ - grpo
8
+ - trl
9
+ - biology
10
+ - plasmid
11
+ - dna
12
+ - synthetic-biology
13
+ license: mit
14
+ datasets:
15
+ - McClain/plasmids-ncbi-addgene
16
+ pipeline_tag: text-generation
17
+ ---
18
+
19
  # PlasmidGPT-GRPO
20
 
21
+ A generative model for plasmid DNA sequences, fine-tuned with Group Relative Policy Optimization (GRPO) reinforcement learning.
22
 
23
+ ## Model Description
24
 
25
+ This model is a fine-tuned version of [PlasmidGPT](https://huggingface.co/McClain/plasmidgpt-addgene-gpt2) optimized using GRPO to generate valid, functional plasmid sequences with:
26
+ - **Origin of replication (ORI)** - Required for plasmid maintenance
27
+ - **Antibiotic resistance marker (AMR)** - Required for selection
28
 
29
+ ### Performance
 
30
 
31
+ At temperature 1.3, this model achieves:
32
+ - **90% QC pass rate** (valid ORI + AMR)
33
+ - **3 unique ORI types** (ColE1, Col(pHAD28), Col440I)
34
+ - **100% unique sequences** (no duplicates)
 
35
 
36
+ ## Quick Start
37
+
38
+ ```python
39
  from transformers import AutoModelForCausalLM, AutoTokenizer
40
 
41
+ model = AutoModelForCausalLM.from_pretrained("UCL-CSSB/PlasmidGPT-GRPO")
42
+ tokenizer = AutoTokenizer.from_pretrained("UCL-CSSB/PlasmidGPT-GRPO")
 
 
 
43
 
44
+ # Generate a plasmid starting with ATG (start codon)
45
+ prompt = "ATG"
46
+ inputs = tokenizer(prompt, return_tensors="pt")
47
+
48
+ outputs = model.generate(
49
  **inputs,
50
+ max_new_tokens=2000,
51
+ temperature=1.3,
52
  do_sample=True,
53
+ pad_token_id=tokenizer.eos_token_id
 
 
 
54
  )
55
+
56
+ sequence = tokenizer.decode(outputs[0], skip_special_tokens=True)
57
+ print(sequence)
58
  ```
59
 
60
+ ## Training
61
+
62
+ [<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>](https://wandb.ai/ucl-cssb/PlasmidRL/runs/u3wt9c50)
63
+
64
+ This model was trained with GRPO (Group Relative Policy Optimization), a method introduced in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
65
+
66
+ The reward function optimizes for:
67
+ 1. Presence of a valid origin of replication (ORI)
68
+ 2. Presence of a valid antibiotic resistance marker (AMR)
69
+ 3. Absence of long repetitive sequences
70
+
71
+ ### Framework Versions
72
+
73
+ - TRL: 0.23.1
74
+ - Transformers: 4.57.0
75
+ - PyTorch: 2.8.0
76
+ - Datasets: 4.1.1
77
+ - Tokenizers: 0.22.1
78
+
79
+ ## Recommended Sampling Parameters
80
+
81
+ | Temperature | Pass Rate | ORI Diversity | Notes |
82
+ |-------------|-----------|---------------|-------|
83
+ | 0.8 | 37% | 1 type | Collapsed - avoid |
84
+ | 0.95 | 63% | 2 types | Conservative |
85
+ | 1.15 | 76% | 2 types | Balanced |
86
+ | **1.3** | **90%** | **3 types** | **Recommended** |
87
+
88
+ ## Citation
89
+
90
+ ```bibtex
91
+ @article{shao2024deepseekmath,
92
+ title={{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
93
+ author={Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
94
+ year=2024,
95
+ eprint={arXiv:2402.03300},
96
+ }
97
+
98
+ @misc{vonwerra2022trl,
99
+ title={{TRL: Transformer Reinforcement Learning}},
100
+ author={Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec},
101
+ year=2020,
102
+ journal={GitHub repository},
103
+ publisher={GitHub},
104
+ howpublished={\url{https://github.com/huggingface/trl}}
105
+ }
106
+ ```
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:353de867743e69096257539c5ae44131947d9e41ef8a9a0ffdd863b3cff9eee6
3
  size 438696576
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba508e7a9d4bfb9c095f95c11fe0e7a1131f6a9076e89852bdd22f67ca00c324
3
  size 438696576