File size: 1,335 Bytes
54c5666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import os
import torch
from typing import Optional


def save_checkpoint(checkpoint_dir: str, epoch: int, model, optimizer, scheduler, val_loss: float, config) -> str:
    os.makedirs(checkpoint_dir, exist_ok=True)
    path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pt")
    payload = {
        "epoch": epoch,
        "model_state_dict": model.state_dict() if hasattr(model, "state_dict") else model.module.state_dict(),
        "optimizer_state_dict": optimizer.state_dict() if optimizer is not None else None,
        "scheduler_state_dict": scheduler.state_dict() if scheduler is not None else None,
        "val_loss": val_loss,
        "config": config,
    }
    torch.save(payload, path)
    return path


def load_checkpoint(path: str, model, optimizer=None, scheduler=None) -> Optional[int]:
    if not os.path.exists(path):
        return None
    payload = torch.load(path, map_location="cpu")
    model.load_state_dict(payload["model_state_dict"], strict=False)
    if optimizer is not None and payload.get("optimizer_state_dict"):
        optimizer.load_state_dict(payload["optimizer_state_dict"])
    if scheduler is not None and payload.get("scheduler_state_dict"):
        scheduler.load_state_dict(payload["scheduler_state_dict"])
    return int(payload.get("epoch", 0))