File size: 12,841 Bytes
32d4a86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
"""
Model Optimization Script for Production Deployment
Reduces model size and improves inference speed through:
1. Quantization (INT8)
2. TorchScript compilation
3. Model pruning (optional)
4. State dict optimization
"""

import torch
import os
import argparse
from pathlib import Path

# Import model classes BEFORE loading checkpoints (needed for unpickling)
# This ensures PyTorch can find the class definitions when loading saved objects
# Note: resnet_train.py has module-level code that loads COCO data, which may fail
# if training files aren't present. We'll handle this in the functions.

def quantize_model(checkpoint_path, output_path, model_type='resnet'):
    """
    Quantize model to INT8 for 4x size reduction and faster inference.
    Note: Slight accuracy loss (usually <1%)
    """
    print(f"Quantizing {model_type} model...")
    
    device = torch.device('cpu')  # Quantization typically done on CPU
    
    # Import classes before loading (required for unpickling)
    # resnet_train.py now handles missing training data gracefully
    if model_type == 'resnet':
        # Import the module itself so we can update vocab later
        import resnet_train
        from resnet_train import EncoderCNN, DecoderRNN, Vocabulary
        
        # Make Vocabulary available in __main__ for unpickling
        # This handles cases where checkpoint was saved with Vocabulary from __main__
        import __main__
        if not hasattr(__main__, 'Vocabulary'):
            __main__.Vocabulary = Vocabulary
    elif model_type == 'efficientnet':
        from efficient_train import Encoder, Decoder, ImageCaptioningModel
        from transformers import AutoTokenizer
    
    # Load checkpoint (now all classes are available for unpickling)
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    
    if model_type == 'resnet':
        # For ResNet, quantize encoder and decoder separately
        
        # IMPORTANT: Update vocab from checkpoint before creating DecoderRNN
        # The decoder uses len(vocab.word2idx) in its __init__, so we need the full vocab
        if 'vocab' in checkpoint and checkpoint['vocab'] is not None:
            # Update the vocab in resnet_train module (DecoderRNN.__init__ references resnet_train.vocab)
            resnet_train.vocab = checkpoint['vocab']
            print(f"  Updated vocab size: {len(checkpoint['vocab'].word2idx)}")
        else:
            raise ValueError("Checkpoint does not contain 'vocab' key. Cannot proceed.")
        
        encoder = EncoderCNN()
        decoder = DecoderRNN()  # Now uses the correct vocab size from checkpoint
        
        encoder.load_state_dict(checkpoint['encoder'])
        decoder.load_state_dict(checkpoint['decoder'])
        
        # Set to eval mode
        encoder.eval()
        decoder.eval()
        
        # Prepare for quantization (dummy input)
        dummy_input = torch.randn(1, 3, 224, 224)
        
        # Quantize encoder (only Linear and Conv2d layers)
        encoder_quantized = torch.quantization.quantize_dynamic(
            encoder, {torch.nn.Linear, torch.nn.Conv2d}, dtype=torch.qint8
        )
        
        # Quantize decoder (only Linear layers - Embedding requires special config)
        # Embeddings are typically small and don't benefit much from quantization
        decoder_quantized = torch.quantization.quantize_dynamic(
            decoder, {torch.nn.Linear}, dtype=torch.qint8
        )
        
        # Save quantized model
        quantized_checkpoint = {
            'encoder': encoder_quantized.state_dict(),
            'decoder': decoder_quantized.state_dict(),
            'vocab': checkpoint.get('vocab'),
            'quantized': True
        }
        
    elif model_type == 'efficientnet':
        # Classes already imported above before loading checkpoint
        tokenizer = AutoTokenizer.from_pretrained('gpt2')
        tokenizer.pad_token = tokenizer.eos_token
        special_tokens = {'additional_special_tokens': ['<start>', '<end>']}
        tokenizer.add_special_tokens(special_tokens)
        
        encoder = Encoder(model_name='efficientnet_b3', embed_dim=512)
        decoder = Decoder(
            vocab_size=len(tokenizer),
            embed_dim=512,
            num_layers=8,
            num_heads=8,
            max_seq_length=64
        )
        model = ImageCaptioningModel(encoder, decoder)
        
        # Load state dict - handle both 'model_state' key and direct state dict
        if 'model_state' in checkpoint:
            model.load_state_dict(checkpoint['model_state'])
        else:
            model.load_state_dict(checkpoint)
        
        model.eval()
        
        # Quantize the full model
        model_quantized = torch.quantization.quantize_dynamic(
            model, {torch.nn.Linear, torch.nn.Conv2d}, dtype=torch.qint8
        )
        
        quantized_checkpoint = {
            'model_state': model_quantized.state_dict(),
            'quantized': True
        }
    
    torch.save(quantized_checkpoint, output_path)
    
    # Compare sizes
    original_size = os.path.getsize(checkpoint_path) / (1024 * 1024)  # MB
    quantized_size = os.path.getsize(output_path) / (1024 * 1024)  # MB
    reduction = (1 - quantized_size / original_size) * 100
    
    print(f"✓ Quantization complete!")
    print(f"  Original size: {original_size:.2f} MB")
    print(f"  Quantized size: {quantized_size:.2f} MB")
    print(f"  Size reduction: {reduction:.1f}%")
    
    return output_path


