""" Model quantization utilities for faster inference and lower memory usage. Supports FP16, INT8, and dynamic quantization. """ import logging from pathlib import Path from typing import Dict, Optional import torch import torch.nn as nn logger = logging.getLogger(__name__) def quantize_fp16(model: nn.Module) -> nn.Module: """ Convert model to FP16 (half precision). Args: model: Model to quantize Returns: FP16 quantized model """ model = model.half() logger.info("Model quantized to FP16") return model def quantize_dynamic_int8( model: nn.Module, quantizable_modules: Optional[list] = None, ) -> nn.Module: """ Apply dynamic INT8 quantization to model. Args: model: Model to quantize quantizable_modules: List of module types to quantize (default: Linear, Conv2d) Returns: INT8 quantized model """ if quantizable_modules is None: quantizable_modules = [torch.nn.Linear, torch.nn.Conv2d] try: quantized_model = torch.quantization.quantize_dynamic( model, quantizable_modules, dtype=torch.qint8, ) logger.info(f"Model quantized to INT8 (modules: {quantizable_modules})") return quantized_model except Exception as e: logger.error(f"INT8 quantization failed: {e}") logger.warning("Falling back to FP16 quantization") return quantize_fp16(model) def quantize_static_int8( model: nn.Module, calibration_data, quantizable_modules: Optional[list] = None, ) -> nn.Module: """ Apply static INT8 quantization with calibration data. Args: model: Model to quantize calibration_data: DataLoader or list of inputs for calibration quantizable_modules: List of module types to quantize Returns: INT8 quantized model """ if quantizable_modules is None: quantizable_modules = [torch.nn.Linear, torch.nn.Conv2d] model.eval() # Prepare model for quantization model.qconfig = torch.quantization.get_default_qconfig("fbgemm") torch.quantization.prepare(model, inplace=True) # Calibrate with data logger.info("Calibrating model for static quantization...") with torch.no_grad(): if hasattr(calibration_data, "__iter__"): for i, data in enumerate(calibration_data): if isinstance(data, (list, tuple)): inputs = data[0] else: inputs = data model(inputs) if i >= 100: # Limit calibration samples break else: for inputs in calibration_data[:100]: model(inputs) # Convert to quantized quantized_model = torch.quantization.convert(model, inplace=False) logger.info("Model quantized to static INT8") return quantized_model def save_quantized_model( model: nn.Module, output_path: Path, quantization_type: str = "fp16", ): """ Save quantized model. Args: model: Quantized model output_path: Path to save model quantization_type: Type of quantization ('fp16', 'int8') """ output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) if quantization_type == "fp16": torch.save(model.state_dict(), output_path) else: # For INT8, save the full model (quantization state needed) torch.save(model, output_path) logger.info(f"Quantized model saved to {output_path}") def load_quantized_model( model: nn.Module, checkpoint_path: Path, quantization_type: str = "fp16", device: str = "cuda", ) -> nn.Module: """ Load quantized model. Args: model: Base model architecture checkpoint_path: Path to quantized checkpoint quantization_type: Type of quantization device: Device to load on Returns: Loaded quantized model """ checkpoint_path = Path(checkpoint_path) if quantization_type == "fp16": state_dict = torch.load(checkpoint_path, map_location=device) model.load_state_dict(state_dict) model = model.half() else: # For INT8, load full model model = torch.load(checkpoint_path, map_location=device) logger.info(f"Quantized model loaded from {checkpoint_path}") return model def compare_model_sizes( model_fp32: nn.Module, model_quantized: nn.Module, ) -> Dict[str, float]: """ Compare model sizes between FP32 and quantized versions. Args: model_fp32: Original FP32 model model_quantized: Quantized model Returns: Dict with size comparisons """ def get_model_size(model): param_size = sum(p.numel() * p.element_size() for p in model.parameters()) buffer_size = sum(b.numel() * b.element_size() for b in model.buffers()) return param_size + buffer_size size_fp32 = get_model_size(model_fp32) size_quantized = get_model_size(model_quantized) reduction = (1 - size_quantized / size_fp32) * 100 return { "fp32_size_mb": size_fp32 / 1024 / 1024, "quantized_size_mb": size_quantized / 1024 / 1024, "reduction_percent": reduction, } def benchmark_quantized_model( model: nn.Module, sample_input, num_runs: int = 100, device: str = "cuda", ) -> Dict[str, float]: """ Benchmark quantized model inference speed. Args: model: Model to benchmark sample_input: Sample input tensor num_runs: Number of inference runs device: Device to run on Returns: Dict with timing statistics """ model.eval() model = model.to(device) if isinstance(sample_input, list): sample_input = [x.to(device) for x in sample_input] else: sample_input = sample_input.to(device) # Warmup with torch.no_grad(): for _ in range(10): if isinstance(sample_input, list): _ = model.inference(sample_input) else: _ = model(sample_input) # Benchmark torch.cuda.synchronize() import time start_time = time.time() with torch.no_grad(): for _ in range(num_runs): if isinstance(sample_input, list): _ = model.inference(sample_input) else: _ = model(sample_input) torch.cuda.synchronize() end_time = time.time() avg_time = (end_time - start_time) / num_runs fps = 1.0 / avg_time return { "avg_inference_time_ms": avg_time * 1000, "fps": fps, "total_time_s": end_time - start_time, }