Spaces:
Running
Running
| """ | |
| Model Optimization Script for Production Deployment | |
| Reduces model size and improves inference speed through: | |
| 1. Quantization (INT8) | |
| 2. TorchScript compilation | |
| 3. Model pruning (optional) | |
| 4. State dict optimization | |
| """ | |
| import torch | |
| import os | |
| import argparse | |
| from pathlib import Path | |
| # Import model classes BEFORE loading checkpoints (needed for unpickling) | |
| # This ensures PyTorch can find the class definitions when loading saved objects | |
| # Note: resnet_train.py has module-level code that loads COCO data, which may fail | |
| # if training files aren't present. We'll handle this in the functions. | |
| def quantize_model(checkpoint_path, output_path, model_type='resnet'): | |
| """ | |
| Quantize model to INT8 for 4x size reduction and faster inference. | |
| Note: Slight accuracy loss (usually <1%) | |
| """ | |
| print(f"Quantizing {model_type} model...") | |
| device = torch.device('cpu') # Quantization typically done on CPU | |
| # Import classes before loading (required for unpickling) | |
| # resnet_train.py now handles missing training data gracefully | |
| if model_type == 'resnet': | |
| # Import the module itself so we can update vocab later | |
| import resnet_train | |
| from resnet_train import EncoderCNN, DecoderRNN, Vocabulary | |
| # Make Vocabulary available in __main__ for unpickling | |
| # This handles cases where checkpoint was saved with Vocabulary from __main__ | |
| import __main__ | |
| if not hasattr(__main__, 'Vocabulary'): | |
| __main__.Vocabulary = Vocabulary | |
| elif model_type == 'efficientnet': | |
| from efficient_train import Encoder, Decoder, ImageCaptioningModel | |
| from transformers import AutoTokenizer | |
| # Load checkpoint (now all classes are available for unpickling) | |
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) | |
| if model_type == 'resnet': | |
| # For ResNet, quantize encoder and decoder separately | |
| # IMPORTANT: Update vocab from checkpoint before creating DecoderRNN | |
| # The decoder uses len(vocab.word2idx) in its __init__, so we need the full vocab | |
| if 'vocab' in checkpoint and checkpoint['vocab'] is not None: | |
| # Update the vocab in resnet_train module (DecoderRNN.__init__ references resnet_train.vocab) | |
| resnet_train.vocab = checkpoint['vocab'] | |
| print(f" Updated vocab size: {len(checkpoint['vocab'].word2idx)}") | |
| else: | |
| raise ValueError("Checkpoint does not contain 'vocab' key. Cannot proceed.") | |
| encoder = EncoderCNN() | |
| decoder = DecoderRNN() # Now uses the correct vocab size from checkpoint | |
| encoder.load_state_dict(checkpoint['encoder']) | |
| decoder.load_state_dict(checkpoint['decoder']) | |
| # Set to eval mode | |
| encoder.eval() | |
| decoder.eval() | |
| # Prepare for quantization (dummy input) | |
| dummy_input = torch.randn(1, 3, 224, 224) | |
| # Quantize encoder (only Linear and Conv2d layers) | |
| encoder_quantized = torch.quantization.quantize_dynamic( | |
| encoder, {torch.nn.Linear, torch.nn.Conv2d}, dtype=torch.qint8 | |
| ) | |
| # Quantize decoder (only Linear layers - Embedding requires special config) | |
| # Embeddings are typically small and don't benefit much from quantization | |
| decoder_quantized = torch.quantization.quantize_dynamic( | |
| decoder, {torch.nn.Linear}, dtype=torch.qint8 | |
| ) | |
| # Save quantized model | |
| quantized_checkpoint = { | |
| 'encoder': encoder_quantized.state_dict(), | |
| 'decoder': decoder_quantized.state_dict(), | |
| 'vocab': checkpoint.get('vocab'), | |
| 'quantized': True | |
| } | |
| elif model_type == 'efficientnet': | |
| # Classes already imported above before loading checkpoint | |
| tokenizer = AutoTokenizer.from_pretrained('gpt2') | |
| tokenizer.pad_token = tokenizer.eos_token | |
| special_tokens = {'additional_special_tokens': ['<start>', '<end>']} | |
| tokenizer.add_special_tokens(special_tokens) | |
| encoder = Encoder(model_name='efficientnet_b3', embed_dim=512) | |
| decoder = Decoder( | |
| vocab_size=len(tokenizer), | |
| embed_dim=512, | |
| num_layers=8, | |
| num_heads=8, | |
| max_seq_length=64 | |
| ) | |
| model = ImageCaptioningModel(encoder, decoder) | |
| # Load state dict - handle both 'model_state' key and direct state dict | |
| if 'model_state' in checkpoint: | |
| model.load_state_dict(checkpoint['model_state']) | |
| else: | |
| model.load_state_dict(checkpoint) | |
| model.eval() | |
| # Quantize the full model | |
| model_quantized = torch.quantization.quantize_dynamic( | |
| model, {torch.nn.Linear, torch.nn.Conv2d}, dtype=torch.qint8 | |
| ) | |
| quantized_checkpoint = { | |
| 'model_state': model_quantized.state_dict(), | |
| 'quantized': True | |
| } | |
| torch.save(quantized_checkpoint, output_path) | |
| # Compare sizes | |
| original_size = os.path.getsize(checkpoint_path) / (1024 * 1024) # MB | |
| quantized_size = os.path.getsize(output_path) / (1024 * 1024) # MB | |
| reduction = (1 - quantized_size / original_size) * 100 | |
| print(f"✓ Quantization complete!") | |
| print(f" Original size: {original_size:.2f} MB") | |
| print(f" Quantized size: {quantized_size:.2f} MB") | |
| print(f" Size reduction: {reduction:.1f}%") | |
| return output_path | |
| def optimize_state_dict(checkpoint_path, output_path): | |
| """ | |
| Remove unnecessary metadata and optimize state dict for smaller size. | |
| """ | |
| print(f"Optimizing state dict...") | |
| # Import classes before loading (required for unpickling) | |
| try: | |
| from resnet_train import Vocabulary | |
| # Make Vocabulary available in __main__ for unpickling | |
| import __main__ | |
| if not hasattr(__main__, 'Vocabulary'): | |
| __main__.Vocabulary = Vocabulary | |
| except ImportError: | |
| pass | |
| checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) | |
| # Create optimized checkpoint with only essential data | |
| optimized = {} | |
| for key, value in checkpoint.items(): | |
| if key not in ['optimizer', 'scheduler', 'epoch', 'loss', 'metrics']: | |
| optimized[key] = value | |
| # Save with highest compression | |
| torch.save(optimized, output_path, _use_new_zipfile_serialization=True) | |
| original_size = os.path.getsize(checkpoint_path) / (1024 * 1024) | |
| optimized_size = os.path.getsize(output_path) / (1024 * 1024) | |
| reduction = (1 - optimized_size / original_size) * 100 | |
| print(f"✓ State dict optimized!") | |
| print(f" Original: {original_size:.2f} MB") | |
| print(f" Optimized: {optimized_size:.2f} MB") | |
| print(f" Reduction: {reduction:.1f}%") | |
| return output_path | |
| def create_torchscript(checkpoint_path, output_path, model_type='resnet'): | |
| """ | |
| Convert model to TorchScript for faster loading and inference. | |
| Note: Requires example input for tracing. | |
| """ | |
| print(f"Creating TorchScript model...") | |
| device = torch.device('cpu') | |
| # Import classes before loading (required for unpickling) | |
| if model_type == 'resnet': | |
| import resnet_train | |
| from resnet_train import EncoderCNN, DecoderRNN, Vocabulary | |
| # Make Vocabulary available in __main__ for unpickling | |
| import __main__ | |
| if not hasattr(__main__, 'Vocabulary'): | |
| __main__.Vocabulary = Vocabulary | |
| elif model_type == 'efficientnet': | |
| from efficient_train import Encoder, Decoder, ImageCaptioningModel | |
| from transformers import AutoTokenizer | |
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) | |
| if model_type == 'resnet': | |
| # Update vocab from checkpoint before creating DecoderRNN | |
| if 'vocab' in checkpoint and checkpoint['vocab'] is not None: | |
| resnet_train.vocab = checkpoint['vocab'] | |
| print(f" Updated vocab size: {len(checkpoint['vocab'].word2idx)}") | |
| else: | |
| raise ValueError("Checkpoint does not contain 'vocab' key. Cannot proceed.") | |
| encoder = EncoderCNN().eval() | |
| decoder = DecoderRNN().eval() # Now uses the correct vocab size | |
| encoder.load_state_dict(checkpoint['encoder']) | |
| decoder.load_state_dict(checkpoint['decoder']) | |
| # Trace encoder | |
| dummy_image = torch.randn(1, 3, 224, 224) | |
| encoder_traced = torch.jit.trace(encoder, dummy_image) | |
| # For decoder, we need to trace with proper inputs | |
| # This is more complex due to RNN structure | |
| print(" ⚠ TorchScript for RNN decoder may require manual scripting") | |
| print(" ✓ Encoder traced successfully") | |
| torch.jit.save(encoder_traced, output_path.replace('.pth', '_encoder.pt')) | |
| elif model_type == 'efficientnet': | |
| # Classes already imported above | |
| tokenizer = AutoTokenizer.from_pretrained('gpt2') | |
| tokenizer.pad_token = tokenizer.eos_token | |
| special_tokens = {'additional_special_tokens': ['<start>', '<end>']} | |
| tokenizer.add_special_tokens(special_tokens) | |
| encoder = Encoder(model_name='efficientnet_b3', embed_dim=512) | |
| decoder = Decoder( | |
| vocab_size=len(tokenizer), | |
| embed_dim=512, | |
| num_layers=8, | |
| num_heads=8, | |
| max_seq_length=64 | |
| ) | |
| model = ImageCaptioningModel(encoder, decoder).eval() | |
| model.load_state_dict(checkpoint['model_state']) | |
| # Trace encoder only (decoder has dynamic inputs) | |
| dummy_image = torch.randn(1, 3, 224, 224) | |
| encoder_traced = torch.jit.trace(model.encoder, dummy_image) | |
| torch.jit.save(encoder_traced, output_path.replace('.pth', '_encoder.pt')) | |
| print(" ✓ Encoder traced successfully") | |
| print(f"✓ TorchScript saved to {output_path}") | |
| return output_path | |
| def main(): | |
| parser = argparse.ArgumentParser(description='Optimize models for production deployment') | |
| parser.add_argument('--model', type=str, choices=['resnet', 'efficientnet', 'both'], | |
| default='both', help='Model to optimize') | |
| parser.add_argument('--method', type=str, choices=['quantize', 'optimize', 'torchscript', 'all'], | |
| default='all', help='Optimization method') | |
| parser.add_argument('--resnet-path', type=str, default='resnet_best_model.pth', | |
| help='Path to ResNet checkpoint') | |
| parser.add_argument('--efficientnet-path', type=str, default='efficient_best_model.pth', | |
| help='Path to EfficientNet checkpoint') | |
| parser.add_argument('--output-dir', type=str, default='optimized_models', | |
| help='Output directory for optimized models') | |
| args = parser.parse_args() | |
| # Create output directory | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| models_to_process = [] | |
| if args.model in ['resnet', 'both']: | |
| if os.path.exists(args.resnet_path): | |
| models_to_process.append(('resnet', args.resnet_path)) | |
| else: | |
| print(f"⚠ Warning: {args.resnet_path} not found, skipping ResNet") | |
| if args.model in ['efficientnet', 'both']: | |
| if os.path.exists(args.efficientnet_path): | |
| models_to_process.append(('efficientnet', args.efficientnet_path)) | |
| else: | |
| print(f"⚠ Warning: {args.efficientnet_path} not found, skipping EfficientNet") | |
| if not models_to_process: | |
| print("❌ No models found to optimize!") | |
| return | |
| for model_type, model_path in models_to_process: | |
| print(f"\n{'='*60}") | |
| print(f"Processing {model_type.upper()} model") | |
| print(f"{'='*60}") | |
| base_name = Path(model_path).stem | |
| output_base = os.path.join(args.output_dir, f"{model_type}_{base_name}") | |
| if args.method in ['quantize', 'all']: | |
| quantized_path = f"{output_base}_quantized.pth" | |
| quantize_model(model_path, quantized_path, model_type) | |
| if args.method in ['optimize', 'all']: | |
| optimized_path = f"{output_base}_optimized.pth" | |
| optimize_state_dict(model_path, optimized_path) | |
| if args.method in ['torchscript', 'all']: | |
| torchscript_path = f"{output_base}_torchscript.pt" | |
| create_torchscript(model_path, torchscript_path, model_type) | |
| print(f"\n{'='*60}") | |
| print("✓ Optimization complete!") | |
| print(f"Optimized models saved to: {args.output_dir}") | |
| print(f"{'='*60}") | |
| if __name__ == '__main__': | |
| main() | |