| from pathlib import Path |
| import torch, time, shutil |
|
|
| def save_checkpoint(path, model, optimizer=None, scheduler=None, step=0, epoch=0, metrics=None, config=None): |
| path=Path(path); path.parent.mkdir(parents=True, exist_ok=True) |
| tmp=path.with_suffix(path.suffix+'.tmp') |
| state={'model':model.state_dict(),'step':step,'epoch':epoch,'metrics':metrics or {},'saved_at':time.time(),'config':config} |
| if optimizer is not None: state['optimizer']=optimizer.state_dict() |
| if scheduler is not None: state['scheduler']=scheduler.state_dict() |
| torch.save(state,tmp); tmp.replace(path) |
| latest=path.parent/'latest.pt' |
| if latest != path: shutil.copy2(path, latest) |
|
|
| def load_checkpoint(path, model, optimizer=None, scheduler=None, map_location='cpu'): |
| state=torch.load(path,map_location=map_location) |
| model.load_state_dict(state['model']) |
| if optimizer is not None and 'optimizer' in state: optimizer.load_state_dict(state['optimizer']) |
| if scheduler is not None and 'scheduler' in state: scheduler.load_state_dict(state['scheduler']) |
| return state |
|
|