| import torch | |
| from trainable_codon_encoder import TrainableCodonEncoder | |
| def test(): | |
| print("Testing Ternary Codon Encoder inference...") | |
| model = TrainableCodonEncoder(latent_dim=16, hidden_dim=64) | |
| checkpoint = torch.load("pytorch_model.bin", map_location="cpu", weights_only=False) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| model.eval() | |
| # ATG index is 14 | |
| codon_idx = torch.tensor([14]) | |
| with torch.no_grad(): | |
| z_hyp = model(codon_idx) | |
| print(f"Codon index: {codon_idx.item()}") | |
| print(f"Hyperbolic Embedding shape: {z_hyp.shape}") | |
| print(f"Embedding: {z_hyp[0, :5]}...") | |
| print("Test passed!") | |
| if __name__ == "__main__": | |
| test() | |