|
|
""" |
|
|
Model quantization utilities for faster inference and lower memory usage. |
|
|
|
|
|
Supports FP16, INT8, and dynamic quantization. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Dict, Optional |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def quantize_fp16(model: nn.Module) -> nn.Module: |
|
|
""" |
|
|
Convert model to FP16 (half precision). |
|
|
|
|
|
Args: |
|
|
model: Model to quantize |
|
|
|
|
|
Returns: |
|
|
FP16 quantized model |
|
|
""" |
|
|
model = model.half() |
|
|
logger.info("Model quantized to FP16") |
|
|
return model |
|
|
|
|
|
|
|
|
def quantize_dynamic_int8( |
|
|
model: nn.Module, |
|
|
quantizable_modules: Optional[list] = None, |
|
|
) -> nn.Module: |
|
|
""" |
|
|
Apply dynamic INT8 quantization to model. |
|
|
|
|
|
Args: |
|
|
model: Model to quantize |
|
|
quantizable_modules: List of module types to quantize (default: Linear, Conv2d) |
|
|
|
|
|
Returns: |
|
|
INT8 quantized model |
|
|
""" |
|
|
if quantizable_modules is None: |
|
|
quantizable_modules = [torch.nn.Linear, torch.nn.Conv2d] |
|
|
|
|
|
try: |
|
|
quantized_model = torch.quantization.quantize_dynamic( |
|
|
model, |
|
|
quantizable_modules, |
|
|
dtype=torch.qint8, |
|
|
) |
|
|
logger.info(f"Model quantized to INT8 (modules: {quantizable_modules})") |
|
|
return quantized_model |
|
|
except Exception as e: |
|
|
logger.error(f"INT8 quantization failed: {e}") |
|
|
logger.warning("Falling back to FP16 quantization") |
|
|
return quantize_fp16(model) |
|
|
|
|
|
|
|
|
def quantize_static_int8( |
|
|
model: nn.Module, |
|
|
calibration_data, |
|
|
quantizable_modules: Optional[list] = None, |
|
|
) -> nn.Module: |
|
|
""" |
|
|
Apply static INT8 quantization with calibration data. |
|
|
|
|
|
Args: |
|
|
model: Model to quantize |
|
|
calibration_data: DataLoader or list of inputs for calibration |
|
|
quantizable_modules: List of module types to quantize |
|
|
|
|
|
Returns: |
|
|
INT8 quantized model |
|
|
""" |
|
|
if quantizable_modules is None: |
|
|
quantizable_modules = [torch.nn.Linear, torch.nn.Conv2d] |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
model.qconfig = torch.quantization.get_default_qconfig("fbgemm") |
|
|
torch.quantization.prepare(model, inplace=True) |
|
|
|
|
|
|
|
|
logger.info("Calibrating model for static quantization...") |
|
|
with torch.no_grad(): |
|
|
if hasattr(calibration_data, "__iter__"): |
|
|
for i, data in enumerate(calibration_data): |
|
|
if isinstance(data, (list, tuple)): |
|
|
inputs = data[0] |
|
|
else: |
|
|
inputs = data |
|
|
model(inputs) |
|
|
if i >= 100: |
|
|
break |
|
|
else: |
|
|
for inputs in calibration_data[:100]: |
|
|
model(inputs) |
|
|
|
|
|
|
|
|
quantized_model = torch.quantization.convert(model, inplace=False) |
|
|
logger.info("Model quantized to static INT8") |
|
|
return quantized_model |
|
|
|
|
|
|
|
|
def save_quantized_model( |
|
|
model: nn.Module, |
|
|
output_path: Path, |
|
|
quantization_type: str = "fp16", |
|
|
): |
|
|
""" |
|
|
Save quantized model. |
|
|
|
|
|
Args: |
|
|
model: Quantized model |
|
|
output_path: Path to save model |
|
|
quantization_type: Type of quantization ('fp16', 'int8') |
|
|
""" |
|
|
output_path = Path(output_path) |
|
|
output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
if quantization_type == "fp16": |
|
|
torch.save(model.state_dict(), output_path) |
|
|
else: |
|
|
|
|
|
torch.save(model, output_path) |
|
|
|
|
|
logger.info(f"Quantized model saved to {output_path}") |
|
|
|
|
|
|
|
|
def load_quantized_model( |
|
|
model: nn.Module, |
|
|
checkpoint_path: Path, |
|
|
quantization_type: str = "fp16", |
|
|
device: str = "cuda", |
|
|
) -> nn.Module: |
|
|
""" |
|
|
Load quantized model. |
|
|
|
|
|
Args: |
|
|
model: Base model architecture |
|
|
checkpoint_path: Path to quantized checkpoint |
|
|
quantization_type: Type of quantization |
|
|
device: Device to load on |
|
|
|
|
|
Returns: |
|
|
Loaded quantized model |
|
|
""" |
|
|
checkpoint_path = Path(checkpoint_path) |
|
|
|
|
|
if quantization_type == "fp16": |
|
|
state_dict = torch.load(checkpoint_path, map_location=device) |
|
|
model.load_state_dict(state_dict) |
|
|
model = model.half() |
|
|
else: |
|
|
|
|
|
model = torch.load(checkpoint_path, map_location=device) |
|
|
|
|
|
logger.info(f"Quantized model loaded from {checkpoint_path}") |
|
|
return model |
|
|
|
|
|
|
|
|
def compare_model_sizes( |
|
|
model_fp32: nn.Module, |
|
|
model_quantized: nn.Module, |
|
|
) -> Dict[str, float]: |
|
|
""" |
|
|
Compare model sizes between FP32 and quantized versions. |
|
|
|
|
|
Args: |
|
|
model_fp32: Original FP32 model |
|
|
model_quantized: Quantized model |
|
|
|
|
|
Returns: |
|
|
Dict with size comparisons |
|
|
""" |
|
|
|
|
|
def get_model_size(model): |
|
|
param_size = sum(p.numel() * p.element_size() for p in model.parameters()) |
|
|
buffer_size = sum(b.numel() * b.element_size() for b in model.buffers()) |
|
|
return param_size + buffer_size |
|
|
|
|
|
size_fp32 = get_model_size(model_fp32) |
|
|
size_quantized = get_model_size(model_quantized) |
|
|
|
|
|
reduction = (1 - size_quantized / size_fp32) * 100 |
|
|
|
|
|
return { |
|
|
"fp32_size_mb": size_fp32 / 1024 / 1024, |
|
|
"quantized_size_mb": size_quantized / 1024 / 1024, |
|
|
"reduction_percent": reduction, |
|
|
} |
|
|
|
|
|
|
|
|
def benchmark_quantized_model( |
|
|
model: nn.Module, |
|
|
sample_input, |
|
|
num_runs: int = 100, |
|
|
device: str = "cuda", |
|
|
) -> Dict[str, float]: |
|
|
""" |
|
|
Benchmark quantized model inference speed. |
|
|
|
|
|
Args: |
|
|
model: Model to benchmark |
|
|
sample_input: Sample input tensor |
|
|
num_runs: Number of inference runs |
|
|
device: Device to run on |
|
|
|
|
|
Returns: |
|
|
Dict with timing statistics |
|
|
""" |
|
|
model.eval() |
|
|
model = model.to(device) |
|
|
|
|
|
if isinstance(sample_input, list): |
|
|
sample_input = [x.to(device) for x in sample_input] |
|
|
else: |
|
|
sample_input = sample_input.to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
for _ in range(10): |
|
|
if isinstance(sample_input, list): |
|
|
_ = model.inference(sample_input) |
|
|
else: |
|
|
_ = model(sample_input) |
|
|
|
|
|
|
|
|
torch.cuda.synchronize() |
|
|
import time |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
with torch.no_grad(): |
|
|
for _ in range(num_runs): |
|
|
if isinstance(sample_input, list): |
|
|
_ = model.inference(sample_input) |
|
|
else: |
|
|
_ = model(sample_input) |
|
|
|
|
|
torch.cuda.synchronize() |
|
|
end_time = time.time() |
|
|
|
|
|
avg_time = (end_time - start_time) / num_runs |
|
|
fps = 1.0 / avg_time |
|
|
|
|
|
return { |
|
|
"avg_inference_time_ms": avg_time * 1000, |
|
|
"fps": fps, |
|
|
"total_time_s": end_time - start_time, |
|
|
} |
|
|
|