PlasmidGPT-GRPO / test_generation.py
McClain's picture
Upload 9 files
280a94d verified
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}\n")
print("Loading RL-optimized PlasmidGPT-GRPO model...")
model = AutoModelForCausalLM.from_pretrained(
".",
trust_remote_code=True
).to(device)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(
".",
trust_remote_code=True
)
print("Generating optimized plasmid sequences...\n")
start_sequence = 'ATGGCTAGCGAATTCGGCGCGCCT'
print(f"Start sequence: {start_sequence}\n")
input_ids = tokenizer.encode(start_sequence, return_tensors='pt').to(device)
outputs = model.generate(
input_ids,
max_length=400,
num_return_sequences=3,
temperature=0.8,
do_sample=True,
top_k=50,
top_p=0.95,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
print("=" * 80)
for i, output in enumerate(outputs, 1):
sequence = tokenizer.decode(output, skip_special_tokens=True)
print(f"\nPlasmid {i}:")
print(f" Length: {len(sequence)} bp")
print(f" First 100 bp: {sequence[:100]}")
print(f" Last 100 bp: {sequence[-100:]}")
print("\n" + "=" * 80)
print("\nNote: These sequences are generated by an RL-optimized model trained to:")
print(" βœ“ Include proper genetic elements (ori, promoters, CDS, markers)")
print(" βœ“ Avoid repeat regions > 50 bp")
print(" βœ“ Generate compact, functional plasmids")
print(" βœ“ Organize genes in proper cassettes (promoter β†’ CDS β†’ terminator)")