|
|
|
|
|
"""
|
|
|
Model Quantization Utilities
|
|
|
|
|
|
This module provides utilities for model quantization to reduce memory usage
|
|
|
and improve inference speed while maintaining reasonable accuracy.
|
|
|
|
|
|
Author: Louis Chua Bean Chong
|
|
|
License: GPLv3
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.quantization as quantization
|
|
|
from typing import Optional, Dict, Any
|
|
|
import copy
|
|
|
|
|
|
|
|
|
class QuantizedModel:
|
|
|
"""
|
|
|
Wrapper for quantized models with easy conversion and inference.
|
|
|
|
|
|
This class provides utilities for converting models to quantized versions
|
|
|
and performing efficient inference with reduced memory usage.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, model: nn.Module, quantized_model: Optional[nn.Module] = None):
|
|
|
"""
|
|
|
Initialize quantized model wrapper.
|
|
|
|
|
|
Args:
|
|
|
model: Original model
|
|
|
quantized_model: Pre-quantized model (optional)
|
|
|
"""
|
|
|
self.original_model = model
|
|
|
self.quantized_model = quantized_model
|
|
|
self.is_quantized = quantized_model is not None
|
|
|
|
|
|
def quantize_dynamic(self,
|
|
|
qconfig_spec: Optional[Dict] = None,
|
|
|
dtype: torch.dtype = torch.qint8) -> 'QuantizedModel':
|
|
|
"""
|
|
|
Perform dynamic quantization on the model.
|
|
|
|
|
|
Args:
|
|
|
qconfig_spec: Quantization configuration
|
|
|
dtype: Quantization dtype (qint8, quint8)
|
|
|
|
|
|
Returns:
|
|
|
QuantizedModel: Self with quantized model
|
|
|
"""
|
|
|
if qconfig_spec is None:
|
|
|
qconfig_spec = {
|
|
|
nn.Linear: quantization.default_dynamic_qconfig,
|
|
|
nn.LSTM: quantization.default_dynamic_qconfig,
|
|
|
nn.LSTMCell: quantization.default_dynamic_qconfig,
|
|
|
nn.RNNCell: quantization.default_dynamic_qconfig,
|
|
|
nn.GRUCell: quantization.default_dynamic_qconfig,
|
|
|
}
|
|
|
|
|
|
|
|
|
model_copy = copy.deepcopy(self.original_model)
|
|
|
model_copy.eval()
|
|
|
|
|
|
|
|
|
model_prepared = quantization.prepare_dynamic(model_copy, qconfig_spec)
|
|
|
|
|
|
|
|
|
self.quantized_model = quantization.convert(model_prepared)
|
|
|
self.is_quantized = True
|
|
|
|
|
|
print(f"Dynamic quantization completed with dtype: {dtype}")
|
|
|
return self
|
|
|
|
|
|
def quantize_static(self,
|
|
|
calibration_data: torch.utils.data.DataLoader,
|
|
|
qconfig: Optional[quantization.QConfig] = None) -> 'QuantizedModel':
|
|
|
"""
|
|
|
Perform static quantization on the model.
|
|
|
|
|
|
Args:
|
|
|
calibration_data: DataLoader for calibration
|
|
|
qconfig: Quantization configuration
|
|
|
|
|
|
Returns:
|
|
|
QuantizedModel: Self with quantized model
|
|
|
"""
|
|
|
if qconfig is None:
|
|
|
qconfig = quantization.get_default_qconfig('fbgemm')
|
|
|
|
|
|
|
|
|
model_copy = copy.deepcopy(self.original_model)
|
|
|
model_copy.eval()
|
|
|
|
|
|
|
|
|
model_prepared = quantization.prepare(model_copy, qconfig)
|
|
|
|
|
|
|
|
|
print("Calibrating model...")
|
|
|
with torch.no_grad():
|
|
|
for batch_idx, (data, _) in enumerate(calibration_data):
|
|
|
if batch_idx >= 100:
|
|
|
break
|
|
|
model_prepared(data)
|
|
|
|
|
|
|
|
|
self.quantized_model = quantization.convert(model_prepared)
|
|
|
self.is_quantized = True
|
|
|
|
|
|
print("Static quantization completed")
|
|
|
return self
|
|
|
|
|
|
def forward(self, *args, **kwargs):
|
|
|
"""Forward pass using quantized model if available."""
|
|
|
if self.is_quantized and self.quantized_model is not None:
|
|
|
return self.quantized_model(*args, **kwargs)
|
|
|
else:
|
|
|
return self.original_model(*args, **kwargs)
|
|
|
|
|
|
def get_memory_usage(self) -> Dict[str, float]:
|
|
|
"""
|
|
|
Get memory usage comparison between original and quantized models.
|
|
|
|
|
|
Returns:
|
|
|
dict: Memory usage in MB
|
|
|
"""
|
|
|
def get_model_size(model):
|
|
|
param_size = 0
|
|
|
buffer_size = 0
|
|
|
|
|
|
for param in model.parameters():
|
|
|
param_size += param.nelement() * param.element_size()
|
|
|
|
|
|
for buffer in model.buffers():
|
|
|
buffer_size += buffer.nelement() * buffer.element_size()
|
|
|
|
|
|
return (param_size + buffer_size) / (1024 * 1024)
|
|
|
|
|
|
original_size = get_model_size(self.original_model)
|
|
|
quantized_size = get_model_size(self.quantized_model) if self.quantized_model else original_size
|
|
|
|
|
|
return {
|
|
|
"original_mb": original_size,
|
|
|
"quantized_mb": quantized_size,
|
|
|
"compression_ratio": original_size / quantized_size if quantized_size > 0 else 1.0
|
|
|
}
|
|
|
|
|
|
def save_quantized(self, path: str):
|
|
|
"""Save quantized model."""
|
|
|
if self.quantized_model is not None:
|
|
|
torch.save(self.quantized_model.state_dict(), path)
|
|
|
print(f"Quantized model saved to: {path}")
|
|
|
else:
|
|
|
raise ValueError("No quantized model available")
|
|
|
|
|
|
def load_quantized(self, path: str):
|
|
|
"""Load quantized model."""
|
|
|
self.quantized_model.load_state_dict(torch.load(path))
|
|
|
self.is_quantized = True
|
|
|
print(f"Quantized model loaded from: {path}")
|
|
|
|
|
|
|
|
|
def quantize_model_dynamic(model: nn.Module,
|
|
|
dtype: torch.dtype = torch.qint8) -> QuantizedModel:
|
|
|
"""
|
|
|
Convenience function for dynamic quantization.
|
|
|
|
|
|
Args:
|
|
|
model: Model to quantize
|
|
|
dtype: Quantization dtype
|
|
|
|
|
|
Returns:
|
|
|
QuantizedModel: Quantized model wrapper
|
|
|
"""
|
|
|
quantized = QuantizedModel(model)
|
|
|
return quantized.quantize_dynamic(dtype=dtype)
|
|
|
|
|
|
|
|
|
def quantize_model_static(model: nn.Module,
|
|
|
calibration_data: torch.utils.data.DataLoader,
|
|
|
qconfig: Optional[quantization.QConfig] = None) -> QuantizedModel:
|
|
|
"""
|
|
|
Convenience function for static quantization.
|
|
|
|
|
|
Args:
|
|
|
model: Model to quantize
|
|
|
calibration_data: Data for calibration
|
|
|
qconfig: Quantization configuration
|
|
|
|
|
|
Returns:
|
|
|
QuantizedModel: Quantized model wrapper
|
|
|
"""
|
|
|
quantized = QuantizedModel(model)
|
|
|
return quantized.quantize_static(calibration_data, qconfig)
|
|
|
|
|
|
|
|
|
def create_quantization_config(backend: str = 'fbgemm',
|
|
|
dtype: torch.dtype = torch.qint8) -> quantization.QConfig:
|
|
|
"""
|
|
|
Create quantization configuration.
|
|
|
|
|
|
Args:
|
|
|
backend: Quantization backend ('fbgemm', 'qnnpack')
|
|
|
dtype: Quantization dtype
|
|
|
|
|
|
Returns:
|
|
|
QConfig: Quantization configuration
|
|
|
"""
|
|
|
if backend == 'fbgemm':
|
|
|
return quantization.QConfig(
|
|
|
activation=quantization.default_observer,
|
|
|
weight=quantization.default_per_channel_weight_observer
|
|
|
)
|
|
|
elif backend == 'qnnpack':
|
|
|
return quantization.QConfig(
|
|
|
activation=quantization.default_observer,
|
|
|
weight=quantization.default_weight_observer
|
|
|
)
|
|
|
else:
|
|
|
raise ValueError(f"Unsupported backend: {backend}")
|
|
|
|
|
|
|
|
|
def benchmark_quantization(original_model: nn.Module,
|
|
|
quantized_model: QuantizedModel,
|
|
|
test_data: torch.Tensor,
|
|
|
num_runs: int = 100) -> Dict[str, float]:
|
|
|
"""
|
|
|
Benchmark original vs quantized model performance.
|
|
|
|
|
|
Args:
|
|
|
original_model: Original model
|
|
|
quantized_model: Quantized model
|
|
|
test_data: Test data for benchmarking
|
|
|
num_runs: Number of runs for averaging
|
|
|
|
|
|
Returns:
|
|
|
dict: Performance metrics
|
|
|
"""
|
|
|
original_model.eval()
|
|
|
quantized_model.quantized_model.eval()
|
|
|
|
|
|
|
|
|
start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
|
|
|
end_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
|
|
|
|
|
|
if start_time:
|
|
|
start_time.record()
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for _ in range(num_runs):
|
|
|
_ = original_model(test_data)
|
|
|
|
|
|
if end_time:
|
|
|
end_time.record()
|
|
|
torch.cuda.synchronize()
|
|
|
original_time = start_time.elapsed_time(end_time) / num_runs
|
|
|
else:
|
|
|
import time
|
|
|
start = time.time()
|
|
|
for _ in range(num_runs):
|
|
|
_ = original_model(test_data)
|
|
|
original_time = (time.time() - start) * 1000 / num_runs
|
|
|
|
|
|
|
|
|
if start_time:
|
|
|
start_time.record()
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for _ in range(num_runs):
|
|
|
_ = quantized_model.quantized_model(test_data)
|
|
|
|
|
|
if end_time:
|
|
|
end_time.record()
|
|
|
torch.cuda.synchronize()
|
|
|
quantized_time = start_time.elapsed_time(end_time) / num_runs
|
|
|
else:
|
|
|
start = time.time()
|
|
|
for _ in range(num_runs):
|
|
|
_ = quantized_model.quantized_model(test_data)
|
|
|
quantized_time = (time.time() - start) * 1000 / num_runs
|
|
|
|
|
|
return {
|
|
|
"original_time_ms": original_time,
|
|
|
"quantized_time_ms": quantized_time,
|
|
|
"speedup": original_time / quantized_time if quantized_time > 0 else 1.0
|
|
|
}
|
|
|
|