| | """ |
| | Model Loading and Management Utilities |
| | |
| | This module provides utilities for loading trained EOSDIS GNN models, |
| | managing checkpoints, and handling model metadata. Designed for use |
| | in HuggingFace model repositories and production inference pipelines. |
| | |
| | Key features: |
| | - Model loading from various checkpoint formats |
| | - Automatic configuration detection |
| | - Model validation and compatibility checks |
| | - Device management and optimization |
| | - Metadata handling |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch_geometric.data import HeteroData |
| | import os |
| | import json |
| | import logging |
| | import warnings |
| | from pathlib import Path |
| | from typing import Dict, Any, Optional, Union, Tuple, List |
| | import pickle |
| | from datetime import datetime |
| |
|
| | from model_architecture import EOSDIS_HeteroGNN, ModelValidationError |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class ModelLoadError(Exception): |
| | """Custom exception for model loading errors.""" |
| | pass |
| |
|
| |
|
| | class CheckpointManager: |
| | """ |
| | Manages model checkpoints and metadata. |
| | |
| | Handles various checkpoint formats, validation, and provides |
| | utilities for model versioning and compatibility checks. |
| | """ |
| | |
| | SUPPORTED_FORMATS = ['.pt', '.pth', '.pkl', '.pickle'] |
| | CONFIG_FILENAME = 'model_config.json' |
| | METADATA_FILENAME = 'model_metadata.json' |
| | |
| | def __init__(self, checkpoint_dir: Optional[str] = None): |
| | """ |
| | Initialize checkpoint manager. |
| | |
| | Args: |
| | checkpoint_dir: Directory containing model checkpoints |
| | """ |
| | self.checkpoint_dir = Path(checkpoint_dir) if checkpoint_dir else None |
| | |
| | def save_checkpoint(self, |
| | model: EOSDIS_HeteroGNN, |
| | filepath: Union[str, Path], |
| | metadata: Optional[Dict[str, Any]] = None, |
| | save_config: bool = True) -> Dict[str, str]: |
| | """ |
| | Save model checkpoint with configuration and metadata. |
| | |
| | Args: |
| | model: Model instance to save |
| | filepath: Path to save the checkpoint |
| | metadata: Additional metadata to save |
| | save_config: Whether to save model configuration separately |
| | |
| | Returns: |
| | Dictionary with paths to saved files |
| | """ |
| | filepath = Path(filepath) |
| | filepath.parent.mkdir(parents=True, exist_ok=True) |
| | |
| | |
| | checkpoint_data = { |
| | 'model_state_dict': model.state_dict(), |
| | 'model_config': model.get_config(), |
| | 'save_timestamp': datetime.now().isoformat(), |
| | 'pytorch_version': torch.__version__ |
| | } |
| | |
| | if metadata: |
| | checkpoint_data['metadata'] = metadata |
| | |
| | |
| | torch.save(checkpoint_data, filepath) |
| | saved_files = {'checkpoint': str(filepath)} |
| | |
| | |
| | if save_config: |
| | config_path = filepath.parent / self.CONFIG_FILENAME |
| | with open(config_path, 'w') as f: |
| | json.dump(model.get_config(), f, indent=2) |
| | saved_files['config'] = str(config_path) |
| | |
| | |
| | if metadata: |
| | metadata_path = filepath.parent / self.METADATA_FILENAME |
| | with open(metadata_path, 'w') as f: |
| | json.dump(metadata, f, indent=2) |
| | saved_files['metadata'] = str(metadata_path) |
| | |
| | logger.info(f"Saved model checkpoint to {filepath}") |
| | return saved_files |
| | |
| | def load_checkpoint(self, |
| | filepath: Union[str, Path], |
| | map_location: Optional[str] = None) -> Dict[str, Any]: |
| | """ |
| | Load checkpoint data from file. |
| | |
| | Args: |
| | filepath: Path to checkpoint file |
| | map_location: Device to map the checkpoint to |
| | |
| | Returns: |
| | Dictionary containing checkpoint data |
| | |
| | Raises: |
| | ModelLoadError: If checkpoint cannot be loaded |
| | """ |
| | filepath = Path(filepath) |
| | |
| | if not filepath.exists(): |
| | raise ModelLoadError(f"Checkpoint file not found: {filepath}") |
| | |
| | if filepath.suffix not in self.SUPPORTED_FORMATS: |
| | raise ModelLoadError( |
| | f"Unsupported checkpoint format: {filepath.suffix}. " |
| | f"Supported formats: {self.SUPPORTED_FORMATS}" |
| | ) |
| | |
| | try: |
| | if map_location: |
| | checkpoint_data = torch.load(filepath, map_location=map_location) |
| | else: |
| | checkpoint_data = torch.load(filepath) |
| | |
| | logger.info(f"Loaded checkpoint from {filepath}") |
| | return checkpoint_data |
| | |
| | except Exception as e: |
| | raise ModelLoadError(f"Failed to load checkpoint from {filepath}: {str(e)}") |
| | |
| | def validate_checkpoint(self, checkpoint_data: Dict[str, Any]) -> bool: |
| | """ |
| | Validate checkpoint data structure and contents. |
| | |
| | Args: |
| | checkpoint_data: Checkpoint data dictionary |
| | |
| | Returns: |
| | True if checkpoint is valid |
| | |
| | Raises: |
| | ModelLoadError: If checkpoint is invalid |
| | """ |
| | required_keys = ['model_state_dict', 'model_config'] |
| | missing_keys = [key for key in required_keys if key not in checkpoint_data] |
| | |
| | if missing_keys: |
| | raise ModelLoadError(f"Checkpoint missing required keys: {missing_keys}") |
| | |
| | |
| | try: |
| | config = checkpoint_data['model_config'] |
| | if not isinstance(config, dict): |
| | raise ModelLoadError("model_config must be a dictionary") |
| | |
| | |
| | required_config_keys = ['metadata', 'hidden_channels', 'num_layers'] |
| | missing_config_keys = [key for key in required_config_keys if key not in config] |
| | if missing_config_keys: |
| | raise ModelLoadError(f"Model config missing required keys: {missing_config_keys}") |
| | |
| | except Exception as e: |
| | raise ModelLoadError(f"Invalid model configuration: {str(e)}") |
| | |
| | return True |
| |
|
| |
|
| | class ModelLoader: |
| | """ |
| | High-level interface for loading EOSDIS GNN models. |
| | |
| | Provides methods for loading models from various sources including |
| | local files, HuggingFace repositories, and custom checkpoint formats. |
| | """ |
| | |
| | def __init__(self, device: Optional[str] = None): |
| | """ |
| | Initialize model loader. |
| | |
| | Args: |
| | device: Target device for loaded models ('cpu', 'cuda', etc.) |
| | """ |
| | self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') |
| | self.checkpoint_manager = CheckpointManager() |
| | |
| | logger.info(f"ModelLoader initialized with device: {self.device}") |
| | |
| | def load_from_checkpoint(self, |
| | checkpoint_path: Union[str, Path], |
| | strict: bool = True, |
| | eval_mode: bool = True) -> EOSDIS_HeteroGNN: |
| | """ |
| | Load model from a checkpoint file. |
| | |
| | Args: |
| | checkpoint_path: Path to checkpoint file |
| | strict: Whether to strictly enforce state dict matching |
| | eval_mode: Whether to set model to evaluation mode |
| | |
| | Returns: |
| | Loaded model instance |
| | |
| | Raises: |
| | ModelLoadError: If model cannot be loaded |
| | """ |
| | checkpoint_path = Path(checkpoint_path) |
| | |
| | |
| | checkpoint_data = self.checkpoint_manager.load_checkpoint( |
| | checkpoint_path, map_location=self.device |
| | ) |
| | |
| | |
| | self.checkpoint_manager.validate_checkpoint(checkpoint_data) |
| | |
| | |
| | config = checkpoint_data['model_config'] |
| | |
| | try: |
| | model = EOSDIS_HeteroGNN.from_config(config) |
| | model.to(self.device) |
| | |
| | |
| | model.load_state_dict(checkpoint_data['model_state_dict'], strict=strict) |
| | |
| | if eval_mode: |
| | model.eval() |
| | |
| | |
| | self._log_model_info(model, checkpoint_data) |
| | |
| | return model |
| | |
| | except Exception as e: |
| | raise ModelLoadError(f"Failed to create model from checkpoint: {str(e)}") |
| | |
| | def load_from_directory(self, |
| | model_dir: Union[str, Path], |
| | checkpoint_name: str = 'model.pt', |
| | strict: bool = True, |
| | eval_mode: bool = True) -> EOSDIS_HeteroGNN: |
| | """ |
| | Load model from a directory containing checkpoints and configuration. |
| | |
| | Args: |
| | model_dir: Directory containing model files |
| | checkpoint_name: Name of the checkpoint file |
| | strict: Whether to strictly enforce state dict matching |
| | eval_mode: Whether to set model to evaluation mode |
| | |
| | Returns: |
| | Loaded model instance |
| | """ |
| | model_dir = Path(model_dir) |
| | |
| | if not model_dir.exists() or not model_dir.is_dir(): |
| | raise ModelLoadError(f"Model directory not found: {model_dir}") |
| | |
| | |
| | checkpoint_path = model_dir / checkpoint_name |
| | |
| | if not checkpoint_path.exists(): |
| | |
| | checkpoint_files = [] |
| | for ext in CheckpointManager.SUPPORTED_FORMATS: |
| | checkpoint_files.extend(model_dir.glob(f"*{ext}")) |
| | |
| | if not checkpoint_files: |
| | raise ModelLoadError(f"No checkpoint files found in {model_dir}") |
| | |
| | checkpoint_path = checkpoint_files[0] |
| | logger.info(f"Using checkpoint file: {checkpoint_path}") |
| | |
| | return self.load_from_checkpoint(checkpoint_path, strict=strict, eval_mode=eval_mode) |
| | |
| | def load_from_config(self, |
| | config_path: Union[str, Path], |
| | checkpoint_path: Optional[Union[str, Path]] = None, |
| | strict: bool = True, |
| | eval_mode: bool = True) -> EOSDIS_HeteroGNN: |
| | """ |
| | Load model from separate configuration and checkpoint files. |
| | |
| | Args: |
| | config_path: Path to configuration JSON file |
| | checkpoint_path: Path to checkpoint file (optional) |
| | strict: Whether to strictly enforce state dict matching |
| | eval_mode: Whether to set model to evaluation mode |
| | |
| | Returns: |
| | Loaded model instance |
| | """ |
| | config_path = Path(config_path) |
| | |
| | if not config_path.exists(): |
| | raise ModelLoadError(f"Configuration file not found: {config_path}") |
| | |
| | |
| | try: |
| | model = EOSDIS_HeteroGNN.load_config(config_path) |
| | model.to(self.device) |
| | |
| | |
| | if checkpoint_path: |
| | checkpoint_path = Path(checkpoint_path) |
| | if checkpoint_path.exists(): |
| | checkpoint_data = torch.load(checkpoint_path, map_location=self.device) |
| | if 'model_state_dict' in checkpoint_data: |
| | model.load_state_dict(checkpoint_data['model_state_dict'], strict=strict) |
| | else: |
| | |
| | model.load_state_dict(checkpoint_data, strict=strict) |
| | else: |
| | logger.warning(f"Checkpoint file not found: {checkpoint_path}") |
| | |
| | if eval_mode: |
| | model.eval() |
| | |
| | return model |
| | |
| | except Exception as e: |
| | raise ModelLoadError(f"Failed to load model from config: {str(e)}") |
| | |
| | def _log_model_info(self, model: EOSDIS_HeteroGNN, checkpoint_data: Dict[str, Any]): |
| | """Log information about the loaded model.""" |
| | total_params = sum(p.numel() for p in model.parameters()) |
| | |
| | logger.info(f"Model loaded successfully:") |
| | logger.info(f" - Parameters: {total_params:,}") |
| | logger.info(f" - Device: {self.device}") |
| | logger.info(f" - Node types: {len(model.node_types)}") |
| | logger.info(f" - Edge types: {len(model.edge_types)}") |
| | |
| | if 'save_timestamp' in checkpoint_data: |
| | logger.info(f" - Saved: {checkpoint_data['save_timestamp']}") |
| | |
| | if 'metadata' in checkpoint_data: |
| | metadata = checkpoint_data['metadata'] |
| | if 'training_epochs' in metadata: |
| | logger.info(f" - Training epochs: {metadata['training_epochs']}") |
| | if 'best_val_loss' in metadata: |
| | logger.info(f" - Best validation loss: {metadata['best_val_loss']:.4f}") |
| |
|
| |
|
| | def load_model(checkpoint_path: Union[str, Path], |
| | device: Optional[str] = None, |
| | eval_mode: bool = True) -> EOSDIS_HeteroGNN: |
| | """ |
| | Convenience function to load a model from a checkpoint. |
| | |
| | Args: |
| | checkpoint_path: Path to checkpoint file or directory |
| | device: Target device ('cpu', 'cuda', etc.) |
| | eval_mode: Whether to set model to evaluation mode |
| | |
| | Returns: |
| | Loaded model instance |
| | """ |
| | loader = ModelLoader(device=device) |
| | |
| | checkpoint_path = Path(checkpoint_path) |
| | |
| | if checkpoint_path.is_dir(): |
| | return loader.load_from_directory(checkpoint_path, eval_mode=eval_mode) |
| | else: |
| | return loader.load_from_checkpoint(checkpoint_path, eval_mode=eval_mode) |
| |
|
| |
|
| | def save_model(model: EOSDIS_HeteroGNN, |
| | save_path: Union[str, Path], |
| | metadata: Optional[Dict[str, Any]] = None) -> Dict[str, str]: |
| | """ |
| | Convenience function to save a model with configuration and metadata. |
| | |
| | Args: |
| | model: Model instance to save |
| | save_path: Path to save the model |
| | metadata: Additional metadata to save |
| | |
| | Returns: |
| | Dictionary with paths to saved files |
| | """ |
| | manager = CheckpointManager() |
| | return manager.save_checkpoint(model, save_path, metadata=metadata) |
| |
|
| |
|
| | def validate_model_compatibility(model: EOSDIS_HeteroGNN, |
| | data: HeteroData) -> bool: |
| | """ |
| | Validate that a model is compatible with given graph data. |
| | |
| | Args: |
| | model: Model instance to validate |
| | data: Graph data to check compatibility with |
| | |
| | Returns: |
| | True if compatible |
| | |
| | Raises: |
| | ModelValidationError: If model is not compatible |
| | """ |
| | |
| | data_node_types = set(data.node_types) |
| | model_node_types = set(model.node_types) |
| | |
| | if not data_node_types.issubset(model_node_types): |
| | missing_types = data_node_types - model_node_types |
| | raise ModelValidationError(f"Model missing node types: {missing_types}") |
| | |
| | |
| | data_edge_types = set(data.edge_types) |
| | model_edge_types = set(model.edge_types) |
| | |
| | if not data_edge_types.issubset(model_edge_types): |
| | missing_types = data_edge_types - model_edge_types |
| | raise ModelValidationError(f"Model missing edge types: {missing_types}") |
| | |
| | |
| | for node_type in data_node_types: |
| | if node_type in data.x_dict: |
| | data_dim = data.x_dict[node_type].size(-1) |
| | if data_dim != model.input_dim: |
| | raise ModelValidationError( |
| | f"Input dimension mismatch for node type '{node_type}': " |
| | f"data has {data_dim}, model expects {model.input_dim}" |
| | ) |
| | |
| | return True |
| |
|
| |
|
| | def get_model_summary(model: EOSDIS_HeteroGNN) -> Dict[str, Any]: |
| | """ |
| | Get a comprehensive summary of model architecture and parameters. |
| | |
| | Args: |
| | model: Model instance to summarize |
| | |
| | Returns: |
| | Dictionary containing model summary |
| | """ |
| | from model_architecture import get_model_info |
| | |
| | summary = get_model_info(model) |
| | |
| | |
| | layer_info = [] |
| | for i, layer in enumerate(model.conv_layers): |
| | layer_info.append({ |
| | 'layer_index': i, |
| | 'layer_type': layer.__class__.__name__, |
| | 'conv_type': layer.conv_type, |
| | 'hidden_channels': layer.hidden_channels, |
| | 'dropout': layer.dropout_prob |
| | }) |
| | |
| | summary['layer_details'] = layer_info |
| | |
| | |
| | total_params = summary['total_parameters'] |
| | memory_mb = (total_params * 4) / (1024 * 1024) |
| | summary['estimated_memory_mb'] = round(memory_mb, 2) |
| | |
| | return summary |
| |
|
| |
|
| | def optimize_model_for_inference(model: EOSDIS_HeteroGNN) -> EOSDIS_HeteroGNN: |
| | """ |
| | Optimize model for inference by applying various optimizations. |
| | |
| | Args: |
| | model: Model to optimize |
| | |
| | Returns: |
| | Optimized model instance |
| | """ |
| | |
| | model.eval() |
| | |
| | |
| | for param in model.parameters(): |
| | param.requires_grad_(False) |
| | |
| | |
| | try: |
| | |
| | |
| | logger.info("Attempting to apply TorchScript optimization...") |
| | |
| | except Exception as e: |
| | logger.warning(f"Could not apply TorchScript optimization: {e}") |
| | |
| | |
| | if hasattr(torch.backends, 'cudnn'): |
| | torch.backends.cudnn.benchmark = True |
| | |
| | logger.info("Model optimized for inference") |
| | return model |