File size: 4,100 Bytes
8d8ae43 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 | import torch
import torch.nn as nn
import torch.nn.functional as F
from .modules import sparse as sp
def _pack_4bit_tensor(q: torch.Tensor) -> torch.Tensor:
q = q.to(torch.int8)
q = torch.clamp(q, -8, 7)
q_u = ((q + 8) & 0xF).to(torch.uint8).reshape(-1)
if q_u.numel() % 2 == 1:
q_u = torch.cat([q_u, torch.zeros(1, dtype=torch.uint8, device=q_u.device)])
q_u = q_u.view(-1, 2)
return q_u[:, 0] | (q_u[:, 1] << 4)
def _unpack_4bit_tensor(packed: torch.Tensor, shape: torch.Size) -> torch.Tensor:
packed = packed.reshape(-1)
low = packed & 0xF
high = packed >> 4
q = torch.empty((packed.numel() * 2,), dtype=torch.int8, device=packed.device)
q[0::2] = low.to(torch.int8)
q[1::2] = high.to(torch.int8)
q = q[: torch.prod(torch.tensor(shape, dtype=torch.int64)).item()]
q = q.view(shape)
return q - 8
def _quantize_weight(weight: torch.Tensor, bits: int = 4) -> tuple[torch.Tensor, torch.Tensor]:
if bits != 4:
raise ValueError('Only 4-bit quantization is supported.')
flat_weight = weight.view(weight.size(0), -1)
max_vals = flat_weight.abs().amax(dim=1)
scales = torch.where(max_vals > 0, max_vals / 7.0, torch.ones_like(max_vals))
scales = scales.to(weight.dtype)
q_weight = torch.clamp((flat_weight / scales[:, None]).round(), -8, 7).to(torch.int8)
q_weight = q_weight.view(weight.shape)
packed = _pack_4bit_tensor(q_weight)
return packed, scales
class QuantizedLinear(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = torch.float16,
bits: int = 4,
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.bits = bits
self.dtype = dtype
self.register_buffer('weight_packed', torch.empty(0, dtype=torch.uint8))
self.register_buffer('weight_scales', torch.empty((out_features,), dtype=self.dtype))
if bias:
self.bias = nn.Parameter(torch.zeros(out_features, dtype=self.dtype))
else:
self.bias = None
@classmethod
def from_float(cls, module: nn.Linear, bits: int = 4, dtype: torch.dtype = torch.float16):
quantized = cls(module.in_features, module.out_features, module.bias is not None, dtype=dtype, bits=bits)
packed, scales = _quantize_weight(module.weight.data, bits=bits)
quantized.weight_packed = packed.to(dtype=torch.uint8)
quantized.weight_scales = scales.to(dtype=dtype)
if module.bias is not None:
quantized.bias.data = module.bias.data.to(dtype=dtype)
return quantized
def dequantize_weight(self) -> torch.Tensor:
q = _unpack_4bit_tensor(self.weight_packed, torch.Size((self.out_features, self.in_features)))
scales = self.weight_scales.view(self.out_features, *([1] * (q.ndim - 1)))
return q.to(self.dtype) * scales
def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(self.dtype)
weight = self.dequantize_weight()
return F.linear(input, weight, self.bias)
class QuantizedSparseLinear(QuantizedLinear):
def forward(self, input):
feats = input.feats.to(self.dtype)
weight = self.dequantize_weight()
return input.replace(F.linear(feats, weight, self.bias))
def quantize_model(model: nn.Module, bits: int = 4, dtype: torch.dtype = torch.float16) -> None:
"""Replace linear modules in a model with 4-bit quantized equivalents."""
for name, child in list(model.named_children()):
if isinstance(child, sp.SparseLinear):
model._modules[name] = QuantizedSparseLinear.from_float(child, bits=bits, dtype=dtype)
elif isinstance(child, nn.Linear):
model._modules[name] = QuantizedLinear.from_float(child, bits=bits, dtype=dtype)
else:
quantize_model(child, bits=bits, dtype=dtype)
|