| | import torch |
| | import torch.nn as nn |
| |
|
| | from torch.ao.quantization.fake_quantize import FakeQuantize |
| | from torch.ao.quantization.observer import MinMaxObserver |
| | from torch.ao.quantization.qconfig import QConfig |
| | from torch.ao.quantization import convert |
| |
|
| | from .model import BitTransformerLM |
| |
|
| |
|
| | def quantize_dynamic(model: BitTransformerLM, dtype: torch.dtype = torch.qint8) -> BitTransformerLM: |
| | """Return a dynamically quantized copy of the model for inference.""" |
| | quantized = torch.quantization.quantize_dynamic( |
| | model, {nn.Linear}, dtype=dtype |
| | ) |
| | return quantized |
| |
|
| |
|
| | class FourBitObserver(MinMaxObserver): |
| | """Min-max observer configured for 4-bit quantization.""" |
| |
|
| | def __init__(self, **kwargs): |
| | super().__init__( |
| | quant_min=0, |
| | quant_max=15, |
| | dtype=torch.quint8, |
| | qscheme=torch.per_tensor_affine, |
| | **kwargs, |
| | ) |
| |
|
| |
|
| | FourBitFakeQuantize = FakeQuantize.with_args(observer=FourBitObserver) |
| |
|
| | four_bit_qconfig = QConfig(activation=FourBitFakeQuantize, weight=FourBitFakeQuantize) |
| |
|
| |
|
| | class QATLinear(nn.Linear): |
| | """Linear layer with fake quantization for QAT.""" |
| |
|
| | def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: |
| | super().__init__(in_features, out_features, bias) |
| | self.weight_fake_quant = FourBitFakeQuantize() |
| | self.activation_post_process = FourBitFakeQuantize() |
| |
|
| | @classmethod |
| | def from_float(cls, mod: nn.Linear) -> "QATLinear": |
| | qat = cls(mod.in_features, mod.out_features, mod.bias is not None) |
| | qat.weight = mod.weight |
| | qat.bias = mod.bias |
| | return qat |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x = self.activation_post_process(x) |
| | w = self.weight_fake_quant(self.weight) |
| | return nn.functional.linear(x, w, self.bias) |
| |
|
| |
|
| | def prepare_qat_fx(model: BitTransformerLM) -> BitTransformerLM: |
| | """Prepare BitTransformerLM for quantization-aware training.""" |
| |
|
| | for name, module in model.named_children(): |
| | if isinstance(module, nn.Linear): |
| | setattr(model, name, QATLinear.from_float(module)) |
| | else: |
| | prepare_qat_fx(module) |
| | return model |
| |
|
| |
|
| | def convert_qat_fx(model: BitTransformerLM) -> BitTransformerLM: |
| | """Convert a QAT-prepared model to a quantized version.""" |
| |
|
| | for name, module in model.named_children(): |
| | if isinstance(module, QATLinear): |
| | w = module.weight.data |
| | qmin, qmax = 0, 15 |
| | min_w = w.min() |
| | max_w = w.max() |
| | scale = (max_w - min_w) / (qmax - qmin + 1e-8) |
| | zero_point = qmin - torch.round(min_w / scale) |
| | q_w = torch.clamp(torch.round(w / scale + zero_point), qmin, qmax) |
| | new_mod = nn.Linear(module.in_features, module.out_features, module.bias is not None) |
| | new_mod.weight = nn.Parameter((q_w - zero_point) * scale) |
| | if module.bias is not None: |
| | new_mod.bias = nn.Parameter(module.bias.data) |
| | setattr(model, name, new_mod) |
| | else: |
| | convert_qat_fx(module) |
| | return model |
| |
|