"""Checkpoint save, prune, and resume utilities.""" from __future__ import annotations import glob import os from pathlib import Path import torch def save_checkpoint( model: torch.nn.Module, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LambdaLR, scaler: torch.GradScaler | None, step: int, config: dict[str, object], output_dir: str, keep: int = 5, ) -> str: """Persist a resumable training checkpoint.""" Path(output_dir).mkdir(parents=True, exist_ok=True) path = Path(output_dir) / f"ckpt_step_{step:07d}.pt" torch.save( { "step": step, "model": model.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), "scaler": scaler.state_dict() if scaler is not None else None, "rng_cpu": torch.get_rng_state(), "rng_gpu": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None, "config": config, }, path, ) _prune_old_checkpoints(output_dir, keep=keep) return str(path) def load_latest_checkpoint( model: torch.nn.Module, optimizer: torch.optim.Optimizer | None, scheduler: torch.optim.lr_scheduler.LambdaLR | None, scaler: torch.GradScaler | None, output_dir: str, device: str | torch.device, ) -> int: """Load the most recent checkpoint and return the step to resume from.""" checkpoints = sorted(glob.glob(os.path.join(output_dir, "ckpt_step_*.pt"))) if not checkpoints: return 0 checkpoint = torch.load(checkpoints[-1], map_location=device) model.load_state_dict(checkpoint["model"]) if optimizer is not None: optimizer.load_state_dict(checkpoint["optimizer"]) if scheduler is not None: scheduler.load_state_dict(checkpoint["scheduler"]) if scaler is not None and checkpoint.get("scaler") is not None: scaler.load_state_dict(checkpoint["scaler"]) torch.set_rng_state(checkpoint["rng_cpu"]) if checkpoint.get("rng_gpu") is not None and torch.cuda.is_available(): torch.cuda.set_rng_state_all(checkpoint["rng_gpu"]) return int(checkpoint["step"]) def _prune_old_checkpoints(output_dir: str, keep: int = 5) -> None: """Keep only the most recent checkpoints.""" checkpoints = sorted(glob.glob(os.path.join(output_dir, "ckpt_step_*.pt"))) for stale in checkpoints[:-keep]: os.remove(stale)