razor5050's picture
Add tokenizer, inference code, model card, and 20-query report
ca2f8ca verified
from pathlib import Path
import torch, time, shutil
def save_checkpoint(path, model, optimizer=None, scheduler=None, step=0, epoch=0, metrics=None, config=None):
path=Path(path); path.parent.mkdir(parents=True, exist_ok=True)
tmp=path.with_suffix(path.suffix+'.tmp')
state={'model':model.state_dict(),'step':step,'epoch':epoch,'metrics':metrics or {},'saved_at':time.time(),'config':config}
if optimizer is not None: state['optimizer']=optimizer.state_dict()
if scheduler is not None: state['scheduler']=scheduler.state_dict()
torch.save(state,tmp); tmp.replace(path)
latest=path.parent/'latest.pt'
if latest != path: shutil.copy2(path, latest)
def load_checkpoint(path, model, optimizer=None, scheduler=None, map_location='cpu'):
state=torch.load(path,map_location=map_location)
model.load_state_dict(state['model'])
if optimizer is not None and 'optimizer' in state: optimizer.load_state_dict(state['optimizer'])
if scheduler is not None and 'scheduler' in state: scheduler.load_state_dict(state['scheduler'])
return state