Graph Machine Learning
PyTorch
English
eosdis_hetero_gnn
gnn
earth
nasa
1.0.0
edgraph-gnn-graphsage / model_utils.py
arminmehrabian's picture
Add GNN training and utility scripts
a9d7940
"""
Model Loading and Management Utilities
This module provides utilities for loading trained EOSDIS GNN models,
managing checkpoints, and handling model metadata. Designed for use
in HuggingFace model repositories and production inference pipelines.
Key features:
- Model loading from various checkpoint formats
- Automatic configuration detection
- Model validation and compatibility checks
- Device management and optimization
- Metadata handling
"""
import torch
import torch.nn as nn
from torch_geometric.data import HeteroData
import os
import json
import logging
import warnings
from pathlib import Path
from typing import Dict, Any, Optional, Union, Tuple, List
import pickle
from datetime import datetime
from model_architecture import EOSDIS_HeteroGNN, ModelValidationError
logger = logging.getLogger(__name__)
class ModelLoadError(Exception):
"""Custom exception for model loading errors."""
pass
class CheckpointManager:
"""
Manages model checkpoints and metadata.
Handles various checkpoint formats, validation, and provides
utilities for model versioning and compatibility checks.
"""
SUPPORTED_FORMATS = ['.pt', '.pth', '.pkl', '.pickle']
CONFIG_FILENAME = 'model_config.json'
METADATA_FILENAME = 'model_metadata.json'
def __init__(self, checkpoint_dir: Optional[str] = None):
"""
Initialize checkpoint manager.
Args:
checkpoint_dir: Directory containing model checkpoints
"""
self.checkpoint_dir = Path(checkpoint_dir) if checkpoint_dir else None
def save_checkpoint(self,
model: EOSDIS_HeteroGNN,
filepath: Union[str, Path],
metadata: Optional[Dict[str, Any]] = None,
save_config: bool = True) -> Dict[str, str]:
"""
Save model checkpoint with configuration and metadata.
Args:
model: Model instance to save
filepath: Path to save the checkpoint
metadata: Additional metadata to save
save_config: Whether to save model configuration separately
Returns:
Dictionary with paths to saved files
"""
filepath = Path(filepath)
filepath.parent.mkdir(parents=True, exist_ok=True)
# Prepare checkpoint data
checkpoint_data = {
'model_state_dict': model.state_dict(),
'model_config': model.get_config(),
'save_timestamp': datetime.now().isoformat(),
'pytorch_version': torch.__version__
}
if metadata:
checkpoint_data['metadata'] = metadata
# Save main checkpoint
torch.save(checkpoint_data, filepath)
saved_files = {'checkpoint': str(filepath)}
# Save configuration separately for easy access
if save_config:
config_path = filepath.parent / self.CONFIG_FILENAME
with open(config_path, 'w') as f:
json.dump(model.get_config(), f, indent=2)
saved_files['config'] = str(config_path)
# Save metadata separately
if metadata:
metadata_path = filepath.parent / self.METADATA_FILENAME
with open(metadata_path, 'w') as f:
json.dump(metadata, f, indent=2)
saved_files['metadata'] = str(metadata_path)
logger.info(f"Saved model checkpoint to {filepath}")
return saved_files
def load_checkpoint(self,
filepath: Union[str, Path],
map_location: Optional[str] = None) -> Dict[str, Any]:
"""
Load checkpoint data from file.
Args:
filepath: Path to checkpoint file
map_location: Device to map the checkpoint to
Returns:
Dictionary containing checkpoint data
Raises:
ModelLoadError: If checkpoint cannot be loaded
"""
filepath = Path(filepath)
if not filepath.exists():
raise ModelLoadError(f"Checkpoint file not found: {filepath}")
if filepath.suffix not in self.SUPPORTED_FORMATS:
raise ModelLoadError(
f"Unsupported checkpoint format: {filepath.suffix}. "
f"Supported formats: {self.SUPPORTED_FORMATS}"
)
try:
if map_location:
checkpoint_data = torch.load(filepath, map_location=map_location)
else:
checkpoint_data = torch.load(filepath)
logger.info(f"Loaded checkpoint from {filepath}")
return checkpoint_data
except Exception as e:
raise ModelLoadError(f"Failed to load checkpoint from {filepath}: {str(e)}")
def validate_checkpoint(self, checkpoint_data: Dict[str, Any]) -> bool:
"""
Validate checkpoint data structure and contents.
Args:
checkpoint_data: Checkpoint data dictionary
Returns:
True if checkpoint is valid
Raises:
ModelLoadError: If checkpoint is invalid
"""
required_keys = ['model_state_dict', 'model_config']
missing_keys = [key for key in required_keys if key not in checkpoint_data]
if missing_keys:
raise ModelLoadError(f"Checkpoint missing required keys: {missing_keys}")
# Validate model config
try:
config = checkpoint_data['model_config']
if not isinstance(config, dict):
raise ModelLoadError("model_config must be a dictionary")
# Check required config fields
required_config_keys = ['metadata', 'hidden_channels', 'num_layers']
missing_config_keys = [key for key in required_config_keys if key not in config]
if missing_config_keys:
raise ModelLoadError(f"Model config missing required keys: {missing_config_keys}")
except Exception as e:
raise ModelLoadError(f"Invalid model configuration: {str(e)}")
return True
class ModelLoader:
"""
High-level interface for loading EOSDIS GNN models.
Provides methods for loading models from various sources including
local files, HuggingFace repositories, and custom checkpoint formats.
"""
def __init__(self, device: Optional[str] = None):
"""
Initialize model loader.
Args:
device: Target device for loaded models ('cpu', 'cuda', etc.)
"""
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
self.checkpoint_manager = CheckpointManager()
logger.info(f"ModelLoader initialized with device: {self.device}")
def load_from_checkpoint(self,
checkpoint_path: Union[str, Path],
strict: bool = True,
eval_mode: bool = True) -> EOSDIS_HeteroGNN:
"""
Load model from a checkpoint file.
Args:
checkpoint_path: Path to checkpoint file
strict: Whether to strictly enforce state dict matching
eval_mode: Whether to set model to evaluation mode
Returns:
Loaded model instance
Raises:
ModelLoadError: If model cannot be loaded
"""
checkpoint_path = Path(checkpoint_path)
# Load checkpoint data
checkpoint_data = self.checkpoint_manager.load_checkpoint(
checkpoint_path, map_location=self.device
)
# Validate checkpoint
self.checkpoint_manager.validate_checkpoint(checkpoint_data)
# Extract configuration and create model
config = checkpoint_data['model_config']
try:
model = EOSDIS_HeteroGNN.from_config(config)
model.to(self.device)
# Load state dict
model.load_state_dict(checkpoint_data['model_state_dict'], strict=strict)
if eval_mode:
model.eval()
# Log model info
self._log_model_info(model, checkpoint_data)
return model
except Exception as e:
raise ModelLoadError(f"Failed to create model from checkpoint: {str(e)}")
def load_from_directory(self,
model_dir: Union[str, Path],
checkpoint_name: str = 'model.pt',
strict: bool = True,
eval_mode: bool = True) -> EOSDIS_HeteroGNN:
"""
Load model from a directory containing checkpoints and configuration.
Args:
model_dir: Directory containing model files
checkpoint_name: Name of the checkpoint file
strict: Whether to strictly enforce state dict matching
eval_mode: Whether to set model to evaluation mode
Returns:
Loaded model instance
"""
model_dir = Path(model_dir)
if not model_dir.exists() or not model_dir.is_dir():
raise ModelLoadError(f"Model directory not found: {model_dir}")
# Look for checkpoint file
checkpoint_path = model_dir / checkpoint_name
if not checkpoint_path.exists():
# Try to find any checkpoint file
checkpoint_files = []
for ext in CheckpointManager.SUPPORTED_FORMATS:
checkpoint_files.extend(model_dir.glob(f"*{ext}"))
if not checkpoint_files:
raise ModelLoadError(f"No checkpoint files found in {model_dir}")
checkpoint_path = checkpoint_files[0]
logger.info(f"Using checkpoint file: {checkpoint_path}")
return self.load_from_checkpoint(checkpoint_path, strict=strict, eval_mode=eval_mode)
def load_from_config(self,
config_path: Union[str, Path],
checkpoint_path: Optional[Union[str, Path]] = None,
strict: bool = True,
eval_mode: bool = True) -> EOSDIS_HeteroGNN:
"""
Load model from separate configuration and checkpoint files.
Args:
config_path: Path to configuration JSON file
checkpoint_path: Path to checkpoint file (optional)
strict: Whether to strictly enforce state dict matching
eval_mode: Whether to set model to evaluation mode
Returns:
Loaded model instance
"""
config_path = Path(config_path)
if not config_path.exists():
raise ModelLoadError(f"Configuration file not found: {config_path}")
# Load configuration
try:
model = EOSDIS_HeteroGNN.load_config(config_path)
model.to(self.device)
# Load checkpoint if provided
if checkpoint_path:
checkpoint_path = Path(checkpoint_path)
if checkpoint_path.exists():
checkpoint_data = torch.load(checkpoint_path, map_location=self.device)
if 'model_state_dict' in checkpoint_data:
model.load_state_dict(checkpoint_data['model_state_dict'], strict=strict)
else:
# Assume the file contains just the state dict
model.load_state_dict(checkpoint_data, strict=strict)
else:
logger.warning(f"Checkpoint file not found: {checkpoint_path}")
if eval_mode:
model.eval()
return model
except Exception as e:
raise ModelLoadError(f"Failed to load model from config: {str(e)}")
def _log_model_info(self, model: EOSDIS_HeteroGNN, checkpoint_data: Dict[str, Any]):
"""Log information about the loaded model."""
total_params = sum(p.numel() for p in model.parameters())
logger.info(f"Model loaded successfully:")
logger.info(f" - Parameters: {total_params:,}")
logger.info(f" - Device: {self.device}")
logger.info(f" - Node types: {len(model.node_types)}")
logger.info(f" - Edge types: {len(model.edge_types)}")
if 'save_timestamp' in checkpoint_data:
logger.info(f" - Saved: {checkpoint_data['save_timestamp']}")
if 'metadata' in checkpoint_data:
metadata = checkpoint_data['metadata']
if 'training_epochs' in metadata:
logger.info(f" - Training epochs: {metadata['training_epochs']}")
if 'best_val_loss' in metadata:
logger.info(f" - Best validation loss: {metadata['best_val_loss']:.4f}")
def load_model(checkpoint_path: Union[str, Path],
device: Optional[str] = None,
eval_mode: bool = True) -> EOSDIS_HeteroGNN:
"""
Convenience function to load a model from a checkpoint.
Args:
checkpoint_path: Path to checkpoint file or directory
device: Target device ('cpu', 'cuda', etc.)
eval_mode: Whether to set model to evaluation mode
Returns:
Loaded model instance
"""
loader = ModelLoader(device=device)
checkpoint_path = Path(checkpoint_path)
if checkpoint_path.is_dir():
return loader.load_from_directory(checkpoint_path, eval_mode=eval_mode)
else:
return loader.load_from_checkpoint(checkpoint_path, eval_mode=eval_mode)
def save_model(model: EOSDIS_HeteroGNN,
save_path: Union[str, Path],
metadata: Optional[Dict[str, Any]] = None) -> Dict[str, str]:
"""
Convenience function to save a model with configuration and metadata.
Args:
model: Model instance to save
save_path: Path to save the model
metadata: Additional metadata to save
Returns:
Dictionary with paths to saved files
"""
manager = CheckpointManager()
return manager.save_checkpoint(model, save_path, metadata=metadata)
def validate_model_compatibility(model: EOSDIS_HeteroGNN,
data: HeteroData) -> bool:
"""
Validate that a model is compatible with given graph data.
Args:
model: Model instance to validate
data: Graph data to check compatibility with
Returns:
True if compatible
Raises:
ModelValidationError: If model is not compatible
"""
# Check node types
data_node_types = set(data.node_types)
model_node_types = set(model.node_types)
if not data_node_types.issubset(model_node_types):
missing_types = data_node_types - model_node_types
raise ModelValidationError(f"Model missing node types: {missing_types}")
# Check edge types
data_edge_types = set(data.edge_types)
model_edge_types = set(model.edge_types)
if not data_edge_types.issubset(model_edge_types):
missing_types = data_edge_types - model_edge_types
raise ModelValidationError(f"Model missing edge types: {missing_types}")
# Check input dimensions
for node_type in data_node_types:
if node_type in data.x_dict:
data_dim = data.x_dict[node_type].size(-1)
if data_dim != model.input_dim:
raise ModelValidationError(
f"Input dimension mismatch for node type '{node_type}': "
f"data has {data_dim}, model expects {model.input_dim}"
)
return True
def get_model_summary(model: EOSDIS_HeteroGNN) -> Dict[str, Any]:
"""
Get a comprehensive summary of model architecture and parameters.
Args:
model: Model instance to summarize
Returns:
Dictionary containing model summary
"""
from model_architecture import get_model_info
summary = get_model_info(model)
# Add layer-specific information
layer_info = []
for i, layer in enumerate(model.conv_layers):
layer_info.append({
'layer_index': i,
'layer_type': layer.__class__.__name__,
'conv_type': layer.conv_type,
'hidden_channels': layer.hidden_channels,
'dropout': layer.dropout_prob
})
summary['layer_details'] = layer_info
# Add memory usage estimate (rough)
total_params = summary['total_parameters']
memory_mb = (total_params * 4) / (1024 * 1024) # Assuming float32
summary['estimated_memory_mb'] = round(memory_mb, 2)
return summary
def optimize_model_for_inference(model: EOSDIS_HeteroGNN) -> EOSDIS_HeteroGNN:
"""
Optimize model for inference by applying various optimizations.
Args:
model: Model to optimize
Returns:
Optimized model instance
"""
# Set to evaluation mode
model.eval()
# Disable gradient computation
for param in model.parameters():
param.requires_grad_(False)
# Try to apply torch.jit.script if compatible
try:
# Note: This might not work with all model configurations
# Keeping it optional
logger.info("Attempting to apply TorchScript optimization...")
# model = torch.jit.script(model) # Commented out as it may not work with HeteroConv
except Exception as e:
logger.warning(f"Could not apply TorchScript optimization: {e}")
# Apply other optimizations
if hasattr(torch.backends, 'cudnn'):
torch.backends.cudnn.benchmark = True
logger.info("Model optimized for inference")
return model