|
|
|
|
|
""" |
|
|
Test the simplified quickstart guide examples |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from transformers import GPT2LMHeadModel |
|
|
|
|
|
print("Testing simplified CodonGPT quickstart guide...") |
|
|
|
|
|
try: |
|
|
|
|
|
print("\n1. Testing custom component downloads...") |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
hf_hub_download(repo_id="naniltx/codonGPT", filename="tokenizer.py", local_dir="./") |
|
|
hf_hub_download(repo_id="naniltx/codonGPT", filename="synonymous_logit_processor.py", local_dir="./") |
|
|
print("β Custom components downloaded successfully") |
|
|
|
|
|
|
|
|
print("\n2. Testing custom component imports...") |
|
|
from tokenizer import CodonTokenizer |
|
|
from synonymous_logit_processor import SynonymMaskingLogitsProcessor |
|
|
print("β Custom components imported successfully") |
|
|
|
|
|
|
|
|
print("\n3. Testing direct model loading from Hugging Face...") |
|
|
model = GPT2LMHeadModel.from_pretrained("naniltx/codonGPT") |
|
|
model.eval() |
|
|
print("β Model loaded directly from HF successfully") |
|
|
|
|
|
|
|
|
print("\n4. Testing custom tokenizer...") |
|
|
tokenizer = CodonTokenizer() |
|
|
print(f"β Tokenizer loaded successfully (vocab size: {tokenizer.vocab_size})") |
|
|
|
|
|
|
|
|
print("\n5. Testing basic sequence generation...") |
|
|
input_sequence = "ATGAAACCC" |
|
|
input_codons = [input_sequence[i:i+3] for i in range(0, len(input_sequence), 3)] |
|
|
input_tokens = [tokenizer.bos_token_id] + tokenizer.convert_tokens_to_ids(input_codons) |
|
|
input_tensor = torch.tensor([input_tokens]) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
input_tensor, |
|
|
max_length=input_tensor.size(1) + 3, |
|
|
temperature=1.0, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
generated_tokens = outputs[0][input_tensor.size(1):].tolist() |
|
|
generated_codons = [tokenizer.decode([token_id]) for token_id in generated_tokens |
|
|
if token_id not in [tokenizer.pad_token_id, tokenizer.eos_token_id]] |
|
|
generated_sequence = ''.join(generated_codons) |
|
|
|
|
|
print(f"β Input sequence: {input_sequence}") |
|
|
print(f"β Generated sequence: {generated_sequence}") |
|
|
|
|
|
|
|
|
print("\n6. Testing synonym-aware generation...") |
|
|
from synonymous_logit_processor import generate_candidate_codons_with_generate |
|
|
from Bio.Seq import Seq |
|
|
|
|
|
initial_codons = ["ATG", "AAA", "CCC"] |
|
|
optimized_codons = generate_candidate_codons_with_generate( |
|
|
initial_codons, |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
temperature=1.0, |
|
|
top_k=50, |
|
|
top_p=0.9 |
|
|
) |
|
|
|
|
|
print(f"β Original: {initial_codons}") |
|
|
print(f"β Optimized: {optimized_codons}") |
|
|
|
|
|
|
|
|
original_aa = ''.join([str(Seq(codon).translate()) for codon in initial_codons]) |
|
|
optimized_aa = ''.join([str(Seq(codon).translate()) for codon in optimized_codons]) |
|
|
print(f"β Original AA: {original_aa}") |
|
|
print(f"β Optimized AA: {optimized_aa}") |
|
|
print(f"β AA preserved: {original_aa == optimized_aa}") |
|
|
|
|
|
print("\nπ All simplified quickstart tests passed!") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nβ Test failed with error: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |