File size: 391 Bytes
06c8a6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# save model utils
import torch
import torch
import os


def save_model(model: torch.nn.Module, path: str) -> str:
    parent_folder = os.path.dirname(path)
    os.makedirs(parent_folder, exist_ok=True)
    torch.save(model.state_dict(), path)
    return path

def load_model(model: torch.nn.Module, path: str) -> torch.nn.Module:
    model.load_state_dict(torch.load(path))
    return model