File size: 696 Bytes
c75f273 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import torch
from trainable_codon_encoder import TrainableCodonEncoder
def test():
print("Testing Ternary Codon Encoder inference...")
model = TrainableCodonEncoder(latent_dim=16, hidden_dim=64)
checkpoint = torch.load("pytorch_model.bin", map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
# ATG index is 14
codon_idx = torch.tensor([14])
with torch.no_grad():
z_hyp = model(codon_idx)
print(f"Codon index: {codon_idx.item()}")
print(f"Hyperbolic Embedding shape: {z_hyp.shape}")
print(f"Embedding: {z_hyp[0, :5]}...")
print("Test passed!")
if __name__ == "__main__":
test()
|