""" Optimized checkpoint utilities for faster saving/loading. Features: - Async checkpoint saving (non-blocking) - Compression (gzip) for smaller files - Incremental checkpoints (only save changed weights) - Checkpoint validation """ import gzip import logging from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Any, Dict, Optional import torch import torch.nn as nn logger = logging.getLogger(__name__) # Global thread pool for async operations _executor = ThreadPoolExecutor(max_workers=2) def save_checkpoint_async( checkpoint_data: Dict[str, Any], checkpoint_path: Path, compress: bool = True, validate: bool = True, ) -> None: """ Save checkpoint asynchronously (non-blocking). Args: checkpoint_data: Checkpoint data dict checkpoint_path: Path to save checkpoint compress: Whether to compress checkpoint (gzip) validate: Whether to validate checkpoint after saving """ checkpoint_path.parent.mkdir(parents=True, exist_ok=True) def _save(): try: if compress: # Save compressed with gzip.open(f"{checkpoint_path}.gz", "wb") as f: torch.save(checkpoint_data, f) logger.debug(f"Saved compressed checkpoint to {checkpoint_path}.gz") else: # Save uncompressed torch.save(checkpoint_data, checkpoint_path) logger.debug(f"Saved checkpoint to {checkpoint_path}") if validate: # Validate by loading if compress: with gzip.open(f"{checkpoint_path}.gz", "rb") as f: _ = torch.load(f) else: _ = torch.load(checkpoint_path) logger.debug(f"Validated checkpoint: {checkpoint_path}") except Exception as e: logger.error(f"Error saving checkpoint asynchronously: {e}") # Submit to thread pool (non-blocking) _executor.submit(_save) def save_checkpoint_compressed( checkpoint_data: Dict[str, Any], checkpoint_path: Path, compression_level: int = 6, ) -> Path: """ Save checkpoint with compression. Args: checkpoint_data: Checkpoint data dict checkpoint_path: Path to save checkpoint compression_level: Gzip compression level (0-9) Returns: Path to saved checkpoint (with .gz extension) """ checkpoint_path.parent.mkdir(parents=True, exist_ok=True) compressed_path = checkpoint_path.with_suffix(checkpoint_path.suffix + ".gz") # Save compressed with gzip.open(compressed_path, "wb", compresslevel=compression_level) as f: torch.save(checkpoint_data, f) original_size = sum( p.stat().st_size for p in checkpoint_path.parent.glob(checkpoint_path.name) if p != compressed_path ) compressed_size = compressed_path.stat().st_size compression_ratio = (1 - compressed_size / original_size) * 100 if original_size > 0 else 0 logger.info( f"Saved compressed checkpoint: {compressed_path} " f"({compressed_size / 1024 / 1024:.2f} MB, " f"{compression_ratio:.1f}% compression)" ) return compressed_path def load_checkpoint_compressed(checkpoint_path: Path) -> Dict[str, Any]: """ Load compressed checkpoint. Args: checkpoint_path: Path to checkpoint (with or without .gz extension) Returns: Checkpoint data dict """ # Try compressed first if checkpoint_path.suffix == ".gz": compressed_path = checkpoint_path else: compressed_path = checkpoint_path.with_suffix(checkpoint_path.suffix + ".gz") if compressed_path.exists(): with gzip.open(compressed_path, "rb") as f: checkpoint = torch.load(f, map_location="cpu") logger.info(f"Loaded compressed checkpoint from {compressed_path}") return checkpoint # Fallback to uncompressed if checkpoint_path.exists(): checkpoint = torch.load(checkpoint_path, map_location="cpu") logger.info(f"Loaded checkpoint from {checkpoint_path}") return checkpoint raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") def save_incremental_checkpoint( model: nn.Module, optimizer, scheduler, epoch: int, loss: float, checkpoint_path: Path, base_checkpoint_path: Optional[Path] = None, save_full_every: int = 10, ) -> Path: """ Save incremental checkpoint (only changed weights). Args: model: Model to save optimizer: Optimizer state scheduler: Scheduler state epoch: Current epoch loss: Current loss checkpoint_path: Path to save checkpoint base_checkpoint_path: Path to base checkpoint (for diff) save_full_every: Save full checkpoint every N epochs Returns: Path to saved checkpoint """ checkpoint_path.parent.mkdir(parents=True, exist_ok=True) # Save full checkpoint periodically if base_checkpoint_path is None or epoch % save_full_every == 0: checkpoint_data = { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "loss": loss, "is_full": True, } torch.save(checkpoint_data, checkpoint_path) logger.info(f"Saved full checkpoint to {checkpoint_path}") return checkpoint_path # Save incremental checkpoint (diff from base) if base_checkpoint_path and base_checkpoint_path.exists(): base_checkpoint = torch.load(base_checkpoint_path, map_location="cpu") base_state = base_checkpoint.get("model_state_dict", {}) current_state = model.state_dict() diff_state = {} # Only save changed parameters for key, value in current_state.items(): if key not in base_state or not torch.equal(value, base_state[key]): diff_state[key] = value checkpoint_data = { "epoch": epoch, "model_state_dict": diff_state, # Only differences "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "loss": loss, "is_full": False, "base_checkpoint": str(base_checkpoint_path), } torch.save(checkpoint_data, checkpoint_path) logger.info( f"Saved incremental checkpoint to {checkpoint_path} " f"({len(diff_state)}/{len(current_state)} parameters changed)" ) return checkpoint_path # Fallback to full checkpoint checkpoint_data = { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "loss": loss, "is_full": True, } torch.save(checkpoint_data, checkpoint_path) logger.info(f"Saved full checkpoint to {checkpoint_path}") return checkpoint_path def load_incremental_checkpoint( model: nn.Module, checkpoint_path: Path, device: str = "cpu", ) -> Dict[str, Any]: """ Load incremental checkpoint (applies diff to base). Args: model: Model to load weights into checkpoint_path: Path to incremental checkpoint device: Device to load on Returns: Checkpoint data dict """ checkpoint = torch.load(checkpoint_path, map_location=device) if checkpoint.get("is_full", True): # Full checkpoint model.load_state_dict(checkpoint["model_state_dict"]) logger.info(f"Loaded full checkpoint from {checkpoint_path}") return checkpoint # Incremental checkpoint - need to load base first base_checkpoint_path = Path(checkpoint.get("base_checkpoint", "")) if not base_checkpoint_path.exists(): logger.warning( f"Base checkpoint not found: {base_checkpoint_path}. " "Loading incremental checkpoint as-is." ) model.load_state_dict(checkpoint["model_state_dict"], strict=False) return checkpoint # Load base checkpoint base_checkpoint = torch.load(base_checkpoint_path, map_location=device) base_state = base_checkpoint.get("model_state_dict", {}) # Apply diff diff_state = checkpoint["model_state_dict"] full_state = base_state.copy() full_state.update(diff_state) model.load_state_dict(full_state) logger.info( f"Loaded incremental checkpoint from {checkpoint_path} " f"(applied to base: {base_checkpoint_path})" ) return checkpoint def validate_checkpoint(checkpoint_path: Path) -> bool: """ Validate checkpoint file integrity. Args: checkpoint_path: Path to checkpoint Returns: True if valid, False otherwise """ try: if checkpoint_path.suffix == ".gz": with gzip.open(checkpoint_path, "rb") as f: checkpoint = torch.load(f, map_location="cpu") else: checkpoint = torch.load(checkpoint_path, map_location="cpu") # Check required keys required_keys = ["epoch", "model_state_dict"] if not all(key in checkpoint for key in required_keys): logger.error(f"Checkpoint missing required keys: {required_keys}") return False # Check state dict is valid if not isinstance(checkpoint["model_state_dict"], dict): logger.error("Checkpoint model_state_dict is not a dict") return False logger.info(f"Checkpoint validated: {checkpoint_path}") return True except Exception as e: logger.error(f"Checkpoint validation failed: {e}") return False def get_checkpoint_size(checkpoint_path: Path) -> Dict[str, float]: """ Get checkpoint file size information. Args: checkpoint_path: Path to checkpoint Returns: Dict with size information (bytes, mb, etc.) """ sizes = {} # Check compressed version compressed_path = checkpoint_path.with_suffix(checkpoint_path.suffix + ".gz") if compressed_path.exists(): sizes["compressed_bytes"] = compressed_path.stat().st_size sizes["compressed_mb"] = sizes["compressed_bytes"] / 1024 / 1024 # Check uncompressed version if checkpoint_path.exists(): sizes["uncompressed_bytes"] = checkpoint_path.stat().st_size sizes["uncompressed_mb"] = sizes["uncompressed_bytes"] / 1024 / 1024 if "compressed_bytes" in sizes: sizes["compression_ratio"] = ( 1 - sizes["compressed_bytes"] / sizes["uncompressed_bytes"] ) * 100 return sizes