File size: 1,079 Bytes
ca2f8ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from pathlib import Path
import torch, time, shutil

def save_checkpoint(path, model, optimizer=None, scheduler=None, step=0, epoch=0, metrics=None, config=None):
    path=Path(path); path.parent.mkdir(parents=True, exist_ok=True)
    tmp=path.with_suffix(path.suffix+'.tmp')
    state={'model':model.state_dict(),'step':step,'epoch':epoch,'metrics':metrics or {},'saved_at':time.time(),'config':config}
    if optimizer is not None: state['optimizer']=optimizer.state_dict()
    if scheduler is not None: state['scheduler']=scheduler.state_dict()
    torch.save(state,tmp); tmp.replace(path)
    latest=path.parent/'latest.pt'
    if latest != path: shutil.copy2(path, latest)

def load_checkpoint(path, model, optimizer=None, scheduler=None, map_location='cpu'):
    state=torch.load(path,map_location=map_location)
    model.load_state_dict(state['model'])
    if optimizer is not None and 'optimizer' in state: optimizer.load_state_dict(state['optimizer'])
    if scheduler is not None and 'scheduler' in state: scheduler.load_state_dict(state['scheduler'])
    return state