""" Checkpoint Manager for Mamba Swarm Handles saving, loading, and managing model checkpoints """ import os import json import time import shutil import logging import torch import threading from typing import Dict, List, Any, Optional, Tuple from dataclasses import dataclass, asdict from pathlib import Path from datetime import datetime import pickle import hashlib @dataclass class CheckpointMetadata: checkpoint_id: str timestamp: float epoch: int step: int loss: float model_config: Dict[str, Any] training_config: Dict[str, Any] metrics: Dict[str, float] file_path: str file_size: int checksum: str class CheckpointManager: """Manages model checkpoints for Mamba Swarm""" def __init__(self, checkpoint_dir: str = "./checkpoints", max_checkpoints: int = 10, save_interval: int = 1000, best_metric: str = "loss", best_metric_mode: str = "min"): self.checkpoint_dir = Path(checkpoint_dir) self.max_checkpoints = max_checkpoints self.save_interval = save_interval self.best_metric = best_metric self.best_metric_mode = best_metric_mode self.logger = logging.getLogger(__name__) self.lock = threading.Lock() # Create checkpoint directory self.checkpoint_dir.mkdir(parents=True, exist_ok=True) # Metadata storage self.metadata_file = self.checkpoint_dir / "metadata.json" self.checkpoints_metadata: Dict[str, CheckpointMetadata] = {} # Best checkpoint tracking self.best_checkpoint_id: Optional[str] = None self.best_metric_value: Optional[float] = None # Load existing metadata self._load_metadata() def save_checkpoint(self, model_state: Dict[str, Any], optimizer_state: Optional[Dict[str, Any]] = None, scheduler_state: Optional[Dict[str, Any]] = None, epoch: int = 0, step: int = 0, loss: float = 0.0, metrics: Optional[Dict[str, float]] = None, model_config: Optional[Dict[str, Any]] = None, training_config: Optional[Dict[str, Any]] = None, force_save: bool = False) -> str: """Save a checkpoint""" # Check if we should save based on interval if not force_save and step % self.save_interval != 0: return None # Generate checkpoint ID checkpoint_id = self._generate_checkpoint_id(epoch, step) # Prepare checkpoint data checkpoint_data = { "model_state": model_state, "optimizer_state": optimizer_state, "scheduler_state": scheduler_state, "epoch": epoch, "step": step, "loss": loss, "metrics": metrics or {}, "model_config": model_config or {}, "training_config": training_config or {}, "timestamp": time.time() } # Save checkpoint file checkpoint_path = self.checkpoint_dir / f"{checkpoint_id}.pt" with self.lock: try: torch.save(checkpoint_data, checkpoint_path) # Calculate file size and checksum file_size = checkpoint_path.stat().st_size checksum = self._calculate_checksum(checkpoint_path) # Create metadata metadata = CheckpointMetadata( checkpoint_id=checkpoint_id, timestamp=checkpoint_data["timestamp"], epoch=epoch, step=step, loss=loss, model_config=model_config or {}, training_config=training_config or {}, metrics=metrics or {}, file_path=str(checkpoint_path), file_size=file_size, checksum=checksum ) # Store metadata self.checkpoints_metadata[checkpoint_id] = metadata # Update best checkpoint self._update_best_checkpoint(checkpoint_id, metrics or {"loss": loss}) # Clean up old checkpoints self._cleanup_old_checkpoints() # Save metadata self._save_metadata() self.logger.info(f"Saved checkpoint {checkpoint_id} at step {step}") return checkpoint_id except Exception as e: self.logger.error(f"Failed to save checkpoint: {e}") # Clean up partial file if checkpoint_path.exists(): checkpoint_path.unlink() raise def load_checkpoint(self, checkpoint_id: Optional[str] = None) -> Optional[Dict[str, Any]]: """Load a checkpoint""" # Use best checkpoint if none specified if checkpoint_id is None: checkpoint_id = self.best_checkpoint_id if checkpoint_id is None or checkpoint_id not in self.checkpoints_metadata: self.logger.warning(f"Checkpoint {checkpoint_id} not found") return None metadata = self.checkpoints_metadata[checkpoint_id] checkpoint_path = Path(metadata.file_path) if not checkpoint_path.exists(): self.logger.error(f"Checkpoint file {checkpoint_path} does not exist") return None try: # Verify checksum if not self._verify_checksum(checkpoint_path, metadata.checksum): self.logger.error(f"Checkpoint {checkpoint_id} failed checksum verification") return None # Load checkpoint checkpoint_data = torch.load(checkpoint_path, map_location='cpu') self.logger.info(f"Loaded checkpoint {checkpoint_id} from step {metadata.step}") return checkpoint_data except Exception as e: self.logger.error(f"Failed to load checkpoint {checkpoint_id}: {e}") return None def load_best_checkpoint(self) -> Optional[Dict[str, Any]]: """Load the best checkpoint""" return self.load_checkpoint(self.best_checkpoint_id) def load_latest_checkpoint(self) -> Optional[Dict[str, Any]]: """Load the most recent checkpoint""" if not self.checkpoints_metadata: return None # Find latest checkpoint by timestamp latest_id = max(self.checkpoints_metadata.keys(), key=lambda x: self.checkpoints_metadata[x].timestamp) return self.load_checkpoint(latest_id) def list_checkpoints(self, sort_by: str = "timestamp") -> List[CheckpointMetadata]: """List all available checkpoints""" checkpoints = list(self.checkpoints_metadata.values()) if sort_by == "timestamp": checkpoints.sort(key=lambda x: x.timestamp, reverse=True) elif sort_by == "step": checkpoints.sort(key=lambda x: x.step, reverse=True) elif sort_by == "loss": checkpoints.sort(key=lambda x: x.loss) return checkpoints def delete_checkpoint(self, checkpoint_id: str) -> bool: """Delete a specific checkpoint""" if checkpoint_id not in self.checkpoints_metadata: self.logger.warning(f"Checkpoint {checkpoint_id} not found") return False metadata = self.checkpoints_metadata[checkpoint_id] checkpoint_path = Path(metadata.file_path) with self.lock: try: # Remove file if checkpoint_path.exists(): checkpoint_path.unlink() # Remove from metadata del self.checkpoints_metadata[checkpoint_id] # Update best checkpoint if needed if checkpoint_id == self.best_checkpoint_id: self._find_new_best_checkpoint() # Save metadata self._save_metadata() self.logger.info(f"Deleted checkpoint {checkpoint_id}") return True except Exception as e: self.logger.error(f"Failed to delete checkpoint {checkpoint_id}: {e}") return False def get_checkpoint_info(self, checkpoint_id: str) -> Optional[CheckpointMetadata]: """Get information about a specific checkpoint""" return self.checkpoints_metadata.get(checkpoint_id) def export_checkpoint(self, checkpoint_id: str, export_path: str) -> bool: """Export a checkpoint to a different location""" if checkpoint_id not in self.checkpoints_metadata: self.logger.error(f"Checkpoint {checkpoint_id} not found") return False metadata = self.checkpoints_metadata[checkpoint_id] source_path = Path(metadata.file_path) export_path = Path(export_path) try: # Copy checkpoint file shutil.copy2(source_path, export_path) # Copy metadata metadata_export_path = export_path.with_suffix('.json') with open(metadata_export_path, 'w') as f: json.dump(asdict(metadata), f, indent=2) self.logger.info(f"Exported checkpoint {checkpoint_id} to {export_path}") return True except Exception as e: self.logger.error(f"Failed to export checkpoint {checkpoint_id}: {e}") return False def import_checkpoint(self, checkpoint_path: str, metadata_path: Optional[str] = None) -> Optional[str]: """Import a checkpoint from external location""" checkpoint_path = Path(checkpoint_path) if not checkpoint_path.exists(): self.logger.error(f"Checkpoint file {checkpoint_path} does not exist") return None try: # Load metadata if provided if metadata_path: with open(metadata_path, 'r') as f: metadata_dict = json.load(f) metadata = CheckpointMetadata(**metadata_dict) else: # Try to extract metadata from checkpoint checkpoint_data = torch.load(checkpoint_path, map_location='cpu') metadata = CheckpointMetadata( checkpoint_id=self._generate_checkpoint_id( checkpoint_data.get("epoch", 0), checkpoint_data.get("step", 0) ), timestamp=checkpoint_data.get("timestamp", time.time()), epoch=checkpoint_data.get("epoch", 0), step=checkpoint_data.get("step", 0), loss=checkpoint_data.get("loss", 0.0), model_config=checkpoint_data.get("model_config", {}), training_config=checkpoint_data.get("training_config", {}), metrics=checkpoint_data.get("metrics", {}), file_path="", # Will be set below file_size=0, # Will be set below checksum="" # Will be set below ) # Copy to checkpoint directory new_checkpoint_path = self.checkpoint_dir / f"{metadata.checkpoint_id}.pt" shutil.copy2(checkpoint_path, new_checkpoint_path) # Update metadata metadata.file_path = str(new_checkpoint_path) metadata.file_size = new_checkpoint_path.stat().st_size metadata.checksum = self._calculate_checksum(new_checkpoint_path) with self.lock: self.checkpoints_metadata[metadata.checkpoint_id] = metadata self._update_best_checkpoint(metadata.checkpoint_id, metadata.metrics) self._save_metadata() self.logger.info(f"Imported checkpoint {metadata.checkpoint_id}") return metadata.checkpoint_id except Exception as e: self.logger.error(f"Failed to import checkpoint: {e}") return None def _generate_checkpoint_id(self, epoch: int, step: int) -> str: """Generate unique checkpoint ID""" timestamp = int(time.time()) return f"checkpoint_epoch_{epoch}_step_{step}_{timestamp}" def _calculate_checksum(self, file_path: Path) -> str: """Calculate MD5 checksum of file""" hash_md5 = hashlib.md5() with open(file_path, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): hash_md5.update(chunk) return hash_md5.hexdigest() def _verify_checksum(self, file_path: Path, expected_checksum: str) -> bool: """Verify file checksum""" actual_checksum = self._calculate_checksum(file_path) return actual_checksum == expected_checksum def _update_best_checkpoint(self, checkpoint_id: str, metrics: Dict[str, float]): """Update best checkpoint based on metrics""" if self.best_metric not in metrics: return metric_value = metrics[self.best_metric] if self.best_metric_value is None: # First checkpoint self.best_checkpoint_id = checkpoint_id self.best_metric_value = metric_value else: # Compare with current best is_better = False if self.best_metric_mode == "min": is_better = metric_value < self.best_metric_value elif self.best_metric_mode == "max": is_better = metric_value > self.best_metric_value if is_better: self.best_checkpoint_id = checkpoint_id self.best_metric_value = metric_value self.logger.info(f"New best checkpoint: {checkpoint_id} ({self.best_metric}: {metric_value})") def _find_new_best_checkpoint(self): """Find new best checkpoint after deletion""" if not self.checkpoints_metadata: self.best_checkpoint_id = None self.best_metric_value = None return best_id = None best_value = None for checkpoint_id, metadata in self.checkpoints_metadata.items(): if self.best_metric in metadata.metrics: metric_value = metadata.metrics[self.best_metric] if best_value is None: best_id = checkpoint_id best_value = metric_value else: is_better = False if self.best_metric_mode == "min": is_better = metric_value < best_value elif self.best_metric_mode == "max": is_better = metric_value > best_value if is_better: best_id = checkpoint_id best_value = metric_value self.best_checkpoint_id = best_id self.best_metric_value = best_value def _cleanup_old_checkpoints(self): """Remove old checkpoints to maintain max_checkpoints limit""" if len(self.checkpoints_metadata) <= self.max_checkpoints: return # Sort by timestamp (oldest first) sorted_checkpoints = sorted( self.checkpoints_metadata.items(), key=lambda x: x[1].timestamp ) # Calculate how many to remove num_to_remove = len(sorted_checkpoints) - self.max_checkpoints for i in range(num_to_remove): checkpoint_id, metadata = sorted_checkpoints[i] # Don't delete the best checkpoint if checkpoint_id == self.best_checkpoint_id: continue # Delete checkpoint checkpoint_path = Path(metadata.file_path) if checkpoint_path.exists(): checkpoint_path.unlink() del self.checkpoints_metadata[checkpoint_id] self.logger.info(f"Cleaned up old checkpoint: {checkpoint_id}") def _load_metadata(self): """Load checkpoint metadata from file""" if not self.metadata_file.exists(): return try: with open(self.metadata_file, 'r') as f: data = json.load(f) # Load checkpoint metadata for checkpoint_id, metadata_dict in data.get("checkpoints", {}).items(): metadata = CheckpointMetadata(**metadata_dict) self.checkpoints_metadata[checkpoint_id] = metadata # Load best checkpoint info self.best_checkpoint_id = data.get("best_checkpoint_id") self.best_metric_value = data.get("best_metric_value") self.logger.info(f"Loaded metadata for {len(self.checkpoints_metadata)} checkpoints") except Exception as e: self.logger.error(f"Failed to load metadata: {e}") def _save_metadata(self): """Save checkpoint metadata to file""" try: data = { "checkpoints": { checkpoint_id: asdict(metadata) for checkpoint_id, metadata in self.checkpoints_metadata.items() }, "best_checkpoint_id": self.best_checkpoint_id, "best_metric_value": self.best_metric_value, "last_updated": time.time() } # Write to temporary file first temp_file = self.metadata_file.with_suffix('.tmp') with open(temp_file, 'w') as f: json.dump(data, f, indent=2) # Atomic rename temp_file.replace(self.metadata_file) except Exception as e: self.logger.error(f"Failed to save metadata: {e}") def get_storage_usage(self) -> Dict[str, Any]: """Get storage usage statistics""" total_size = 0 checkpoint_count = len(self.checkpoints_metadata) for metadata in self.checkpoints_metadata.values(): total_size += metadata.file_size return { "total_size_bytes": total_size, "total_size_mb": total_size / (1024 * 1024), "total_size_gb": total_size / (1024 * 1024 * 1024), "checkpoint_count": checkpoint_count, "average_size_mb": (total_size / checkpoint_count / (1024 * 1024)) if checkpoint_count > 0 else 0, "checkpoint_directory": str(self.checkpoint_dir) } def cleanup_all_checkpoints(self): """Remove all checkpoints (dangerous operation)""" with self.lock: for metadata in self.checkpoints_metadata.values(): checkpoint_path = Path(metadata.file_path) if checkpoint_path.exists(): checkpoint_path.unlink() self.checkpoints_metadata.clear() self.best_checkpoint_id = None self.best_metric_value = None # Remove metadata file if self.metadata_file.exists(): self.metadata_file.unlink() self.logger.info("Cleaned up all checkpoints") # Example usage and testing if __name__ == "__main__": # Create checkpoint manager checkpoint_manager = CheckpointManager( checkpoint_dir="./test_checkpoints", max_checkpoints=5, save_interval=100 ) # Simulate saving checkpoints for step in range(0, 1000, 100): model_state = {"layer_weights": torch.randn(10, 10)} optimizer_state = {"param_groups": [{"lr": 0.001}]} metrics = { "loss": 1.0 - step / 1000.0, # Decreasing loss "accuracy": step / 1000.0 # Increasing accuracy } checkpoint_id = checkpoint_manager.save_checkpoint( model_state=model_state, optimizer_state=optimizer_state, step=step, loss=metrics["loss"], metrics=metrics, force_save=True ) print(f"Saved checkpoint: {checkpoint_id}") # List checkpoints print("\nAvailable checkpoints:") for metadata in checkpoint_manager.list_checkpoints(): print(f" {metadata.checkpoint_id}: step {metadata.step}, loss {metadata.loss:.3f}") # Load best checkpoint best_checkpoint = checkpoint_manager.load_best_checkpoint() print(f"\nLoaded best checkpoint: {checkpoint_manager.best_checkpoint_id}") # Get storage usage usage = checkpoint_manager.get_storage_usage() print(f"\nStorage usage: {usage['total_size_mb']:.2f} MB ({usage['checkpoint_count']} checkpoints)") # Cleanup checkpoint_manager.cleanup_all_checkpoints() print("Cleaned up test checkpoints")