File size: 9,981 Bytes
ef6446c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 |
#!/usr/bin/env python3
"""
Model Quantization Utilities
This module provides utilities for model quantization to reduce memory usage
and improve inference speed while maintaining reasonable accuracy.
Author: Louis Chua Bean Chong
License: GPLv3
"""
import torch
import torch.nn as nn
import torch.quantization as quantization
from typing import Optional, Dict, Any
import copy
class QuantizedModel:
"""
Wrapper for quantized models with easy conversion and inference.
This class provides utilities for converting models to quantized versions
and performing efficient inference with reduced memory usage.
"""
def __init__(self, model: nn.Module, quantized_model: Optional[nn.Module] = None):
"""
Initialize quantized model wrapper.
Args:
model: Original model
quantized_model: Pre-quantized model (optional)
"""
self.original_model = model
self.quantized_model = quantized_model
self.is_quantized = quantized_model is not None
def quantize_dynamic(self,
qconfig_spec: Optional[Dict] = None,
dtype: torch.dtype = torch.qint8) -> 'QuantizedModel':
"""
Perform dynamic quantization on the model.
Args:
qconfig_spec: Quantization configuration
dtype: Quantization dtype (qint8, quint8)
Returns:
QuantizedModel: Self with quantized model
"""
if qconfig_spec is None:
qconfig_spec = {
nn.Linear: quantization.default_dynamic_qconfig,
nn.LSTM: quantization.default_dynamic_qconfig,
nn.LSTMCell: quantization.default_dynamic_qconfig,
nn.RNNCell: quantization.default_dynamic_qconfig,
nn.GRUCell: quantization.default_dynamic_qconfig,
}
# Create a copy of the model for quantization
model_copy = copy.deepcopy(self.original_model)
model_copy.eval()
# Prepare model for quantization
model_prepared = quantization.prepare_dynamic(model_copy, qconfig_spec)
# Convert to quantized model
self.quantized_model = quantization.convert(model_prepared)
self.is_quantized = True
print(f"Dynamic quantization completed with dtype: {dtype}")
return self
def quantize_static(self,
calibration_data: torch.utils.data.DataLoader,
qconfig: Optional[quantization.QConfig] = None) -> 'QuantizedModel':
"""
Perform static quantization on the model.
Args:
calibration_data: DataLoader for calibration
qconfig: Quantization configuration
Returns:
QuantizedModel: Self with quantized model
"""
if qconfig is None:
qconfig = quantization.get_default_qconfig('fbgemm')
# Create a copy of the model for quantization
model_copy = copy.deepcopy(self.original_model)
model_copy.eval()
# Prepare model for quantization
model_prepared = quantization.prepare(model_copy, qconfig)
# Calibrate the model
print("Calibrating model...")
with torch.no_grad():
for batch_idx, (data, _) in enumerate(calibration_data):
if batch_idx >= 100: # Limit calibration samples
break
model_prepared(data)
# Convert to quantized model
self.quantized_model = quantization.convert(model_prepared)
self.is_quantized = True
print("Static quantization completed")
return self
def forward(self, *args, **kwargs):
"""Forward pass using quantized model if available."""
if self.is_quantized and self.quantized_model is not None:
return self.quantized_model(*args, **kwargs)
else:
return self.original_model(*args, **kwargs)
def get_memory_usage(self) -> Dict[str, float]:
"""
Get memory usage comparison between original and quantized models.
Returns:
dict: Memory usage in MB
"""
def get_model_size(model):
param_size = 0
buffer_size = 0
for param in model.parameters():
param_size += param.nelement() * param.element_size()
for buffer in model.buffers():
buffer_size += buffer.nelement() * buffer.element_size()
return (param_size + buffer_size) / (1024 * 1024) # Convert to MB
original_size = get_model_size(self.original_model)
quantized_size = get_model_size(self.quantized_model) if self.quantized_model else original_size
return {
"original_mb": original_size,
"quantized_mb": quantized_size,
"compression_ratio": original_size / quantized_size if quantized_size > 0 else 1.0
}
def save_quantized(self, path: str):
"""Save quantized model."""
if self.quantized_model is not None:
torch.save(self.quantized_model.state_dict(), path)
print(f"Quantized model saved to: {path}")
else:
raise ValueError("No quantized model available")
def load_quantized(self, path: str):
"""Load quantized model."""
self.quantized_model.load_state_dict(torch.load(path))
self.is_quantized = True
print(f"Quantized model loaded from: {path}")
def quantize_model_dynamic(model: nn.Module,
dtype: torch.dtype = torch.qint8) -> QuantizedModel:
"""
Convenience function for dynamic quantization.
Args:
model: Model to quantize
dtype: Quantization dtype
Returns:
QuantizedModel: Quantized model wrapper
"""
quantized = QuantizedModel(model)
return quantized.quantize_dynamic(dtype=dtype)
def quantize_model_static(model: nn.Module,
calibration_data: torch.utils.data.DataLoader,
qconfig: Optional[quantization.QConfig] = None) -> QuantizedModel:
"""
Convenience function for static quantization.
Args:
model: Model to quantize
calibration_data: Data for calibration
qconfig: Quantization configuration
Returns:
QuantizedModel: Quantized model wrapper
"""
quantized = QuantizedModel(model)
return quantized.quantize_static(calibration_data, qconfig)
def create_quantization_config(backend: str = 'fbgemm',
dtype: torch.dtype = torch.qint8) -> quantization.QConfig:
"""
Create quantization configuration.
Args:
backend: Quantization backend ('fbgemm', 'qnnpack')
dtype: Quantization dtype
Returns:
QConfig: Quantization configuration
"""
if backend == 'fbgemm':
return quantization.QConfig(
activation=quantization.default_observer,
weight=quantization.default_per_channel_weight_observer
)
elif backend == 'qnnpack':
return quantization.QConfig(
activation=quantization.default_observer,
weight=quantization.default_weight_observer
)
else:
raise ValueError(f"Unsupported backend: {backend}")
def benchmark_quantization(original_model: nn.Module,
quantized_model: QuantizedModel,
test_data: torch.Tensor,
num_runs: int = 100) -> Dict[str, float]:
"""
Benchmark original vs quantized model performance.
Args:
original_model: Original model
quantized_model: Quantized model
test_data: Test data for benchmarking
num_runs: Number of runs for averaging
Returns:
dict: Performance metrics
"""
original_model.eval()
quantized_model.quantized_model.eval()
# Benchmark original model
start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
end_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
if start_time:
start_time.record()
with torch.no_grad():
for _ in range(num_runs):
_ = original_model(test_data)
if end_time:
end_time.record()
torch.cuda.synchronize()
original_time = start_time.elapsed_time(end_time) / num_runs
else:
import time
start = time.time()
for _ in range(num_runs):
_ = original_model(test_data)
original_time = (time.time() - start) * 1000 / num_runs # Convert to ms
# Benchmark quantized model
if start_time:
start_time.record()
with torch.no_grad():
for _ in range(num_runs):
_ = quantized_model.quantized_model(test_data)
if end_time:
end_time.record()
torch.cuda.synchronize()
quantized_time = start_time.elapsed_time(end_time) / num_runs
else:
start = time.time()
for _ in range(num_runs):
_ = quantized_model.quantized_model(test_data)
quantized_time = (time.time() - start) * 1000 / num_runs # Convert to ms
return {
"original_time_ms": original_time,
"quantized_time_ms": quantized_time,
"speedup": original_time / quantized_time if quantized_time > 0 else 1.0
}
|