sentiment_anals / src /utils /model_loader.py
abdou21367's picture
Upload 64 files
839c56d verified
"""
ModelLoader - Intelligent model loading with automatic type detection
Replaces fragile path-based model detection with metadata-based approach
"""
import torch
from pathlib import Path
import json
import warnings
class ModelLoader:
"""
Utility class for loading models with automatic type detection
Features:
- Loads models using metadata from checkpoint
- No more fragile path-based detection
- Complete model reproducibility
- Quick metadata inspection without loading weights
Example:
>>> # Quick info without loading weights
>>> info = ModelLoader.get_model_info('model.pt')
>>> print(f"Model type: {info['model_type']}")
>>> # Load model automatically
>>> model = ModelLoader.load_model('model.pt', device='cuda')
>>> # Load with optimizer state
>>> data = ModelLoader.load_checkpoint_full('model.pt')
>>> model = data['model']
>>> optimizer.load_state_dict(data['optimizer_state'])
"""
@staticmethod
def get_model_info(checkpoint_path):
"""
Get model information without loading weights (fast)
Args:
checkpoint_path: Path to checkpoint file
Returns:
dict: Model metadata (type, config, metrics, timestamp)
Example:
>>> info = ModelLoader.get_model_info('results/bilstm/best_model.pt')
>>> print(f"Model: {info['model_type']}")
>>> print(f"Accuracy: {info['metrics']['accuracy']:.2%}")
"""
checkpoint_path = Path(checkpoint_path)
# Try to load metadata file first (much faster)
metadata_path = str(checkpoint_path).replace('.pt', '_metadata.json')
if Path(metadata_path).exists():
with open(metadata_path, 'r') as f:
return json.load(f)
# Fall back to loading checkpoint
checkpoint = torch.load(checkpoint_path, map_location='cpu')
return {
'model_type': checkpoint.get('model_type'),
'model_config': checkpoint.get('model_config'),
'metrics': checkpoint.get('metrics'),
'timestamp': checkpoint.get('timestamp'),
'pytorch_version': checkpoint.get('pytorch_version')
}
@staticmethod
def _create_model(model_type, model_config):
"""
Create model instance based on type and config
Args:
model_type: Type of model ('bilstm', 'transformer', 'fasttext', 'roberta', 'bertweet')
model_config: Configuration dictionary
Returns:
Model instance
Raises:
ValueError: If model_type is unknown
"""
if model_type == 'bilstm':
from src.models.baseline.bilstm_attention import create_bilstm_model
model, _ = create_bilstm_model(**model_config)
return model
elif model_type == 'transformer':
from src.models.baseline.custom_transformer import create_custom_transformer
model, _ = create_custom_transformer(**model_config)
return model
elif model_type == 'fasttext':
from src.models.baseline.fasttext import create_fasttext_model
model, _ = create_fasttext_model(**model_config)
return model
elif model_type == 'roberta':
from src.models.pretrained.roberta import create_roberta_model
# RoBERTa models don't return config tuple in old version
model_name = model_config.get('model_name', 'roberta-base')
num_classes = model_config.get('num_classes', 3)
dropout = model_config.get('dropout', 0.5)
freeze_bert = model_config.get('freeze_bert', False)
freeze_layers = model_config.get('freeze_layers', 0)
model = create_roberta_model(
model_name=model_name,
num_classes=num_classes,
dropout=dropout,
freeze_bert=freeze_bert,
freeze_layers=freeze_layers
)
return model
elif model_type == 'bertweet':
from src.models.pretrained.bertweet import create_bertweet_model
model_name = model_config.get('model_name', 'vinai/bertweet-base')
num_classes = model_config.get('num_classes', 3)
dropout = model_config.get('dropout', 0.5)
freeze_bert = model_config.get('freeze_bert', False)
freeze_layers = model_config.get('freeze_layers', 0)
model = create_bertweet_model(
model_name=model_name,
num_classes=num_classes,
dropout=dropout,
freeze_bert=freeze_bert,
freeze_layers=freeze_layers
)
return model
else:
raise ValueError(f"Unknown model type: {model_type}")
@staticmethod
def load_model(checkpoint_path, device='cpu', strict=True):
"""
Load model with automatic type detection from metadata
Args:
checkpoint_path: Path to checkpoint file
device: Device to load model on ('cpu', 'cuda', etc.)
strict: Whether to strictly enforce state_dict matching
Returns:
Loaded model (in eval mode, on specified device)
Raises:
ValueError: If checkpoint missing metadata or has unknown model type
Example:
>>> model = ModelLoader.load_model('results/bilstm/best_model.pt', device='cuda')
>>> predictions = model(inputs)
"""
checkpoint_path = Path(checkpoint_path)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)
# Get model type and config from metadata
model_type = checkpoint.get('model_type')
model_config = checkpoint.get('model_config')
if model_type is None:
raise ValueError(
f"Checkpoint at {checkpoint_path} does not contain 'model_type' metadata.\n"
f"This checkpoint was saved with the old save_model() function.\n"
f"Please either:\n"
f" 1. Re-train the model with the updated save_model() function, or\n"
f" 2. Use the legacy loading method (not recommended)"
)
if model_config is None:
raise ValueError(
f"Checkpoint at {checkpoint_path} does not contain 'model_config' metadata.\n"
f"Cannot recreate model without configuration."
)
# Create model from config
print(f"Loading {model_type} model from checkpoint...")
model = ModelLoader._create_model(model_type, model_config)
# Load weights
model.load_state_dict(checkpoint['model_state_dict'], strict=strict)
# Move to device and set to eval mode
model.to(device)
model.eval()
print(f"✅ Model loaded successfully!")
return model
@staticmethod
def load_checkpoint_full(checkpoint_path, device='cpu'):
"""
Load full checkpoint including optimizer state and metadata
Args:
checkpoint_path: Path to checkpoint file
device: Device to load on
Returns:
dict with keys:
- 'model': Loaded model
- 'optimizer_state': Optimizer state dict (or None)
- 'epoch': Last epoch number
- 'metrics': Saved metrics
- 'model_type': Model type
- 'model_config': Model configuration
Example:
>>> data = ModelLoader.load_checkpoint_full('model.pt')
>>> model = data['model']
>>>
>>> # Resume training
>>> optimizer = torch.optim.Adam(model.parameters())
>>> if data['optimizer_state']:
... optimizer.load_state_dict(data['optimizer_state'])
>>>
>>> start_epoch = data['epoch'] + 1
"""
checkpoint_path = Path(checkpoint_path)
checkpoint = torch.load(checkpoint_path, map_location=device)
model_type = checkpoint.get('model_type')
model_config = checkpoint.get('model_config')
if model_type is None or model_config is None:
raise ValueError("Checkpoint missing required metadata (model_type or model_config)")
# Create and load model
model = ModelLoader._create_model(model_type, model_config)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
return {
'model': model,
'optimizer_state': checkpoint.get('optimizer_state_dict'),
'epoch': checkpoint.get('epoch', 0),
'metrics': checkpoint.get('metrics', {}),
'model_type': model_type,
'model_config': model_config
}
@staticmethod
def load_model_legacy(checkpoint_path, model_type, model_config, device='cpu'):
"""
Load model the old way (for checkpoints without metadata)
Args:
checkpoint_path: Path to checkpoint
model_type: Model type string (must specify manually)
model_config: Model configuration dict (must specify manually)
device: Device to load on
Returns:
Loaded model
Example:
>>> # For old checkpoints without metadata
>>> config = {'vocab_size': 10000, 'num_classes': 3}
>>> model = ModelLoader.load_model_legacy(
... 'old_model.pt',
... model_type='bilstm',
... model_config=config
... )
"""
warnings.warn(
"Using legacy loading method. Consider re-training model with new save format.",
DeprecationWarning
)
checkpoint = torch.load(checkpoint_path, map_location=device)
# Create model
model = ModelLoader._create_model(model_type, model_config)
# Load weights (handle both old and new checkpoint formats)
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
return model
# Convenience function for backward compatibility
def load_model(checkpoint_path, device='cpu'):
"""
Convenience function to load model (uses ModelLoader internally)
Args:
checkpoint_path: Path to checkpoint
device: Device to load on
Returns:
Loaded model
Example:
>>> from src.utils.model_loader import load_model
>>> model = load_model('results/bilstm/best_model.pt', device='cuda')
"""
return ModelLoader.load_model(checkpoint_path, device)
if __name__ == "__main__":
print("="*80)
print("MODELLOADER UTILITY")
print("="*80)
print("\nThis module provides intelligent model loading with automatic type detection.")
print("\nUsage:")
print(" 1. Quick info: ModelLoader.get_model_info('model.pt')")
print(" 2. Load model: ModelLoader.load_model('model.pt', device='cuda')")
print(" 3. Full load: ModelLoader.load_checkpoint_full('model.pt')")
print("\nFeatures:")
print(" ✅ Automatic model type detection from metadata")
print(" ✅ No more fragile path-based detection")
print(" ✅ Complete reproducibility")
print(" ✅ Quick metadata inspection")
print(" ✅ Legacy checkpoint support")
print("\n" + "="*80)