Spaces:
Runtime error
Runtime error
File size: 491 Bytes
0c7049d |
1 2 3 4 5 6 7 8 9 10 11 12 13 |
from pathlib import Path
import torch
def save_model(model: torch.nn.Module,
model_name: str,
target_dir: str):
target_dir_path = Path(target_dir)
target_dir_path.mkdir(parents = True,
exist_ok = True)
assert model_name.endswith(".pth") or model_name.endswith(".pt"), "Model name should end with .pth or .pt"
model_save_path = target_dir_path / model_name
torch.save(obj = model.state_dict(),
f = model_save_path)
|