from pathlib import Path import torch, time, shutil def save_checkpoint(path, model, optimizer=None, scheduler=None, step=0, epoch=0, metrics=None, config=None): path=Path(path); path.parent.mkdir(parents=True, exist_ok=True) tmp=path.with_suffix(path.suffix+'.tmp') state={'model':model.state_dict(),'step':step,'epoch':epoch,'metrics':metrics or {},'saved_at':time.time(),'config':config} if optimizer is not None: state['optimizer']=optimizer.state_dict() if scheduler is not None: state['scheduler']=scheduler.state_dict() torch.save(state,tmp); tmp.replace(path) latest=path.parent/'latest.pt' if latest != path: shutil.copy2(path, latest) def load_checkpoint(path, model, optimizer=None, scheduler=None, map_location='cpu'): state=torch.load(path,map_location=map_location) model.load_state_dict(state['model']) if optimizer is not None and 'optimizer' in state: optimizer.load_state_dict(state['optimizer']) if scheduler is not None and 'scheduler' in state: scheduler.load_state_dict(state['scheduler']) return state