|
|
""" |
|
|
MiniMind Quantization Toolkit |
|
|
INT4/INT8 quantization for efficient inference on edge devices. |
|
|
""" |
|
|
|
|
|
import math |
|
|
from typing import Optional, Dict, Any, Tuple, List |
|
|
from pathlib import Path |
|
|
from dataclasses import dataclass |
|
|
from enum import Enum |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class QuantizationType(Enum): |
|
|
"""Supported quantization types.""" |
|
|
INT8_DYNAMIC = "int8_dynamic" |
|
|
INT8_STATIC = "int8_static" |
|
|
INT4_AWQ = "int4_awq" |
|
|
INT4_GPTQ = "int4_gptq" |
|
|
FP8 = "fp8" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class QuantizationConfig: |
|
|
"""Configuration for quantization.""" |
|
|
quant_type: QuantizationType = QuantizationType.INT4_AWQ |
|
|
bits: int = 4 |
|
|
group_size: int = 128 |
|
|
use_double_quant: bool = False |
|
|
compute_dtype: torch.dtype = torch.float16 |
|
|
calibration_samples: int = 128 |
|
|
calibration_seq_len: int = 512 |
|
|
|
|
|
|
|
|
class Int4Linear(nn.Module): |
|
|
"""INT4 quantized linear layer with group-wise quantization.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_features: int, |
|
|
out_features: int, |
|
|
bias: bool = False, |
|
|
group_size: int = 128, |
|
|
): |
|
|
super().__init__() |
|
|
self.in_features = in_features |
|
|
self.out_features = out_features |
|
|
self.group_size = group_size |
|
|
|
|
|
|
|
|
self.num_groups = math.ceil(in_features / group_size) |
|
|
|
|
|
|
|
|
packed_size = out_features * math.ceil(in_features / 2) |
|
|
self.register_buffer("qweight", torch.zeros(packed_size, dtype=torch.uint8)) |
|
|
|
|
|
|
|
|
self.register_buffer("scales", torch.zeros(out_features, self.num_groups, dtype=torch.float16)) |
|
|
self.register_buffer("zeros", torch.zeros(out_features, self.num_groups, dtype=torch.float16)) |
|
|
|
|
|
if bias: |
|
|
self.register_buffer("bias", torch.zeros(out_features, dtype=torch.float16)) |
|
|
else: |
|
|
self.bias = None |
|
|
|
|
|
@staticmethod |
|
|
def pack_int4(values: torch.Tensor) -> torch.Tensor: |
|
|
"""Pack two INT4 values into one INT8.""" |
|
|
assert values.shape[-1] % 2 == 0 |
|
|
low = values[..., 0::2] & 0xF |
|
|
high = values[..., 1::2] & 0xF |
|
|
return (high << 4 | low).to(torch.uint8) |
|
|
|
|
|
@staticmethod |
|
|
def unpack_int4(packed: torch.Tensor) -> torch.Tensor: |
|
|
"""Unpack INT8 to two INT4 values.""" |
|
|
low = packed & 0xF |
|
|
high = (packed >> 4) & 0xF |
|
|
return torch.stack([low, high], dim=-1).flatten(-2) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""Dequantize and compute linear transformation.""" |
|
|
input_dtype = x.dtype |
|
|
|
|
|
|
|
|
unpacked = self.unpack_int4(self.qweight) |
|
|
unpacked = unpacked.view(self.out_features, self.in_features) |
|
|
|
|
|
|
|
|
weight = torch.zeros(self.out_features, self.in_features, dtype=self.scales.dtype, device=x.device) |
|
|
for g in range(self.num_groups): |
|
|
start = g * self.group_size |
|
|
end = min((g + 1) * self.group_size, self.in_features) |
|
|
weight[:, start:end] = (unpacked[:, start:end].float() - self.zeros[:, g:g+1]) * self.scales[:, g:g+1] |
|
|
|
|
|
weight = weight.to(input_dtype) |
|
|
output = F.linear(x, weight, self.bias) |
|
|
return output |
|
|
|
|
|
@classmethod |
|
|
def from_float(cls, module: nn.Linear, group_size: int = 128) -> "Int4Linear": |
|
|
"""Convert a float linear layer to INT4.""" |
|
|
int4_layer = cls( |
|
|
module.in_features, |
|
|
module.out_features, |
|
|
bias=module.bias is not None, |
|
|
group_size=group_size, |
|
|
) |
|
|
|
|
|
weight = module.weight.data.float() |
|
|
out_features, in_features = weight.shape |
|
|
|
|
|
|
|
|
num_groups = math.ceil(in_features / group_size) |
|
|
qweight = torch.zeros_like(weight, dtype=torch.int8) |
|
|
|
|
|
for g in range(num_groups): |
|
|
start = g * group_size |
|
|
end = min((g + 1) * group_size, in_features) |
|
|
group_weight = weight[:, start:end] |
|
|
|
|
|
|
|
|
min_val = group_weight.min(dim=1, keepdim=True)[0] |
|
|
max_val = group_weight.max(dim=1, keepdim=True)[0] |
|
|
|
|
|
scale = (max_val - min_val) / 15.0 |
|
|
scale = scale.clamp(min=1e-8) |
|
|
zero = -min_val / scale |
|
|
|
|
|
int4_layer.scales[:, g] = scale.squeeze().to(torch.float16) |
|
|
int4_layer.zeros[:, g] = zero.squeeze().to(torch.float16) |
|
|
|
|
|
|
|
|
qweight[:, start:end] = ((group_weight / scale + zero).round().clamp(0, 15)).to(torch.int8) |
|
|
|
|
|
|
|
|
int4_layer.qweight.copy_(cls.pack_int4(qweight.flatten())) |
|
|
|
|
|
if module.bias is not None: |
|
|
int4_layer.bias = module.bias.data.to(torch.float16) |
|
|
|
|
|
return int4_layer |
|
|
|
|
|
|
|
|
class Mind2Quantizer: |
|
|
"""Quantizer for MiniMind models.""" |
|
|
|
|
|
def __init__(self, config: Optional[QuantizationConfig] = None): |
|
|
self.config = config or QuantizationConfig() |
|
|
|
|
|
def quantize( |
|
|
self, |
|
|
model: nn.Module, |
|
|
calibration_data: Optional[torch.Tensor] = None, |
|
|
) -> nn.Module: |
|
|
""" |
|
|
Quantize the model. |
|
|
|
|
|
Args: |
|
|
model: Model to quantize |
|
|
calibration_data: Calibration data for static quantization |
|
|
|
|
|
Returns: |
|
|
Quantized model |
|
|
""" |
|
|
if self.config.quant_type == QuantizationType.INT8_DYNAMIC: |
|
|
return self._quantize_int8_dynamic(model) |
|
|
elif self.config.quant_type == QuantizationType.INT4_AWQ: |
|
|
return self._quantize_int4_awq(model, calibration_data) |
|
|
elif self.config.quant_type == QuantizationType.INT4_GPTQ: |
|
|
return self._quantize_int4_gptq(model, calibration_data) |
|
|
else: |
|
|
raise ValueError(f"Unsupported quantization type: {self.config.quant_type}") |
|
|
|
|
|
def _quantize_int8_dynamic(self, model: nn.Module) -> nn.Module: |
|
|
"""Apply INT8 dynamic quantization.""" |
|
|
return torch.quantization.quantize_dynamic( |
|
|
model, |
|
|
{nn.Linear}, |
|
|
dtype=torch.qint8, |
|
|
) |
|
|
|
|
|
def _quantize_int4_awq( |
|
|
self, |
|
|
model: nn.Module, |
|
|
calibration_data: Optional[torch.Tensor] = None, |
|
|
) -> nn.Module: |
|
|
"""Apply AWQ-style INT4 quantization.""" |
|
|
model = model.cpu().float() |
|
|
|
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, nn.Linear) and module.weight.shape[0] >= 64: |
|
|
parent_name = ".".join(name.split(".")[:-1]) |
|
|
child_name = name.split(".")[-1] |
|
|
|
|
|
parent = model |
|
|
for part in parent_name.split("."): |
|
|
if part: |
|
|
parent = getattr(parent, part) |
|
|
|
|
|
int4_linear = Int4Linear.from_float(module, self.config.group_size) |
|
|
setattr(parent, child_name, int4_linear) |
|
|
|
|
|
return model |
|
|
|
|
|
def _quantize_int4_gptq( |
|
|
self, |
|
|
model: nn.Module, |
|
|
calibration_data: Optional[torch.Tensor] = None, |
|
|
) -> nn.Module: |
|
|
"""Apply GPTQ-style INT4 quantization with calibration.""" |
|
|
|
|
|
if calibration_data is None: |
|
|
print("Warning: GPTQ without calibration, falling back to AWQ") |
|
|
return self._quantize_int4_awq(model, calibration_data) |
|
|
|
|
|
model = model.cpu().float() |
|
|
|
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
model(calibration_data) |
|
|
|
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, nn.Linear) and module.weight.shape[0] >= 64: |
|
|
parent_name = ".".join(name.split(".")[:-1]) |
|
|
child_name = name.split(".")[-1] |
|
|
|
|
|
parent = model |
|
|
for part in parent_name.split("."): |
|
|
if part: |
|
|
parent = getattr(parent, part) |
|
|
|
|
|
int4_linear = Int4Linear.from_float(module, self.config.group_size) |
|
|
setattr(parent, child_name, int4_linear) |
|
|
|
|
|
return model |
|
|
|
|
|
def estimate_model_size(self, model: nn.Module) -> Dict[str, float]: |
|
|
"""Estimate model size in different formats.""" |
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
|
|
|
return { |
|
|
"params": total_params, |
|
|
"fp32_gb": (total_params * 4) / (1024**3), |
|
|
"fp16_gb": (total_params * 2) / (1024**3), |
|
|
"int8_gb": (total_params * 1) / (1024**3), |
|
|
"int4_gb": (total_params * 0.5) / (1024**3), |
|
|
} |
|
|
|
|
|
|
|
|
def quantize_model( |
|
|
model: nn.Module, |
|
|
quant_type: str = "int4_awq", |
|
|
group_size: int = 128, |
|
|
calibration_data: Optional[torch.Tensor] = None, |
|
|
) -> nn.Module: |
|
|
""" |
|
|
Convenience function to quantize a model. |
|
|
|
|
|
Args: |
|
|
model: Model to quantize |
|
|
quant_type: Quantization type (int4_awq, int4_gptq, int8_dynamic) |
|
|
group_size: Group size for INT4 quantization |
|
|
calibration_data: Calibration data for GPTQ |
|
|
|
|
|
Returns: |
|
|
Quantized model |
|
|
""" |
|
|
config = QuantizationConfig( |
|
|
quant_type=QuantizationType(quant_type), |
|
|
group_size=group_size, |
|
|
) |
|
|
quantizer = Mind2Quantizer(config) |
|
|
return quantizer.quantize(model, calibration_data) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
import sys |
|
|
sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
from model import create_model |
|
|
|
|
|
print("Testing quantization...") |
|
|
|
|
|
|
|
|
model = create_model("mind2-nano", device="cpu", dtype=torch.float32) |
|
|
|
|
|
quantizer = Mind2Quantizer() |
|
|
|
|
|
|
|
|
sizes = quantizer.estimate_model_size(model) |
|
|
print(f"Model sizes:") |
|
|
for fmt, size in sizes.items(): |
|
|
print(f" {fmt}: {size:.3f}") |
|
|
|
|
|
|
|
|
print("\nQuantizing to INT4...") |
|
|
quantized_model = quantizer.quantize(model) |
|
|
|
|
|
|
|
|
input_ids = torch.randint(0, 1000, (1, 32)) |
|
|
with torch.no_grad(): |
|
|
_, logits, _, _ = quantized_model(input_ids) |
|
|
print(f"Output shape: {logits.shape}") |
|
|
print("✓ Quantization successful!") |
|
|
|