codonGPT / test_quickstart.py
anuj2054's picture
test files for the quickstart guide
ee7017c verified
#!/usr/bin/env python3
"""
Test the simplified quickstart guide examples
"""
import torch
from transformers import GPT2LMHeadModel
print("Testing simplified CodonGPT quickstart guide...")
try:
# Test 1: Download custom components (simulate what users would do)
print("\n1. Testing custom component downloads...")
from huggingface_hub import hf_hub_download
# Download custom tokenizer and processor to current directory
hf_hub_download(repo_id="naniltx/codonGPT", filename="tokenizer.py", local_dir="./")
hf_hub_download(repo_id="naniltx/codonGPT", filename="synonymous_logit_processor.py", local_dir="./")
print("βœ“ Custom components downloaded successfully")
# Test 2: Import custom components
print("\n2. Testing custom component imports...")
from tokenizer import CodonTokenizer
from synonymous_logit_processor import SynonymMaskingLogitsProcessor
print("βœ“ Custom components imported successfully")
# Test 3: Load model directly from HF
print("\n3. Testing direct model loading from Hugging Face...")
model = GPT2LMHeadModel.from_pretrained("naniltx/codonGPT")
model.eval()
print("βœ“ Model loaded directly from HF successfully")
# Test 4: Load custom tokenizer
print("\n4. Testing custom tokenizer...")
tokenizer = CodonTokenizer()
print(f"βœ“ Tokenizer loaded successfully (vocab size: {tokenizer.vocab_size})")
# Test 5: Basic sequence generation
print("\n5. Testing basic sequence generation...")
input_sequence = "ATGAAACCC"
input_codons = [input_sequence[i:i+3] for i in range(0, len(input_sequence), 3)]
input_tokens = [tokenizer.bos_token_id] + tokenizer.convert_tokens_to_ids(input_codons)
input_tensor = torch.tensor([input_tokens])
with torch.no_grad():
outputs = model.generate(
input_tensor,
max_length=input_tensor.size(1) + 3,
temperature=1.0,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
generated_tokens = outputs[0][input_tensor.size(1):].tolist()
generated_codons = [tokenizer.decode([token_id]) for token_id in generated_tokens
if token_id not in [tokenizer.pad_token_id, tokenizer.eos_token_id]]
generated_sequence = ''.join(generated_codons)
print(f"βœ“ Input sequence: {input_sequence}")
print(f"βœ“ Generated sequence: {generated_sequence}")
# Test 6: Synonym-aware generation
print("\n6. Testing synonym-aware generation...")
from synonymous_logit_processor import generate_candidate_codons_with_generate
from Bio.Seq import Seq
initial_codons = ["ATG", "AAA", "CCC"]
optimized_codons = generate_candidate_codons_with_generate(
initial_codons,
model=model,
tokenizer=tokenizer,
temperature=1.0,
top_k=50,
top_p=0.9
)
print(f"βœ“ Original: {initial_codons}")
print(f"βœ“ Optimized: {optimized_codons}")
# Verify amino acid preservation
original_aa = ''.join([str(Seq(codon).translate()) for codon in initial_codons])
optimized_aa = ''.join([str(Seq(codon).translate()) for codon in optimized_codons])
print(f"βœ“ Original AA: {original_aa}")
print(f"βœ“ Optimized AA: {optimized_aa}")
print(f"βœ“ AA preserved: {original_aa == optimized_aa}")
print("\nπŸŽ‰ All simplified quickstart tests passed!")
except Exception as e:
print(f"\n❌ Test failed with error: {e}")
import traceback
traceback.print_exc()