Spaces:
Runtime error
Runtime error
| """ | |
| ModelLoader - Intelligent model loading with automatic type detection | |
| Replaces fragile path-based model detection with metadata-based approach | |
| """ | |
| import torch | |
| from pathlib import Path | |
| import json | |
| import warnings | |
| class ModelLoader: | |
| """ | |
| Utility class for loading models with automatic type detection | |
| Features: | |
| - Loads models using metadata from checkpoint | |
| - No more fragile path-based detection | |
| - Complete model reproducibility | |
| - Quick metadata inspection without loading weights | |
| Example: | |
| >>> # Quick info without loading weights | |
| >>> info = ModelLoader.get_model_info('model.pt') | |
| >>> print(f"Model type: {info['model_type']}") | |
| >>> # Load model automatically | |
| >>> model = ModelLoader.load_model('model.pt', device='cuda') | |
| >>> # Load with optimizer state | |
| >>> data = ModelLoader.load_checkpoint_full('model.pt') | |
| >>> model = data['model'] | |
| >>> optimizer.load_state_dict(data['optimizer_state']) | |
| """ | |
| def get_model_info(checkpoint_path): | |
| """ | |
| Get model information without loading weights (fast) | |
| Args: | |
| checkpoint_path: Path to checkpoint file | |
| Returns: | |
| dict: Model metadata (type, config, metrics, timestamp) | |
| Example: | |
| >>> info = ModelLoader.get_model_info('results/bilstm/best_model.pt') | |
| >>> print(f"Model: {info['model_type']}") | |
| >>> print(f"Accuracy: {info['metrics']['accuracy']:.2%}") | |
| """ | |
| checkpoint_path = Path(checkpoint_path) | |
| # Try to load metadata file first (much faster) | |
| metadata_path = str(checkpoint_path).replace('.pt', '_metadata.json') | |
| if Path(metadata_path).exists(): | |
| with open(metadata_path, 'r') as f: | |
| return json.load(f) | |
| # Fall back to loading checkpoint | |
| checkpoint = torch.load(checkpoint_path, map_location='cpu') | |
| return { | |
| 'model_type': checkpoint.get('model_type'), | |
| 'model_config': checkpoint.get('model_config'), | |
| 'metrics': checkpoint.get('metrics'), | |
| 'timestamp': checkpoint.get('timestamp'), | |
| 'pytorch_version': checkpoint.get('pytorch_version') | |
| } | |
| def _create_model(model_type, model_config): | |
| """ | |
| Create model instance based on type and config | |
| Args: | |
| model_type: Type of model ('bilstm', 'transformer', 'fasttext', 'roberta', 'bertweet') | |
| model_config: Configuration dictionary | |
| Returns: | |
| Model instance | |
| Raises: | |
| ValueError: If model_type is unknown | |
| """ | |
| if model_type == 'bilstm': | |
| from src.models.baseline.bilstm_attention import create_bilstm_model | |
| model, _ = create_bilstm_model(**model_config) | |
| return model | |
| elif model_type == 'transformer': | |
| from src.models.baseline.custom_transformer import create_custom_transformer | |
| model, _ = create_custom_transformer(**model_config) | |
| return model | |
| elif model_type == 'fasttext': | |
| from src.models.baseline.fasttext import create_fasttext_model | |
| model, _ = create_fasttext_model(**model_config) | |
| return model | |
| elif model_type == 'roberta': | |
| from src.models.pretrained.roberta import create_roberta_model | |
| # RoBERTa models don't return config tuple in old version | |
| model_name = model_config.get('model_name', 'roberta-base') | |
| num_classes = model_config.get('num_classes', 3) | |
| dropout = model_config.get('dropout', 0.5) | |
| freeze_bert = model_config.get('freeze_bert', False) | |
| freeze_layers = model_config.get('freeze_layers', 0) | |
| model = create_roberta_model( | |
| model_name=model_name, | |
| num_classes=num_classes, | |
| dropout=dropout, | |
| freeze_bert=freeze_bert, | |
| freeze_layers=freeze_layers | |
| ) | |
| return model | |
| elif model_type == 'bertweet': | |
| from src.models.pretrained.bertweet import create_bertweet_model | |
| model_name = model_config.get('model_name', 'vinai/bertweet-base') | |
| num_classes = model_config.get('num_classes', 3) | |
| dropout = model_config.get('dropout', 0.5) | |
| freeze_bert = model_config.get('freeze_bert', False) | |
| freeze_layers = model_config.get('freeze_layers', 0) | |
| model = create_bertweet_model( | |
| model_name=model_name, | |
| num_classes=num_classes, | |
| dropout=dropout, | |
| freeze_bert=freeze_bert, | |
| freeze_layers=freeze_layers | |
| ) | |
| return model | |
| else: | |
| raise ValueError(f"Unknown model type: {model_type}") | |
| def load_model(checkpoint_path, device='cpu', strict=True): | |
| """ | |
| Load model with automatic type detection from metadata | |
| Args: | |
| checkpoint_path: Path to checkpoint file | |
| device: Device to load model on ('cpu', 'cuda', etc.) | |
| strict: Whether to strictly enforce state_dict matching | |
| Returns: | |
| Loaded model (in eval mode, on specified device) | |
| Raises: | |
| ValueError: If checkpoint missing metadata or has unknown model type | |
| Example: | |
| >>> model = ModelLoader.load_model('results/bilstm/best_model.pt', device='cuda') | |
| >>> predictions = model(inputs) | |
| """ | |
| checkpoint_path = Path(checkpoint_path) | |
| if not checkpoint_path.exists(): | |
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") | |
| # Load checkpoint | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| # Get model type and config from metadata | |
| model_type = checkpoint.get('model_type') | |
| model_config = checkpoint.get('model_config') | |
| if model_type is None: | |
| raise ValueError( | |
| f"Checkpoint at {checkpoint_path} does not contain 'model_type' metadata.\n" | |
| f"This checkpoint was saved with the old save_model() function.\n" | |
| f"Please either:\n" | |
| f" 1. Re-train the model with the updated save_model() function, or\n" | |
| f" 2. Use the legacy loading method (not recommended)" | |
| ) | |
| if model_config is None: | |
| raise ValueError( | |
| f"Checkpoint at {checkpoint_path} does not contain 'model_config' metadata.\n" | |
| f"Cannot recreate model without configuration." | |
| ) | |
| # Create model from config | |
| print(f"Loading {model_type} model from checkpoint...") | |
| model = ModelLoader._create_model(model_type, model_config) | |
| # Load weights | |
| model.load_state_dict(checkpoint['model_state_dict'], strict=strict) | |
| # Move to device and set to eval mode | |
| model.to(device) | |
| model.eval() | |
| print(f"✅ Model loaded successfully!") | |
| return model | |
| def load_checkpoint_full(checkpoint_path, device='cpu'): | |
| """ | |
| Load full checkpoint including optimizer state and metadata | |
| Args: | |
| checkpoint_path: Path to checkpoint file | |
| device: Device to load on | |
| Returns: | |
| dict with keys: | |
| - 'model': Loaded model | |
| - 'optimizer_state': Optimizer state dict (or None) | |
| - 'epoch': Last epoch number | |
| - 'metrics': Saved metrics | |
| - 'model_type': Model type | |
| - 'model_config': Model configuration | |
| Example: | |
| >>> data = ModelLoader.load_checkpoint_full('model.pt') | |
| >>> model = data['model'] | |
| >>> | |
| >>> # Resume training | |
| >>> optimizer = torch.optim.Adam(model.parameters()) | |
| >>> if data['optimizer_state']: | |
| ... optimizer.load_state_dict(data['optimizer_state']) | |
| >>> | |
| >>> start_epoch = data['epoch'] + 1 | |
| """ | |
| checkpoint_path = Path(checkpoint_path) | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| model_type = checkpoint.get('model_type') | |
| model_config = checkpoint.get('model_config') | |
| if model_type is None or model_config is None: | |
| raise ValueError("Checkpoint missing required metadata (model_type or model_config)") | |
| # Create and load model | |
| model = ModelLoader._create_model(model_type, model_config) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.to(device) | |
| return { | |
| 'model': model, | |
| 'optimizer_state': checkpoint.get('optimizer_state_dict'), | |
| 'epoch': checkpoint.get('epoch', 0), | |
| 'metrics': checkpoint.get('metrics', {}), | |
| 'model_type': model_type, | |
| 'model_config': model_config | |
| } | |
| def load_model_legacy(checkpoint_path, model_type, model_config, device='cpu'): | |
| """ | |
| Load model the old way (for checkpoints without metadata) | |
| Args: | |
| checkpoint_path: Path to checkpoint | |
| model_type: Model type string (must specify manually) | |
| model_config: Model configuration dict (must specify manually) | |
| device: Device to load on | |
| Returns: | |
| Loaded model | |
| Example: | |
| >>> # For old checkpoints without metadata | |
| >>> config = {'vocab_size': 10000, 'num_classes': 3} | |
| >>> model = ModelLoader.load_model_legacy( | |
| ... 'old_model.pt', | |
| ... model_type='bilstm', | |
| ... model_config=config | |
| ... ) | |
| """ | |
| warnings.warn( | |
| "Using legacy loading method. Consider re-training model with new save format.", | |
| DeprecationWarning | |
| ) | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| # Create model | |
| model = ModelLoader._create_model(model_type, model_config) | |
| # Load weights (handle both old and new checkpoint formats) | |
| if 'model_state_dict' in checkpoint: | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| else: | |
| model.load_state_dict(checkpoint) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| # Convenience function for backward compatibility | |
| def load_model(checkpoint_path, device='cpu'): | |
| """ | |
| Convenience function to load model (uses ModelLoader internally) | |
| Args: | |
| checkpoint_path: Path to checkpoint | |
| device: Device to load on | |
| Returns: | |
| Loaded model | |
| Example: | |
| >>> from src.utils.model_loader import load_model | |
| >>> model = load_model('results/bilstm/best_model.pt', device='cuda') | |
| """ | |
| return ModelLoader.load_model(checkpoint_path, device) | |
| if __name__ == "__main__": | |
| print("="*80) | |
| print("MODELLOADER UTILITY") | |
| print("="*80) | |
| print("\nThis module provides intelligent model loading with automatic type detection.") | |
| print("\nUsage:") | |
| print(" 1. Quick info: ModelLoader.get_model_info('model.pt')") | |
| print(" 2. Load model: ModelLoader.load_model('model.pt', device='cuda')") | |
| print(" 3. Full load: ModelLoader.load_checkpoint_full('model.pt')") | |
| print("\nFeatures:") | |
| print(" ✅ Automatic model type detection from metadata") | |
| print(" ✅ No more fragile path-based detection") | |
| print(" ✅ Complete reproducibility") | |
| print(" ✅ Quick metadata inspection") | |
| print(" ✅ Legacy checkpoint support") | |
| print("\n" + "="*80) |