"""Checkpoint management utilities. Canonical Checkpoint Format (new): { 'step': int, # Training step number 'model_state': Dict[str, Tensor], # Model state dict 'optimizer_state': Dict, # Optimizer state dict (optional) 'config': Dict, # TrainingConfig as dict 'metrics': Dict[str, float], # Training metrics 'global_step': int, # (deprecated, kept for compat) same as step 'current_epoch': int, # (optional) current epoch number 'best_loss': float, # (optional) best validation loss } Legacy Checkpoint Format (old, from BaseTrainer): { 'global_step': int, 'current_epoch': int, 'best_loss': float, 'model_state_dict': Dict[str, Tensor], # ← Note: uses '_dict' suffix 'optimizer_state_dict': Dict, 'config': Dict, } The load() function auto-detects and migrates legacy format to canonical format. """ from pathlib import Path from typing import Dict, Any, Optional import torch from taoTrain.config import TrainingConfig class CheckpointManager: """Manage model checkpoints with versioning.""" def __init__( self, checkpoint_dir: str | Path, keep_last_n: int = 3, track_best: bool = True, ): """ Initialize checkpoint manager. Args: checkpoint_dir: Directory to save checkpoints keep_last_n: Number of recent checkpoints to keep track_best: Whether to track best model """ self.checkpoint_dir = Path(checkpoint_dir) self.checkpoint_dir.mkdir(parents=True, exist_ok=True) self.keep_last_n = keep_last_n self.track_best = track_best self.best_metric = None self.best_metric_name = None self.saved_checkpoints = [] def save( self, step: int, model_state: Dict[str, Any], optimizer_state: Optional[Dict[str, Any]] = None, config: Optional[TrainingConfig] = None, metrics: Optional[Dict[str, float]] = None, is_best: bool = False, ) -> Path: """ Save a checkpoint. Args: step: Training step model_state: Model state dict optimizer_state: Optimizer state dict config: Training config metrics: Metrics dict is_best: Whether this is the best model so far Returns: Path to saved checkpoint """ checkpoint = { "step": step, "model_state": model_state, "optimizer_state": optimizer_state, "config": config.to_dict() if config else None, "metrics": metrics or {}, } filename = f"checkpoint_step_{step:06d}.pt" if is_best: filename = "best_model.pt" path = self.checkpoint_dir / filename torch.save(checkpoint, path) # Track saved checkpoints if not is_best: self.saved_checkpoints.append((step, path)) # Clean up old checkpoints if len(self.saved_checkpoints) > self.keep_last_n: _, old_path = self.saved_checkpoints.pop(0) if old_path.exists(): old_path.unlink() return path def load( self, checkpoint_path: str | Path, device: Optional[torch.device] = None, ) -> Dict[str, Any]: """ Load a checkpoint with backward-compatible format handling. Auto-detects checkpoint format (canonical or legacy) and normalizes to canonical format in-memory. Legacy checkpoints are migrated without modifying the file. Args: checkpoint_path: Path to checkpoint device: Device to load to Returns: Checkpoint dict in canonical format with 'model_state' key """ if device is None: device = torch.device("cpu") checkpoint = torch.load(checkpoint_path, map_location=device) # Auto-detect and migrate legacy format to canonical format checkpoint = self._normalize_checkpoint_format(checkpoint) return checkpoint def _normalize_checkpoint_format(self, checkpoint: Dict[str, Any]) -> Dict[str, Any]: """ Normalize checkpoint to canonical format. Detects if checkpoint is in legacy format (from BaseTrainer with 'model_state_dict') and migrates it to canonical format (with 'model_state'). Args: checkpoint: Raw checkpoint dict Returns: Normalized checkpoint dict with canonical keys """ # Check if this is a legacy checkpoint (has 'model_state_dict' but not 'model_state') if "model_state_dict" in checkpoint and "model_state" not in checkpoint: # Migrate legacy format to canonical migrated = { "step": checkpoint.get("global_step", 0), "model_state": checkpoint["model_state_dict"], "optimizer_state": checkpoint.get("optimizer_state_dict"), "config": checkpoint.get("config"), "metrics": {}, # Keep legacy keys for backward compatibility in code that uses them "global_step": checkpoint.get("global_step", 0), "current_epoch": checkpoint.get("current_epoch", 0), "best_loss": checkpoint.get("best_loss", float('inf')), } print(f"\n✓ [CheckpointManager] Detected legacy checkpoint format. Auto-migrated to canonical format.") return migrated # Already in canonical format or unknown format if "model_state" not in checkpoint: # If neither format detected, ensure model_state is accessible # (might be a raw state_dict) print(f"\n⚠ [CheckpointManager] Checkpoint format unclear. Assuming raw state_dict format.") checkpoint["model_state"] = checkpoint return checkpoint def get_latest(self) -> Optional[Path]: """Get path to latest checkpoint.""" if not self.saved_checkpoints: return None return self.saved_checkpoints[-1][1] def get_best(self) -> Optional[Path]: """Get path to best checkpoint.""" best_path = self.checkpoint_dir / "best_model.pt" if best_path.exists(): return best_path return None def list_checkpoints(self) -> list[Path]: """List all saved checkpoints.""" return sorted(self.checkpoint_dir.glob("checkpoint_step_*.pt"))