ternary-codon-encoder / test_inference.py
geestaltt's picture
Upload folder using huggingface_hub
c75f273 verified
import torch
from trainable_codon_encoder import TrainableCodonEncoder
def test():
print("Testing Ternary Codon Encoder inference...")
model = TrainableCodonEncoder(latent_dim=16, hidden_dim=64)
checkpoint = torch.load("pytorch_model.bin", map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
# ATG index is 14
codon_idx = torch.tensor([14])
with torch.no_grad():
z_hyp = model(codon_idx)
print(f"Codon index: {codon_idx.item()}")
print(f"Hyperbolic Embedding shape: {z_hyp.shape}")
print(f"Embedding: {z_hyp[0, :5]}...")
print("Test passed!")
if __name__ == "__main__":
test()