import torch from trainable_codon_encoder import TrainableCodonEncoder def test(): print("Testing Ternary Codon Encoder inference...") model = TrainableCodonEncoder(latent_dim=16, hidden_dim=64) checkpoint = torch.load("pytorch_model.bin", map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() # ATG index is 14 codon_idx = torch.tensor([14]) with torch.no_grad(): z_hyp = model(codon_idx) print(f"Codon index: {codon_idx.item()}") print(f"Hyperbolic Embedding shape: {z_hyp.shape}") print(f"Embedding: {z_hyp[0, :5]}...") print("Test passed!") if __name__ == "__main__": test()