| from pathlib import Path | |
| import torch | |
| from modeling_ftn import FTNConfig, FTNForCausalLM | |
| def load_model(repo_dir: str | Path, device: str = 'cpu'): | |
| repo_dir = Path(repo_dir) | |
| checkpoint = torch.load(repo_dir / 'best_checkpoint.pt', map_location=device) | |
| config = FTNConfig(**checkpoint['config']) | |
| model = FTNForCausalLM(config) | |
| model.load_state_dict(checkpoint['state_dict']) | |
| model.to(device) | |
| model.eval() | |
| return model | |