from pathlib import Path import torch from modeling_ftn import FTNConfig, FTNForCausalLM def load_model(repo_dir: str | Path, device: str = 'cpu'): repo_dir = Path(repo_dir) checkpoint = torch.load(repo_dir / 'best_checkpoint.pt', map_location=device) config = FTNConfig(**checkpoint['config']) model = FTNForCausalLM(config) model.load_state_dict(checkpoint['state_dict']) model.to(device) model.eval() return model