def optimize_state_dict(checkpoint_path, output_path):
    """
    Remove unnecessary metadata and optimize state dict for smaller size.
    """
    print(f"Optimizing state dict...")
    
    # Import classes before loading (required for unpickling)
    try:
        from resnet_train import Vocabulary
        # Make Vocabulary available in __main__ for unpickling
        import __main__
        if not hasattr(__main__, 'Vocabulary'):
            __main__.Vocabulary = Vocabulary
    except ImportError:
        pass
    
    checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
    
    # Create optimized checkpoint with only essential data
    optimized = {}
    for key, value in checkpoint.items():
        if key not in ['optimizer', 'scheduler', 'epoch', 'loss', 'metrics']:
            optimized[key] = value
    
    # Save with highest compression
    torch.save(optimized, output_path, _use_new_zipfile_serialization=True)
    
    original_size = os.path.getsize(checkpoint_path) / (1024 * 1024)
    optimized_size = os.path.getsize(output_path) / (1024 * 1024)
    reduction = (1 - optimized_size / original_size) * 100
    
    print(f"✓ State dict optimized!")
    print(f"  Original: {original_size:.2f} MB")
    print(f"  Optimized: {optimized_size:.2f} MB")
    print(f"  Reduction: {reduction:.1f}%")
    
    return output_path


