import os import torch from typing import Optional class CheckpointManager: def __init__(self, relative_dir: str = "checkpoints"): """ Args: ckpt_dir (str): Directory where checkpoints are stored. """ base_dir = os.path.dirname(os.path.abspath(__file__)) self.ckpt_dir = os.path.join(base_dir, relative_dir) os.makedirs(self.ckpt_dir, exist_ok=True) def _format_filename(self, epoch: int, last_sample_idx:int) -> str: return os.path.join(self.ckpt_dir, f"epoch_{epoch}_sample_{last_sample_idx}.pth") def save(self, model, scaler, optimizer, scheduler, epoch, last_sample_idx): state = { "epoch": epoch, "last_sample_idx": last_sample_idx, "model": model.state_dict(), "scaler": scaler.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), } filename = self._format_filename(epoch,last_sample_idx) torch.save(state, filename) print(f"[checkpoint.py] Saved checkpoint to {filename}") def load(self, model, scaler=None, optimizer=None, scheduler=None, filename=None): if filename is None: # Load latest checkpoint files = sorted([f for f in os.listdir(self.ckpt_dir) if f.endswith(".pth")]) if not files: print(f"No checkpoints found in {self.ckpt_dir}") return 0, 0 filename = os.path.join(self.ckpt_dir, files[-1]) checkpoint = torch.load(filename) model.load_state_dict(checkpoint["model"]) if scaler: scaler.load_state_dict(checkpoint["scaler"]) if optimizer: optimizer.load_state_dict(checkpoint["optimizer"]) if scheduler: scheduler.load_state_dict(checkpoint["scheduler"]) print(f"[checkpoint.py] Loaded checkpoint from {filename}") return checkpoint["epoch"], checkpoint["last_sample_idx"]