File size: 696 Bytes
c75f273
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
from trainable_codon_encoder import TrainableCodonEncoder

def test():
    print("Testing Ternary Codon Encoder inference...")
    model = TrainableCodonEncoder(latent_dim=16, hidden_dim=64)
    checkpoint = torch.load("pytorch_model.bin", map_location="cpu", weights_only=False)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()

    # ATG index is 14
    codon_idx = torch.tensor([14])
    with torch.no_grad():
        z_hyp = model(codon_idx)

    print(f"Codon index: {codon_idx.item()}")
    print(f"Hyperbolic Embedding shape: {z_hyp.shape}")
    print(f"Embedding: {z_hyp[0, :5]}...")
    print("Test passed!")

if __name__ == "__main__":
    test()