def create_torchscript(checkpoint_path, output_path, model_type='resnet'):
    """
    Convert model to TorchScript for faster loading and inference.
    Note: Requires example input for tracing.
    """
    print(f"Creating TorchScript model...")
    
    device = torch.device('cpu')
    
    # Import classes before loading (required for unpickling)
    if model_type == 'resnet':
        import resnet_train
        from resnet_train import EncoderCNN, DecoderRNN, Vocabulary
        
        # Make Vocabulary available in __main__ for unpickling
        import __main__
        if not hasattr(__main__, 'Vocabulary'):
            __main__.Vocabulary = Vocabulary
    elif model_type == 'efficientnet':
        from efficient_train import Encoder, Decoder, ImageCaptioningModel
        from transformers import AutoTokenizer
    
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    
    if model_type == 'resnet':
        # Update vocab from checkpoint before creating DecoderRNN
        if 'vocab' in checkpoint and checkpoint['vocab'] is not None:
            resnet_train.vocab = checkpoint['vocab']
            print(f"  Updated vocab size: {len(checkpoint['vocab'].word2idx)}")
        else:
            raise ValueError("Checkpoint does not contain 'vocab' key. Cannot proceed.")
        
        encoder = EncoderCNN().eval()
        decoder = DecoderRNN().eval()  # Now uses the correct vocab size
        
        encoder.load_state_dict(checkpoint['encoder'])
        decoder.load_state_dict(checkpoint['decoder'])
        
        # Trace encoder
        dummy_image = torch.randn(1, 3, 224, 224)
        encoder_traced = torch.jit.trace(encoder, dummy_image)
        
        # For decoder, we need to trace with proper inputs
        # This is more complex due to RNN structure
        print("  ⚠ TorchScript for RNN decoder may require manual scripting")
        print("  ✓ Encoder traced successfully")
        
        torch.jit.save(encoder_traced, output_path.replace('.pth', '_encoder.pt'))
        
    elif model_type == 'efficientnet':
        # Classes already imported above
        tokenizer = AutoTokenizer.from_pretrained('gpt2')
        tokenizer.pad_token = tokenizer.eos_token
        special_tokens = {'additional_special_tokens': ['<start>', '<end>']}
        tokenizer.add_special_tokens(special_tokens)
        
        encoder = Encoder(model_name='efficientnet_b3', embed_dim=512)
        decoder = Decoder(
            vocab_size=len(tokenizer),
            embed_dim=512,
            num_layers=8,
            num_heads=8,
            max_seq_length=64
        )
        model = ImageCaptioningModel(encoder, decoder).eval()
        
        model.load_state_dict(checkpoint['model_state'])
        
        # Trace encoder only (decoder has dynamic inputs)
        dummy_image = torch.randn(1, 3, 224, 224)
        encoder_traced = torch.jit.trace(model.encoder, dummy_image)
        
        torch.jit.save(encoder_traced, output_path.replace('.pth', '_encoder.pt'))
        print("  ✓ Encoder traced successfully")
    
    print(f"✓ TorchScript saved to {output_path}")
    return output_path


def main():
    parser = argparse.ArgumentParser(description='Optimize models for production deployment')
    parser.add_argument('--model', type=str, choices=['resnet', 'efficientnet', 'both'], 
                       default='both', help='Model to optimize')
    parser.add_argument('--method', type=str, choices=['quantize', 'optimize', 'torchscript', 'all'],
                       default='all', help='Optimization method')
    parser.add_argument('--resnet-path', type=str, default='resnet_best_model.pth',
                       help='Path to ResNet checkpoint')
    parser.add_argument('--efficientnet-path', type=str, default='efficient_best_model.pth',
                       help='Path to EfficientNet checkpoint')
    parser.add_argument('--output-dir', type=str, default='optimized_models',
                       help='Output directory for optimized models')
    
    args = parser.parse_args()
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    models_to_process = []
    if args.model in ['resnet', 'both']:
        if os.path.exists(args.resnet_path):
            models_to_process.append(('resnet', args.resnet_path))
        else:
            print(f"⚠ Warning: {args.resnet_path} not found, skipping ResNet")
    
    if args.model in ['efficientnet', 'both']:
        if os.path.exists(args.efficientnet_path):
            models_to_process.append(('efficientnet', args.efficientnet_path))
        else:
            print(f"⚠ Warning: {args.efficientnet_path} not found, skipping EfficientNet")
    
    if not models_to_process:
        print("❌ No models found to optimize!")
        return
    
    for model_type, model_path in models_to_process:
        print(f"\n{'='*60}")
        print(f"Processing {model_type.upper()} model")
        print(f"{'='*60}")
        
        base_name = Path(model_path).stem
        output_base = os.path.join(args.output_dir, f"{model_type}_{base_name}")
        
        if args.method in ['quantize', 'all']:
            quantized_path = f"{output_base}_quantized.pth"
            quantize_model(model_path, quantized_path, model_type)
        
        if args.method in ['optimize', 'all']:
            optimized_path = f"{output_base}_optimized.pth"
            optimize_state_dict(model_path, optimized_path)
        
        if args.method in ['torchscript', 'all']:
            torchscript_path = f"{output_base}_torchscript.pt"
            create_torchscript(model_path, torchscript_path, model_type)
    
    print(f"\n{'='*60}")
    print("✓ Optimization complete!")
    print(f"Optimized models saved to: {args.output_dir}")
    print(f"{'='*60}")


if __name__ == '__main__':
    main()