File size: 3,622 Bytes
ee7017c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#!/usr/bin/env python3
"""
Test the simplified quickstart guide examples
"""

import torch
from transformers import GPT2LMHeadModel

print("Testing simplified CodonGPT quickstart guide...")

try:
    # Test 1: Download custom components (simulate what users would do)
    print("\n1. Testing custom component downloads...")
    from huggingface_hub import hf_hub_download
    
    # Download custom tokenizer and processor to current directory
    hf_hub_download(repo_id="naniltx/codonGPT", filename="tokenizer.py", local_dir="./")
    hf_hub_download(repo_id="naniltx/codonGPT", filename="synonymous_logit_processor.py", local_dir="./")
    print("βœ“ Custom components downloaded successfully")
    
    # Test 2: Import custom components
    print("\n2. Testing custom component imports...")
    from tokenizer import CodonTokenizer
    from synonymous_logit_processor import SynonymMaskingLogitsProcessor
    print("βœ“ Custom components imported successfully")
    
    # Test 3: Load model directly from HF
    print("\n3. Testing direct model loading from Hugging Face...")
    model = GPT2LMHeadModel.from_pretrained("naniltx/codonGPT")
    model.eval()
    print("βœ“ Model loaded directly from HF successfully")
    
    # Test 4: Load custom tokenizer
    print("\n4. Testing custom tokenizer...")
    tokenizer = CodonTokenizer()
    print(f"βœ“ Tokenizer loaded successfully (vocab size: {tokenizer.vocab_size})")
    
    # Test 5: Basic sequence generation
    print("\n5. Testing basic sequence generation...")
    input_sequence = "ATGAAACCC"
    input_codons = [input_sequence[i:i+3] for i in range(0, len(input_sequence), 3)]
    input_tokens = [tokenizer.bos_token_id] + tokenizer.convert_tokens_to_ids(input_codons)
    input_tensor = torch.tensor([input_tokens])
    
    with torch.no_grad():
        outputs = model.generate(
            input_tensor,
            max_length=input_tensor.size(1) + 3,
            temperature=1.0,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    
    generated_tokens = outputs[0][input_tensor.size(1):].tolist()
    generated_codons = [tokenizer.decode([token_id]) for token_id in generated_tokens 
                       if token_id not in [tokenizer.pad_token_id, tokenizer.eos_token_id]]
    generated_sequence = ''.join(generated_codons)
    
    print(f"βœ“ Input sequence: {input_sequence}")
    print(f"βœ“ Generated sequence: {generated_sequence}")
    
    # Test 6: Synonym-aware generation
    print("\n6. Testing synonym-aware generation...")
    from synonymous_logit_processor import generate_candidate_codons_with_generate
    from Bio.Seq import Seq
    
    initial_codons = ["ATG", "AAA", "CCC"]
    optimized_codons = generate_candidate_codons_with_generate(
        initial_codons,
        model=model,
        tokenizer=tokenizer,
        temperature=1.0,
        top_k=50,
        top_p=0.9
    )
    
    print(f"βœ“ Original: {initial_codons}")
    print(f"βœ“ Optimized: {optimized_codons}")
    
    # Verify amino acid preservation
    original_aa = ''.join([str(Seq(codon).translate()) for codon in initial_codons])
    optimized_aa = ''.join([str(Seq(codon).translate()) for codon in optimized_codons])
    print(f"βœ“ Original AA: {original_aa}")
    print(f"βœ“ Optimized AA: {optimized_aa}")
    print(f"βœ“ AA preserved: {original_aa == optimized_aa}")
    
    print("\nπŸŽ‰ All simplified quickstart tests passed!")
    
except Exception as e:
    print(f"\n❌ Test failed with error: {e}")
    import traceback
    traceback.print_exc()