""" Model Loading and Management Utilities This module provides utilities for loading trained EOSDIS GNN models, managing checkpoints, and handling model metadata. Designed for use in HuggingFace model repositories and production inference pipelines. Key features: - Model loading from various checkpoint formats - Automatic configuration detection - Model validation and compatibility checks - Device management and optimization - Metadata handling """ import torch import torch.nn as nn from torch_geometric.data import HeteroData import os import json import logging import warnings from pathlib import Path from typing import Dict, Any, Optional, Union, Tuple, List import pickle from datetime import datetime from model_architecture import EOSDIS_HeteroGNN, ModelValidationError logger = logging.getLogger(__name__) class ModelLoadError(Exception): """Custom exception for model loading errors.""" pass class CheckpointManager: """ Manages model checkpoints and metadata. Handles various checkpoint formats, validation, and provides utilities for model versioning and compatibility checks. """ SUPPORTED_FORMATS = ['.pt', '.pth', '.pkl', '.pickle'] CONFIG_FILENAME = 'model_config.json' METADATA_FILENAME = 'model_metadata.json' def __init__(self, checkpoint_dir: Optional[str] = None): """ Initialize checkpoint manager. Args: checkpoint_dir: Directory containing model checkpoints """ self.checkpoint_dir = Path(checkpoint_dir) if checkpoint_dir else None def save_checkpoint(self, model: EOSDIS_HeteroGNN, filepath: Union[str, Path], metadata: Optional[Dict[str, Any]] = None, save_config: bool = True) -> Dict[str, str]: """ Save model checkpoint with configuration and metadata. Args: model: Model instance to save filepath: Path to save the checkpoint metadata: Additional metadata to save save_config: Whether to save model configuration separately Returns: Dictionary with paths to saved files """ filepath = Path(filepath) filepath.parent.mkdir(parents=True, exist_ok=True) # Prepare checkpoint data checkpoint_data = { 'model_state_dict': model.state_dict(), 'model_config': model.get_config(), 'save_timestamp': datetime.now().isoformat(), 'pytorch_version': torch.__version__ } if metadata: checkpoint_data['metadata'] = metadata # Save main checkpoint torch.save(checkpoint_data, filepath) saved_files = {'checkpoint': str(filepath)} # Save configuration separately for easy access if save_config: config_path = filepath.parent / self.CONFIG_FILENAME with open(config_path, 'w') as f: json.dump(model.get_config(), f, indent=2) saved_files['config'] = str(config_path) # Save metadata separately if metadata: metadata_path = filepath.parent / self.METADATA_FILENAME with open(metadata_path, 'w') as f: json.dump(metadata, f, indent=2) saved_files['metadata'] = str(metadata_path) logger.info(f"Saved model checkpoint to {filepath}") return saved_files def load_checkpoint(self, filepath: Union[str, Path], map_location: Optional[str] = None) -> Dict[str, Any]: """ Load checkpoint data from file. Args: filepath: Path to checkpoint file map_location: Device to map the checkpoint to Returns: Dictionary containing checkpoint data Raises: ModelLoadError: If checkpoint cannot be loaded """ filepath = Path(filepath) if not filepath.exists(): raise ModelLoadError(f"Checkpoint file not found: {filepath}") if filepath.suffix not in self.SUPPORTED_FORMATS: raise ModelLoadError( f"Unsupported checkpoint format: {filepath.suffix}. " f"Supported formats: {self.SUPPORTED_FORMATS}" ) try: if map_location: checkpoint_data = torch.load(filepath, map_location=map_location) else: checkpoint_data = torch.load(filepath) logger.info(f"Loaded checkpoint from {filepath}") return checkpoint_data except Exception as e: raise ModelLoadError(f"Failed to load checkpoint from {filepath}: {str(e)}") def validate_checkpoint(self, checkpoint_data: Dict[str, Any]) -> bool: """ Validate checkpoint data structure and contents. Args: checkpoint_data: Checkpoint data dictionary Returns: True if checkpoint is valid Raises: ModelLoadError: If checkpoint is invalid """ required_keys = ['model_state_dict', 'model_config'] missing_keys = [key for key in required_keys if key not in checkpoint_data] if missing_keys: raise ModelLoadError(f"Checkpoint missing required keys: {missing_keys}") # Validate model config try: config = checkpoint_data['model_config'] if not isinstance(config, dict): raise ModelLoadError("model_config must be a dictionary") # Check required config fields required_config_keys = ['metadata', 'hidden_channels', 'num_layers'] missing_config_keys = [key for key in required_config_keys if key not in config] if missing_config_keys: raise ModelLoadError(f"Model config missing required keys: {missing_config_keys}") except Exception as e: raise ModelLoadError(f"Invalid model configuration: {str(e)}") return True class ModelLoader: """ High-level interface for loading EOSDIS GNN models. Provides methods for loading models from various sources including local files, HuggingFace repositories, and custom checkpoint formats. """ def __init__(self, device: Optional[str] = None): """ Initialize model loader. Args: device: Target device for loaded models ('cpu', 'cuda', etc.) """ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') self.checkpoint_manager = CheckpointManager() logger.info(f"ModelLoader initialized with device: {self.device}") def load_from_checkpoint(self, checkpoint_path: Union[str, Path], strict: bool = True, eval_mode: bool = True) -> EOSDIS_HeteroGNN: """ Load model from a checkpoint file. Args: checkpoint_path: Path to checkpoint file strict: Whether to strictly enforce state dict matching eval_mode: Whether to set model to evaluation mode Returns: Loaded model instance Raises: ModelLoadError: If model cannot be loaded """ checkpoint_path = Path(checkpoint_path) # Load checkpoint data checkpoint_data = self.checkpoint_manager.load_checkpoint( checkpoint_path, map_location=self.device ) # Validate checkpoint self.checkpoint_manager.validate_checkpoint(checkpoint_data) # Extract configuration and create model config = checkpoint_data['model_config'] try: model = EOSDIS_HeteroGNN.from_config(config) model.to(self.device) # Load state dict model.load_state_dict(checkpoint_data['model_state_dict'], strict=strict) if eval_mode: model.eval() # Log model info self._log_model_info(model, checkpoint_data) return model except Exception as e: raise ModelLoadError(f"Failed to create model from checkpoint: {str(e)}") def load_from_directory(self, model_dir: Union[str, Path], checkpoint_name: str = 'model.pt', strict: bool = True, eval_mode: bool = True) -> EOSDIS_HeteroGNN: """ Load model from a directory containing checkpoints and configuration. Args: model_dir: Directory containing model files checkpoint_name: Name of the checkpoint file strict: Whether to strictly enforce state dict matching eval_mode: Whether to set model to evaluation mode Returns: Loaded model instance """ model_dir = Path(model_dir) if not model_dir.exists() or not model_dir.is_dir(): raise ModelLoadError(f"Model directory not found: {model_dir}") # Look for checkpoint file checkpoint_path = model_dir / checkpoint_name if not checkpoint_path.exists(): # Try to find any checkpoint file checkpoint_files = [] for ext in CheckpointManager.SUPPORTED_FORMATS: checkpoint_files.extend(model_dir.glob(f"*{ext}")) if not checkpoint_files: raise ModelLoadError(f"No checkpoint files found in {model_dir}") checkpoint_path = checkpoint_files[0] logger.info(f"Using checkpoint file: {checkpoint_path}") return self.load_from_checkpoint(checkpoint_path, strict=strict, eval_mode=eval_mode) def load_from_config(self, config_path: Union[str, Path], checkpoint_path: Optional[Union[str, Path]] = None, strict: bool = True, eval_mode: bool = True) -> EOSDIS_HeteroGNN: """ Load model from separate configuration and checkpoint files. Args: config_path: Path to configuration JSON file checkpoint_path: Path to checkpoint file (optional) strict: Whether to strictly enforce state dict matching eval_mode: Whether to set model to evaluation mode Returns: Loaded model instance """ config_path = Path(config_path) if not config_path.exists(): raise ModelLoadError(f"Configuration file not found: {config_path}") # Load configuration try: model = EOSDIS_HeteroGNN.load_config(config_path) model.to(self.device) # Load checkpoint if provided if checkpoint_path: checkpoint_path = Path(checkpoint_path) if checkpoint_path.exists(): checkpoint_data = torch.load(checkpoint_path, map_location=self.device) if 'model_state_dict' in checkpoint_data: model.load_state_dict(checkpoint_data['model_state_dict'], strict=strict) else: # Assume the file contains just the state dict model.load_state_dict(checkpoint_data, strict=strict) else: logger.warning(f"Checkpoint file not found: {checkpoint_path}") if eval_mode: model.eval() return model except Exception as e: raise ModelLoadError(f"Failed to load model from config: {str(e)}") def _log_model_info(self, model: EOSDIS_HeteroGNN, checkpoint_data: Dict[str, Any]): """Log information about the loaded model.""" total_params = sum(p.numel() for p in model.parameters()) logger.info(f"Model loaded successfully:") logger.info(f" - Parameters: {total_params:,}") logger.info(f" - Device: {self.device}") logger.info(f" - Node types: {len(model.node_types)}") logger.info(f" - Edge types: {len(model.edge_types)}") if 'save_timestamp' in checkpoint_data: logger.info(f" - Saved: {checkpoint_data['save_timestamp']}") if 'metadata' in checkpoint_data: metadata = checkpoint_data['metadata'] if 'training_epochs' in metadata: logger.info(f" - Training epochs: {metadata['training_epochs']}") if 'best_val_loss' in metadata: logger.info(f" - Best validation loss: {metadata['best_val_loss']:.4f}") def load_model(checkpoint_path: Union[str, Path], device: Optional[str] = None, eval_mode: bool = True) -> EOSDIS_HeteroGNN: """ Convenience function to load a model from a checkpoint. Args: checkpoint_path: Path to checkpoint file or directory device: Target device ('cpu', 'cuda', etc.) eval_mode: Whether to set model to evaluation mode Returns: Loaded model instance """ loader = ModelLoader(device=device) checkpoint_path = Path(checkpoint_path) if checkpoint_path.is_dir(): return loader.load_from_directory(checkpoint_path, eval_mode=eval_mode) else: return loader.load_from_checkpoint(checkpoint_path, eval_mode=eval_mode) def save_model(model: EOSDIS_HeteroGNN, save_path: Union[str, Path], metadata: Optional[Dict[str, Any]] = None) -> Dict[str, str]: """ Convenience function to save a model with configuration and metadata. Args: model: Model instance to save save_path: Path to save the model metadata: Additional metadata to save Returns: Dictionary with paths to saved files """ manager = CheckpointManager() return manager.save_checkpoint(model, save_path, metadata=metadata) def validate_model_compatibility(model: EOSDIS_HeteroGNN, data: HeteroData) -> bool: """ Validate that a model is compatible with given graph data. Args: model: Model instance to validate data: Graph data to check compatibility with Returns: True if compatible Raises: ModelValidationError: If model is not compatible """ # Check node types data_node_types = set(data.node_types) model_node_types = set(model.node_types) if not data_node_types.issubset(model_node_types): missing_types = data_node_types - model_node_types raise ModelValidationError(f"Model missing node types: {missing_types}") # Check edge types data_edge_types = set(data.edge_types) model_edge_types = set(model.edge_types) if not data_edge_types.issubset(model_edge_types): missing_types = data_edge_types - model_edge_types raise ModelValidationError(f"Model missing edge types: {missing_types}") # Check input dimensions for node_type in data_node_types: if node_type in data.x_dict: data_dim = data.x_dict[node_type].size(-1) if data_dim != model.input_dim: raise ModelValidationError( f"Input dimension mismatch for node type '{node_type}': " f"data has {data_dim}, model expects {model.input_dim}" ) return True def get_model_summary(model: EOSDIS_HeteroGNN) -> Dict[str, Any]: """ Get a comprehensive summary of model architecture and parameters. Args: model: Model instance to summarize Returns: Dictionary containing model summary """ from model_architecture import get_model_info summary = get_model_info(model) # Add layer-specific information layer_info = [] for i, layer in enumerate(model.conv_layers): layer_info.append({ 'layer_index': i, 'layer_type': layer.__class__.__name__, 'conv_type': layer.conv_type, 'hidden_channels': layer.hidden_channels, 'dropout': layer.dropout_prob }) summary['layer_details'] = layer_info # Add memory usage estimate (rough) total_params = summary['total_parameters'] memory_mb = (total_params * 4) / (1024 * 1024) # Assuming float32 summary['estimated_memory_mb'] = round(memory_mb, 2) return summary def optimize_model_for_inference(model: EOSDIS_HeteroGNN) -> EOSDIS_HeteroGNN: """ Optimize model for inference by applying various optimizations. Args: model: Model to optimize Returns: Optimized model instance """ # Set to evaluation mode model.eval() # Disable gradient computation for param in model.parameters(): param.requires_grad_(False) # Try to apply torch.jit.script if compatible try: # Note: This might not work with all model configurations # Keeping it optional logger.info("Attempting to apply TorchScript optimization...") # model = torch.jit.script(model) # Commented out as it may not work with HeteroConv except Exception as e: logger.warning(f"Could not apply TorchScript optimization: {e}") # Apply other optimizations if hasattr(torch.backends, 'cudnn'): torch.backends.cudnn.benchmark = True logger.info("Model optimized for inference") return model