File size: 491 Bytes
0c7049d
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
from pathlib import Path
import torch
def save_model(model: torch.nn.Module,
               model_name: str,
               target_dir: str):
  target_dir_path = Path(target_dir)
  target_dir_path.mkdir(parents = True,
                        exist_ok = True)
  assert model_name.endswith(".pth") or model_name.endswith(".pt"), "Model name should end with .pth or .pt"
  model_save_path = target_dir_path / model_name
  torch.save(obj = model.state_dict(),
             f = model_save_path)