import torch import torch.nn as nn import torch.nn.functional as F from .modules import sparse as sp def _pack_4bit_tensor(q: torch.Tensor) -> torch.Tensor: q = q.to(torch.int8) q = torch.clamp(q, -8, 7) q_u = ((q + 8) & 0xF).to(torch.uint8).reshape(-1) if q_u.numel() % 2 == 1: q_u = torch.cat([q_u, torch.zeros(1, dtype=torch.uint8, device=q_u.device)]) q_u = q_u.view(-1, 2) return q_u[:, 0] | (q_u[:, 1] << 4) def _unpack_4bit_tensor(packed: torch.Tensor, shape: torch.Size) -> torch.Tensor: packed = packed.reshape(-1) low = packed & 0xF high = packed >> 4 q = torch.empty((packed.numel() * 2,), dtype=torch.int8, device=packed.device) q[0::2] = low.to(torch.int8) q[1::2] = high.to(torch.int8) q = q[: torch.prod(torch.tensor(shape, dtype=torch.int64)).item()] q = q.view(shape) return q - 8 def _quantize_weight(weight: torch.Tensor, bits: int = 4) -> tuple[torch.Tensor, torch.Tensor]: if bits != 4: raise ValueError('Only 4-bit quantization is supported.') flat_weight = weight.view(weight.size(0), -1) max_vals = flat_weight.abs().amax(dim=1) scales = torch.where(max_vals > 0, max_vals / 7.0, torch.ones_like(max_vals)) scales = scales.to(weight.dtype) q_weight = torch.clamp((flat_weight / scales[:, None]).round(), -8, 7).to(torch.int8) q_weight = q_weight.view(weight.shape) packed = _pack_4bit_tensor(q_weight) return packed, scales class QuantizedLinear(nn.Module): def __init__( self, in_features: int, out_features: int, bias: bool = True, dtype: torch.dtype = torch.float16, bits: int = 4, ): super().__init__() self.in_features = in_features self.out_features = out_features self.bits = bits self.dtype = dtype self.register_buffer('weight_packed', torch.empty(0, dtype=torch.uint8)) self.register_buffer('weight_scales', torch.empty((out_features,), dtype=self.dtype)) if bias: self.bias = nn.Parameter(torch.zeros(out_features, dtype=self.dtype)) else: self.bias = None @classmethod def from_float(cls, module: nn.Linear, bits: int = 4, dtype: torch.dtype = torch.float16): quantized = cls(module.in_features, module.out_features, module.bias is not None, dtype=dtype, bits=bits) packed, scales = _quantize_weight(module.weight.data, bits=bits) quantized.weight_packed = packed.to(dtype=torch.uint8) quantized.weight_scales = scales.to(dtype=dtype) if module.bias is not None: quantized.bias.data = module.bias.data.to(dtype=dtype) return quantized def dequantize_weight(self) -> torch.Tensor: q = _unpack_4bit_tensor(self.weight_packed, torch.Size((self.out_features, self.in_features))) scales = self.weight_scales.view(self.out_features, *([1] * (q.ndim - 1))) return q.to(self.dtype) * scales def forward(self, input: torch.Tensor) -> torch.Tensor: input = input.to(self.dtype) weight = self.dequantize_weight() return F.linear(input, weight, self.bias) class QuantizedSparseLinear(QuantizedLinear): def forward(self, input): feats = input.feats.to(self.dtype) weight = self.dequantize_weight() return input.replace(F.linear(feats, weight, self.bias)) def quantize_model(model: nn.Module, bits: int = 4, dtype: torch.dtype = torch.float16) -> None: """Replace linear modules in a model with 4-bit quantized equivalents.""" for name, child in list(model.named_children()): if isinstance(child, sp.SparseLinear): model._modules[name] = QuantizedSparseLinear.from_float(child, bits=bits, dtype=dtype) elif isinstance(child, nn.Linear): model._modules[name] = QuantizedLinear.from_float(child, bits=bits, dtype=dtype) else: quantize_model(child, bits=bits, dtype=dtype)