Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from typing import Optional | |
| class CheckpointManager: | |
| def __init__(self, relative_dir: str = "checkpoints"): | |
| """ | |
| Args: | |
| ckpt_dir (str): Directory where checkpoints are stored. | |
| """ | |
| base_dir = os.path.dirname(os.path.abspath(__file__)) | |
| self.ckpt_dir = os.path.join(base_dir, relative_dir) | |
| os.makedirs(self.ckpt_dir, exist_ok=True) | |
| def _format_filename(self, epoch: int, last_sample_idx:int) -> str: | |
| return os.path.join(self.ckpt_dir, f"epoch_{epoch}_sample_{last_sample_idx}.pth") | |
| def save(self, model, scaler, optimizer, scheduler, epoch, last_sample_idx): | |
| state = { | |
| "epoch": epoch, | |
| "last_sample_idx": last_sample_idx, | |
| "model": model.state_dict(), | |
| "scaler": scaler.state_dict(), | |
| "optimizer": optimizer.state_dict(), | |
| "scheduler": scheduler.state_dict(), | |
| } | |
| filename = self._format_filename(epoch,last_sample_idx) | |
| torch.save(state, filename) | |
| print(f"[checkpoint.py] Saved checkpoint to {filename}") | |
| def load(self, model, scaler=None, optimizer=None, scheduler=None, filename=None): | |
| if filename is None: | |
| # Load latest checkpoint | |
| files = sorted([f for f in os.listdir(self.ckpt_dir) if f.endswith(".pth")]) | |
| if not files: | |
| print(f"No checkpoints found in {self.ckpt_dir}") | |
| return 0, 0 | |
| filename = os.path.join(self.ckpt_dir, files[-1]) | |
| checkpoint = torch.load(filename) | |
| model.load_state_dict(checkpoint["model"]) | |
| if scaler: | |
| scaler.load_state_dict(checkpoint["scaler"]) | |
| if optimizer: | |
| optimizer.load_state_dict(checkpoint["optimizer"]) | |
| if scheduler: | |
| scheduler.load_state_dict(checkpoint["scheduler"]) | |
| print(f"[checkpoint.py] Loaded checkpoint from {filename}") | |
| return checkpoint["epoch"], checkpoint["last_sample_idx"] | |