| | |
| | """ |
| | Advanced Model Quantization and Optimization for EMOTIA |
| | Supports INT8, FP16 quantization, pruning, and edge deployment |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.quantization as quant |
| | from torch.quantization import QuantStub, DeQuantStub |
| | import torch.nn.utils.prune as prune |
| | from torch.utils.data import DataLoader |
| | import numpy as np |
| | import os |
| | import json |
| | import logging |
| | from typing import Dict, List, Optional, Tuple |
| | from pathlib import Path |
| | import time |
| | from functools import partial |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | class AdvancedQuantizer: |
| | """Advanced quantization utilities for EMOTIA models""" |
| |
|
| | def __init__(self, model: nn.Module, config: Dict): |
| | self.model = model |
| | self.config = config |
| | self.quantized_model = None |
| | self.calibration_data = [] |
| |
|
| | def prepare_for_quantization(self) -> nn.Module: |
| | """Prepare model for quantization-aware training""" |
| | |
| | self.model = self._fuse_modules() |
| |
|
| | |
| | self.model = self._insert_quant_stubs() |
| |
|
| | |
| | self.model.qconfig = quant.get_default_qat_qconfig('fbgemm') |
| |
|
| | |
| | quant.prepare_qat(self.model, inplace=True) |
| |
|
| | logger.info("Model prepared for quantization-aware training") |
| | return self.model |
| |
|
| | def _fuse_modules(self) -> nn.Module: |
| | """Fuse compatible layers for better quantization""" |
| | fusion_patterns = [ |
| | ['conv1', 'bn1'], |
| | ['conv2', 'bn2'], |
| | ['conv3', 'bn3'], |
| | ] |
| |
|
| | for pattern in fusion_patterns: |
| | try: |
| | quant.fuse_modules(self.model, pattern, inplace=True) |
| | logger.info(f"Fused modules: {pattern}") |
| | except Exception as e: |
| | logger.warning(f"Could not fuse {pattern}: {e}") |
| |
|
| | return self.model |
| |
|
| | def _insert_quant_stubs(self) -> nn.Module: |
| | """Insert quantization and dequantization stubs""" |
| | |
| | self.model.quant = QuantStub() |
| | self.model.dequant = DeQuantStub() |
| |
|
| | return self.model |
| |
|
| | def calibrate(self, calibration_loader: DataLoader, num_batches: int = 100): |
| | """Calibrate quantization parameters""" |
| | logger.info("Starting quantization calibration...") |
| |
|
| | self.model.eval() |
| | with torch.no_grad(): |
| | for i, (inputs, _) in enumerate(calibration_loader): |
| | if i >= num_batches: |
| | break |
| |
|
| | |
| | _ = self.model(inputs) |
| |
|
| | if i % 20 == 0: |
| | logger.info(f"Calibration progress: {i}/{num_batches}") |
| |
|
| | logger.info("Calibration completed") |
| |
|
| | def convert_to_quantized(self) -> nn.Module: |
| | """Convert to quantized model""" |
| | logger.info("Converting to quantized model...") |
| |
|
| | |
| | self.quantized_model = quant.convert(self.model.eval(), inplace=False) |
| |
|
| | logger.info("Model quantized successfully") |
| | return self.quantized_model |
| |
|
| | def quantize_static(self, calibration_loader: DataLoader) -> nn.Module: |
| | """Perform static quantization""" |
| | |
| | self.model.qconfig = quant.get_default_qconfig('fbgemm') |
| | quant.prepare(self.model, inplace=True) |
| |
|
| | |
| | self.calibrate(calibration_loader) |
| |
|
| | |
| | return self.convert_to_quantized() |
| |
|
| | def quantize_dynamic(self) -> nn.Module: |
| | """Perform dynamic quantization""" |
| | logger.info("Performing dynamic quantization...") |
| |
|
| | |
| | self.quantized_model = quant.quantize_dynamic( |
| | self.model, |
| | {nn.Linear, nn.LSTM, nn.GRU}, |
| | dtype=torch.qint8, |
| | inplace=False |
| | ) |
| |
|
| | logger.info("Dynamic quantization completed") |
| | return self.quantized_model |
| |
|
| | class AdvancedPruner: |
| | """Advanced model pruning utilities""" |
| |
|
| | def __init__(self, model: nn.Module, config: Dict): |
| | self.model = model |
| | self.config = config |
| | self.pruned_model = None |
| |
|
| | def apply_structured_pruning(self, amount: float = 0.3): |
| | """Apply structured pruning to convolutional layers""" |
| | logger.info(f"Applying structured pruning with amount: {amount}") |
| |
|
| | for name, module in self.model.named_modules(): |
| | if isinstance(module, nn.Conv2d): |
| | prune.ln_structured(module, name='weight', amount=amount, n=2, dim=0) |
| | logger.info(f"Pruned Conv2d layer: {name}") |
| |
|
| | return self.model |
| |
|
| | def apply_unstructured_pruning(self, amount: float = 0.2): |
| | """Apply unstructured pruning""" |
| | logger.info(f"Applying unstructured pruning with amount: {amount}") |
| |
|
| | for name, module in self.model.named_modules(): |
| | if isinstance(module, (nn.Conv2d, nn.Linear)): |
| | prune.l1_unstructured(module, name='weight', amount=amount) |
| | logger.info(f"Pruned layer: {name}") |
| |
|
| | return self.model |
| |
|
| | def remove_pruning_masks(self): |
| | """Remove pruning masks and make pruning permanent""" |
| | logger.info("Removing pruning masks...") |
| |
|
| | for name, module in self.model.named_modules(): |
| | if isinstance(module, (nn.Conv2d, nn.Linear)): |
| | prune.remove(module, 'weight') |
| |
|
| | logger.info("Pruning masks removed") |
| | return self.model |
| |
|
| | class ModelOptimizer: |
| | """Comprehensive model optimization pipeline""" |
| |
|
| | def __init__(self, model_path: str, config_path: str): |
| | self.model_path = Path(model_path) |
| | self.config = self._load_config(config_path) |
| | self.model = None |
| | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| | def _load_config(self, config_path: str) -> Dict: |
| | """Load optimization configuration""" |
| | with open(config_path, 'r') as f: |
| | return json.load(f) |
| |
|
| | def load_model(self): |
| | """Load the trained model""" |
| | logger.info(f"Loading model from {self.model_path}") |
| |
|
| | |
| | from models.advanced.advanced_fusion import AdvancedFusionModel |
| |
|
| | checkpoint = torch.load(self.model_path, map_location=self.device) |
| | self.model = AdvancedFusionModel(self.config['model']) |
| | self.model.load_state_dict(checkpoint['model_state_dict']) |
| | self.model.to(self.device) |
| | self.model.eval() |
| |
|
| | logger.info("Model loaded successfully") |
| | return self.model |
| |
|
| | def optimize_pipeline(self, output_dir: str = 'optimized_models'): |
| | """Run complete optimization pipeline""" |
| | output_dir = Path(output_dir) |
| | output_dir.mkdir(exist_ok=True) |
| |
|
| | |
| | if self.config.get('pruning', {}).get('enabled', False): |
| | pruner = AdvancedPruner(self.model, self.config['pruning']) |
| | if self.config['pruning']['type'] == 'structured': |
| | self.model = pruner.apply_structured_pruning( |
| | self.config['pruning']['amount'] |
| | ) |
| | else: |
| | self.model = pruner.apply_unstructured_pruning( |
| | self.config['pruning']['amount'] |
| | ) |
| | pruner.remove_pruning_masks() |
| |
|
| | |
| | self._save_model(self.model, output_dir / 'pruned_model.pth') |
| |
|
| | |
| | if self.config.get('quantization', {}).get('enabled', False): |
| | quantizer = AdvancedQuantizer(self.model, self.config['quantization']) |
| |
|
| | if self.config['quantization']['type'] == 'static': |
| | |
| | pass |
| | elif self.config['quantization']['type'] == 'dynamic': |
| | self.model = quantizer.quantize_dynamic() |
| | elif self.config['quantization']['type'] == 'qat': |
| | self.model = quantizer.prepare_for_quantization() |
| | |
| | self.model = quantizer.convert_to_quantized() |
| |
|
| | |
| | self._save_model(self.model, output_dir / 'quantized_model.pth') |
| |
|
| | |
| | if self.config.get('onnx', {}).get('enabled', False): |
| | self._export_onnx(output_dir / 'model.onnx') |
| |
|
| | |
| | if self.config.get('tensorrt', {}).get('enabled', False): |
| | self._optimize_tensorrt(output_dir) |
| |
|
| | logger.info("Optimization pipeline completed") |
| |
|
| | def _save_model(self, model: nn.Module, path: Path): |
| | """Save optimized model""" |
| | torch.save({ |
| | 'model_state_dict': model.state_dict(), |
| | 'config': self.config, |
| | 'optimization_info': { |
| | 'timestamp': time.time(), |
| | 'device': str(self.device), |
| | 'torch_version': torch.__version__ |
| | } |
| | }, path) |
| | logger.info(f"Model saved to {path}") |
| |
|
| | def _export_onnx(self, output_path: Path): |
| | """Export model to ONNX format""" |
| | logger.info("Exporting to ONNX...") |
| |
|
| | |
| | dummy_input = torch.randn(1, 3, 224, 224).to(self.device) |
| |
|
| | torch.onnx.export( |
| | self.model, |
| | dummy_input, |
| | output_path, |
| | export_params=True, |
| | opset_version=11, |
| | do_constant_folding=True, |
| | input_names=['input'], |
| | output_names=['output'], |
| | dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}} |
| | ) |
| |
|
| | logger.info(f"ONNX model exported to {output_path}") |
| |
|
| | def _optimize_tensorrt(self, output_dir: Path): |
| | """Optimize for TensorRT deployment""" |
| | logger.info("Optimizing for TensorRT...") |
| |
|
| | try: |
| | import torch_tensorrt |
| |
|
| | |
| | trt_model = torch_tensorrt.compile( |
| | self.model, |
| | inputs=[torch_tensorrt.Input((1, 3, 224, 224))], |
| | enabled_precisions={torch_tensorrt.dtype.f16} |
| | ) |
| |
|
| | |
| | torch.jit.save(trt_model, output_dir / 'tensorrt_model.pth') |
| |
|
| | logger.info("TensorRT optimization completed") |
| |
|
| | except ImportError: |
| | logger.warning("TensorRT not available, skipping optimization") |
| |
|
| | class EdgeDeploymentOptimizer: |
| | """Optimize models for edge deployment""" |
| |
|
| | def __init__(self, model: nn.Module, target_platform: str): |
| | self.model = model |
| | self.target_platform = target_platform |
| |
|
| | def optimize_for_mobile(self): |
| | """Optimize for mobile deployment""" |
| | logger.info("Optimizing for mobile deployment...") |
| |
|
| | |
| | self.model.qconfig = quant.get_default_qconfig('qnnpack') |
| | quant.prepare(self.model, inplace=True) |
| |
|
| | |
| | self.model = quant.convert(self.model, inplace=True) |
| |
|
| | return self.model |
| |
|
| | def optimize_for_web(self): |
| | """Optimize for web deployment (ONNX.js, WebGL)""" |
| | logger.info("Optimizing for web deployment...") |
| |
|
| | |
| | |
| |
|
| | return self.model |
| |
|
| | def optimize_for_embedded(self): |
| | """Optimize for embedded systems""" |
| | logger.info("Optimizing for embedded deployment...") |
| |
|
| | |
| | quantizer = AdvancedQuantizer(self.model, {'type': 'dynamic'}) |
| | self.model = quantizer.quantize_dynamic() |
| |
|
| | pruner = AdvancedPruner(self.model, {'type': 'unstructured', 'amount': 0.5}) |
| | self.model = pruner.apply_unstructured_pruning(0.5) |
| | pruner.remove_pruning_masks() |
| |
|
| | return self.model |
| |
|
| | def benchmark_model(model: nn.Module, input_shape: Tuple, num_runs: int = 100): |
| | """Benchmark model performance""" |
| | logger.info("Benchmarking model performance...") |
| |
|
| | model.eval() |
| | device = next(model.parameters()).device |
| |
|
| | |
| | dummy_input = torch.randn(input_shape).to(device) |
| | with torch.no_grad(): |
| | for _ in range(10): |
| | _ = model(dummy_input) |
| |
|
| | |
| | times = [] |
| | with torch.no_grad(): |
| | for _ in range(num_runs): |
| | start_time = time.time() |
| | _ = model(dummy_input) |
| | torch.cuda.synchronize() if device.type == 'cuda' else None |
| | times.append(time.time() - start_time) |
| |
|
| | avg_time = np.mean(times) |
| | std_time = np.std(times) |
| |
|
| | logger.info(".4f") |
| | logger.info(".4f") |
| | logger.info(".2f") |
| |
|
| | return { |
| | 'avg_inference_time': avg_time, |
| | 'std_inference_time': std_time, |
| | 'fps': 1.0 / avg_time, |
| | 'model_size_mb': calculate_model_size(model) |
| | } |
| |
|
| | def calculate_model_size(model: nn.Module) -> float: |
| | """Calculate model size in MB""" |
| | param_size = 0 |
| | for param in model.parameters(): |
| | param_size += param.nelement() * param.element_size() |
| |
|
| | buffer_size = 0 |
| | for buffer in model.buffers(): |
| | buffer_size += buffer.nelement() * buffer.element_size() |
| |
|
| | size_mb = (param_size + buffer_size) / 1024 / 1024 |
| | return size_mb |
| |
|
| | def main(): |
| | """Main optimization script""" |
| | import argparse |
| |
|
| | parser = argparse.ArgumentParser(description='EMOTIA Model Optimization') |
| | parser.add_argument('--model_path', required=True, help='Path to trained model') |
| | parser.add_argument('--config_path', required=True, help='Path to optimization config') |
| | parser.add_argument('--output_dir', default='optimized_models', help='Output directory') |
| | parser.add_argument('--benchmark', action='store_true', help='Run benchmarking') |
| |
|
| | args = parser.parse_args() |
| |
|
| | |
| | optimizer = ModelOptimizer(args.model_path, args.config_path) |
| | optimizer.load_model() |
| |
|
| | |
| | optimizer.optimize_pipeline(args.output_dir) |
| |
|
| | |
| | if args.benchmark: |
| | results = benchmark_model(optimizer.model, (1, 3, 224, 224)) |
| | with open(Path(args.output_dir) / 'benchmark_results.json', 'w') as f: |
| | json.dump(results, f, indent=2) |
| |
|
| | logger.info("Benchmarking completed") |
| |
|
| | if __name__ == '__main__': |
| | main() |