"""Checkpoint Management for Training""" import json import logging import shutil from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional import torch logger = logging.getLogger(__name__) @dataclass class CheckpointMetadata: """Metadata for a checkpoint.""" step: int epoch: int global_step: int metrics: Dict[str, float] = field(default_factory=dict) config: Dict[str, Any] = field(default_factory=dict) model_name: str = "zenith" timestamp: str = "" def to_dict(self) -> Dict[str, Any]: return { "step": self.step, "epoch": self.epoch, "global_step": self.global_step, "metrics": self.metrics, "config": self.config, "model_name": self.model_name, "timestamp": self.timestamp, } @classmethod def from_dict(cls, data: Dict[str, Any]) -> "CheckpointMetadata": return cls(**data) class CheckpointManager: """Manages saving and loading of checkpoints.""" def __init__( self, checkpoint_dir: str, save_total_limit: int = 5, save_best_only: bool = False, metric_for_best: str = "eval_loss", greater_is_better: bool = False, ): self.checkpoint_dir = Path(checkpoint_dir) self.checkpoint_dir.mkdir(parents=True, exist_ok=True) self.save_total_limit = save_total_limit self.save_best_only = save_best_only self.metric_for_best = metric_for_best self.greater_is_better = greater_is_better self.best_metric = None self.checkpoints: List[Path] = [] # Load existing checkpoints self._scan_checkpoints() def _scan_checkpoints(self): """Scan checkpoint directory for existing checkpoints.""" for path in self.checkpoint_dir.glob("checkpoint-*"): if path.is_dir(): self.checkpoints.append(path) self.checkpoints.sort(key=lambda p: int(p.name.split("-")[1])) def save_checkpoint( self, state: Dict[str, Any], name: str, metrics: Optional[Dict[str, float]] = None, ) -> Path: """Save checkpoint to disk.""" checkpoint_path = self.checkpoint_dir / f"checkpoint-{name}" checkpoint_path.mkdir(exist_ok=True) # Save model state torch.save(state["model_state_dict"], checkpoint_path / "pytorch_model.bin") # Save optimizer and scheduler states if "optimizer_state_dict" in state: torch.save(state["optimizer_state_dict"], checkpoint_path / "optimizer.pt") if "scheduler_state_dict" in state and state["scheduler_state_dict"]: torch.save(state["scheduler_state_dict"], checkpoint_path / "scheduler.pt") if "scaler_state_dict" in state and state["scaler_state_dict"]: torch.save(state["scaler_state_dict"], checkpoint_path / "scaler.pt") # Save metadata metadata = CheckpointMetadata( step=state.get("step", 0), epoch=state.get("epoch", 0), global_step=state.get("global_step", 0), metrics=metrics or {}, config=state.get("config", {}), timestamp=state.get("timestamp", ""), ) with open(checkpoint_path / "metadata.json", "w") as f: json.dump(metadata.to_dict(), f, indent=2) logger.info(f"Checkpoint saved: {checkpoint_path}") # Update checkpoint list if checkpoint_path not in self.checkpoints: self.checkpoints.append(checkpoint_path) self.checkpoints.sort(key=lambda p: int(p.name.split("-")[1])) # Enforce limit if self.save_total_limit > 0 and len(self.checkpoints) > self.save_total_limit: self._remove_oldest_checkpoint() return checkpoint_path def load_checkpoint( self, checkpoint_path: Union[str, Path], model: torch.nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[Any] = None, scaler: Optional[torch.cuda.amp.GradScaler] = None, ) -> CheckpointMetadata: """Load checkpoint from disk.""" checkpoint_path = Path(checkpoint_path) if not checkpoint_path.exists(): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") # Load model model_path = checkpoint_path / "pytorch_model.bin" if model_path.exists(): state_dict = torch.load(model_path, map_location="cpu") model.load_state_dict(state_dict) logger.info(f"Loaded model from {model_path}") else: logger.warning(f"Model weights not found at {model_path}") # Load optimizer if optimizer is not None: opt_path = checkpoint_path / "optimizer.pt" if opt_path.exists(): optimizer.load_state_dict(torch.load(opt_path, map_location="cpu")) logger.info(f"Loaded optimizer from {opt_path}") # Load scheduler if scheduler is not None: sched_path = checkpoint_path / "scheduler.pt" if sched_path.exists(): scheduler.load_state_dict(torch.load(sched_path, map_location="cpu")) logger.info(f"Loaded scheduler from {sched_path}") # Load scaler if scaler is not None: scaler_path = checkpoint_path / "scaler.pt" if scaler_path.exists(): scaler.load_state_dict(torch.load(scaler_path, map_location="cpu")) logger.info(f"Loaded scaler from {scaler_path}") # Load metadata meta_path = checkpoint_path / "metadata.json" if meta_path.exists(): with open(meta_path, "r") as f: metadata = CheckpointMetadata.from_dict(json.load(f)) logger.info(f"Loaded metadata: epoch={metadata.epoch}, step={metadata.step}") else: metadata = CheckpointMetadata(step=0, epoch=0, global_step=0) return metadata def get_latest_checkpoint(self) -> Optional[Path]: """Get the most recent checkpoint.""" if self.checkpoints: return self.checkpoints[-1] return None def get_best_checkpoint(self) -> Optional[Path]: """Get the best checkpoint based on metric.""" if not self.checkpoints: return None best_path = None best_value = None for path in self.checkpoints: meta_path = path / "metadata.json" if meta_path.exists(): with open(meta_path, "r") as f: meta = CheckpointMetadata.from_dict(json.load(f)) if self.metric_for_best in meta.metrics: value = meta.metrics[self.metric_for_best] if best_value is None or ( self.greater_is_better and value > best_value ) or (not self.greater_is_better and value < best_value): best_value = value best_path = path return best_path def _remove_oldest_checkpoint(self): """Remove the oldest checkpoint to maintain limit.""" if len(self.checkpoints) > self.save_total_limit: oldest = self.checkpoints.pop(0) if oldest.exists(): shutil.rmtree(oldest) logger.info(f"Removed old checkpoint: {oldest}") def cleanup(self, keep: Optional[List[Path]] = None): """Clean up checkpoints, optionally keeping specific ones.""" if keep is None: keep = [] for path in self.checkpoints: if path not in keep: if path.exists(): shutil.rmtree(path) logger.info(f"Removed checkpoint: {path}") self._scan_checkpoints() def save_checkpoint( model: torch.nn.Module, optimizer: torch.optim.Optimizer, scheduler: Optional[Any], scaler: Optional[torch.cuda.amp.GradScaler], checkpoint_dir: str, epoch: int, global_step: int, metrics: Optional[Dict[str, float]] = None, config: Optional[Dict[str, Any]] = None, save_optimizer: bool = True, save_scheduler: bool = True, ): """Convenience function to save a checkpoint.""" manager = CheckpointManager(checkpoint_dir, save_total_limit=0) state = { "model_state_dict": model.state_dict(), "global_step": global_step, "epoch": epoch, "config": config or {}, "timestamp": "", } if save_optimizer: state["optimizer_state_dict"] = optimizer.state_dict() if save_scheduler and scheduler is not None: state["scheduler_state_dict"] = scheduler.state_dict() manager.save_checkpoint(state, f"step-{global_step}", metrics) def load_checkpoint( checkpoint_path: str, model: torch.nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[Any] = None, scaler: Optional[torch.cuda.amp.GradScaler] = None, ) -> int: """Convenience function to load a checkpoint.""" manager = CheckpointManager(Path(checkpoint_path).parent) metadata = manager.load_checkpoint(checkpoint_path, model, optimizer, scheduler, scaler) return metadata.global_step