ftn-v2-boundary-small / load_model.py
umm-dev's picture
Upload FTN research checkpoints and metadata
7793080 verified
from pathlib import Path
import torch
from modeling_ftn import FTNConfig, FTNForCausalLM
def load_model(repo_dir: str | Path, device: str = 'cpu'):
repo_dir = Path(repo_dir)
checkpoint = torch.load(repo_dir / 'best_checkpoint.pt', map_location=device)
config = FTNConfig(**checkpoint['config'])
model = FTNForCausalLM(config)
model.load_state_dict(checkpoint['state_dict'])
model.to(device)
model.eval()
return model