File size: 2,475 Bytes
ef18673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4f432f
ef18673
b4f432f
ef18673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4f432f
ef18673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""Checkpoint save, prune, and resume utilities."""

from __future__ import annotations

import glob
import os
from pathlib import Path

import torch


def save_checkpoint(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler.LambdaLR,
    scaler: torch.GradScaler | None,
    step: int,
    config: dict[str, object],
    output_dir: str,
    keep: int = 5,
) -> str:
    """Persist a resumable training checkpoint."""
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    path = Path(output_dir) / f"ckpt_step_{step:07d}.pt"
    torch.save(
        {
            "step": step,
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
            "scaler": scaler.state_dict() if scaler is not None else None,
            "rng_cpu": torch.get_rng_state(),
            "rng_gpu": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
            "config": config,
        },
        path,
    )
    _prune_old_checkpoints(output_dir, keep=keep)
    return str(path)


def load_latest_checkpoint(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer | None,
    scheduler: torch.optim.lr_scheduler.LambdaLR | None,
    scaler: torch.GradScaler | None,
    output_dir: str,
    device: str | torch.device,
) -> int:
    """Load the most recent checkpoint and return the step to resume from."""
    checkpoints = sorted(glob.glob(os.path.join(output_dir, "ckpt_step_*.pt")))
    if not checkpoints:
        return 0
    checkpoint = torch.load(checkpoints[-1], map_location=device)
    model.load_state_dict(checkpoint["model"])
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint["optimizer"])
    if scheduler is not None:
        scheduler.load_state_dict(checkpoint["scheduler"])
    if scaler is not None and checkpoint.get("scaler") is not None:
        scaler.load_state_dict(checkpoint["scaler"])
    torch.set_rng_state(checkpoint["rng_cpu"])
    if checkpoint.get("rng_gpu") is not None and torch.cuda.is_available():
        torch.cuda.set_rng_state_all(checkpoint["rng_gpu"])
    return int(checkpoint["step"])


def _prune_old_checkpoints(output_dir: str, keep: int = 5) -> None:
    """Keep only the most recent checkpoints."""
    checkpoints = sorted(glob.glob(os.path.join(output_dir, "ckpt_step_*.pt")))
    for stale in checkpoints[:-keep]:
        os.remove(stale)