Remove nested directory: BitTransformerLM/bit_transformer/quantization.py
Browse files
BitTransformerLM/bit_transformer/quantization.py
DELETED
|
@@ -1,89 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
|
| 4 |
-
from torch.ao.quantization.fake_quantize import FakeQuantize
|
| 5 |
-
from torch.ao.quantization.observer import MinMaxObserver
|
| 6 |
-
from torch.ao.quantization.qconfig import QConfig
|
| 7 |
-
from torch.ao.quantization import convert
|
| 8 |
-
|
| 9 |
-
from .model import BitTransformerLM
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def quantize_dynamic(model: BitTransformerLM, dtype: torch.dtype = torch.qint8) -> BitTransformerLM:
|
| 13 |
-
"""Return a dynamically quantized copy of the model for inference."""
|
| 14 |
-
quantized = torch.quantization.quantize_dynamic(
|
| 15 |
-
model, {nn.Linear}, dtype=dtype
|
| 16 |
-
)
|
| 17 |
-
return quantized
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class FourBitObserver(MinMaxObserver):
|
| 21 |
-
"""Min-max observer configured for 4-bit quantization."""
|
| 22 |
-
|
| 23 |
-
def __init__(self, **kwargs):
|
| 24 |
-
super().__init__(
|
| 25 |
-
quant_min=0,
|
| 26 |
-
quant_max=15,
|
| 27 |
-
dtype=torch.quint8,
|
| 28 |
-
qscheme=torch.per_tensor_affine,
|
| 29 |
-
**kwargs,
|
| 30 |
-
)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
FourBitFakeQuantize = FakeQuantize.with_args(observer=FourBitObserver)
|
| 34 |
-
|
| 35 |
-
four_bit_qconfig = QConfig(activation=FourBitFakeQuantize, weight=FourBitFakeQuantize)
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
class QATLinear(nn.Linear):
|
| 39 |
-
"""Linear layer with fake quantization for QAT."""
|
| 40 |
-
|
| 41 |
-
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
|
| 42 |
-
super().__init__(in_features, out_features, bias)
|
| 43 |
-
self.weight_fake_quant = FourBitFakeQuantize()
|
| 44 |
-
self.activation_post_process = FourBitFakeQuantize()
|
| 45 |
-
|
| 46 |
-
@classmethod
|
| 47 |
-
def from_float(cls, mod: nn.Linear) -> "QATLinear":
|
| 48 |
-
qat = cls(mod.in_features, mod.out_features, mod.bias is not None)
|
| 49 |
-
qat.weight = mod.weight
|
| 50 |
-
qat.bias = mod.bias
|
| 51 |
-
return qat
|
| 52 |
-
|
| 53 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 54 |
-
x = self.activation_post_process(x)
|
| 55 |
-
w = self.weight_fake_quant(self.weight)
|
| 56 |
-
return nn.functional.linear(x, w, self.bias)
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def prepare_qat_fx(model: BitTransformerLM) -> BitTransformerLM:
|
| 60 |
-
"""Prepare BitTransformerLM for quantization-aware training."""
|
| 61 |
-
|
| 62 |
-
for name, module in model.named_children():
|
| 63 |
-
if isinstance(module, nn.Linear):
|
| 64 |
-
setattr(model, name, QATLinear.from_float(module))
|
| 65 |
-
else:
|
| 66 |
-
prepare_qat_fx(module)
|
| 67 |
-
return model
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def convert_qat_fx(model: BitTransformerLM) -> BitTransformerLM:
|
| 71 |
-
"""Convert a QAT-prepared model to a quantized version."""
|
| 72 |
-
|
| 73 |
-
for name, module in model.named_children():
|
| 74 |
-
if isinstance(module, QATLinear):
|
| 75 |
-
w = module.weight.data
|
| 76 |
-
qmin, qmax = 0, 15
|
| 77 |
-
min_w = w.min()
|
| 78 |
-
max_w = w.max()
|
| 79 |
-
scale = (max_w - min_w) / (qmax - qmin + 1e-8)
|
| 80 |
-
zero_point = qmin - torch.round(min_w / scale)
|
| 81 |
-
q_w = torch.clamp(torch.round(w / scale + zero_point), qmin, qmax)
|
| 82 |
-
new_mod = nn.Linear(module.in_features, module.out_features, module.bias is not None)
|
| 83 |
-
new_mod.weight = nn.Parameter((q_w - zero_point) * scale)
|
| 84 |
-
if module.bias is not None:
|
| 85 |
-
new_mod.bias = nn.Parameter(module.bias.data)
|
| 86 |
-
setattr(model, name, new_mod)
|
| 87 |
-
else:
|
| 88 |
-
convert_qat_fx(module)
|
| 89 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|