| """ | |
| Quantization module for loading compressed DeepSeek-OCR models | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class QuantizedLinear(nn.Module): | |
| """Quantized linear layer supporting 2/4/8-bit quantization""" | |
| def __init__(self, in_features, out_features, bits, weight_data, scale, zero_point, bias=None): | |
| super().__init__() | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.bits = bits | |
| if bits == 8: | |
| self.register_buffer('weight_quantized', weight_data.to(torch.int8)) | |
| elif bits == 4: | |
| packed = self._pack_4bit(weight_data) | |
| self.register_buffer('weight_quantized', packed) | |
| elif bits == 2: | |
| packed = self._pack_2bit(weight_data) | |
| self.register_buffer('weight_quantized', packed) | |
| else: | |
| self.register_buffer('weight_quantized', weight_data.to(torch.int8)) | |
| self.register_buffer('scale', scale) | |
| self.register_buffer('zero_point', zero_point) | |
| if bias is not None: | |
| self.register_buffer('bias', bias) | |
| else: | |
| self.bias = None | |
| self.original_shape = (out_features, in_features) | |
| def _pack_4bit(self, data): | |
| data = data.view(-1) | |
| if data.numel() % 2 != 0: | |
| data = F.pad(data, (0, 1)) | |
| data = data.to(torch.int8) | |
| even = data[0::2] & 0x0F | |
| odd = (data[1::2] & 0x0F) << 4 | |
| return (even | odd).to(torch.uint8) | |
| def _unpack_4bit(self, packed): | |
| even = (packed & 0x0F).to(torch.int8) | |
| odd = ((packed >> 4) & 0x0F).to(torch.int8) | |
| even = torch.where(even > 7, even - 16, even) | |
| odd = torch.where(odd > 7, odd - 16, odd) | |
| unpacked = torch.stack([even, odd], dim=1).view(-1) | |
| return unpacked[:self.original_shape[0] * self.original_shape[1]] | |
| def _pack_2bit(self, data): | |
| data = data.view(-1) | |
| pad_size = (4 - data.numel() % 4) % 4 | |
| if pad_size > 0: | |
| data = F.pad(data, (0, pad_size)) | |
| data = data.to(torch.int8) & 0x03 | |
| packed = data[0::4] | (data[1::4] << 2) | (data[2::4] << 4) | (data[3::4] << 6) | |
| return packed.to(torch.uint8) | |
| def _unpack_2bit(self, packed): | |
| b0 = (packed & 0x03).to(torch.int8) | |
| b1 = ((packed >> 2) & 0x03).to(torch.int8) | |
| b2 = ((packed >> 4) & 0x03).to(torch.int8) | |
| b3 = ((packed >> 6) & 0x03).to(torch.int8) | |
| b0 = torch.where(b0 > 1, b0 - 4, b0) | |
| b1 = torch.where(b1 > 1, b1 - 4, b1) | |
| b2 = torch.where(b2 > 1, b2 - 4, b2) | |
| b3 = torch.where(b3 > 1, b3 - 4, b3) | |
| unpacked = torch.stack([b0, b1, b2, b3], dim=1).view(-1) | |
| return unpacked[:self.original_shape[0] * self.original_shape[1]] | |
| def dequantize(self): | |
| if self.bits == 8: | |
| weight_int = self.weight_quantized.float() | |
| elif self.bits == 4: | |
| weight_int = self._unpack_4bit(self.weight_quantized).float() | |
| elif self.bits == 2: | |
| weight_int = self._unpack_2bit(self.weight_quantized).float() | |
| else: | |
| weight_int = self.weight_quantized.float() | |
| weight = (weight_int - self.zero_point) * self.scale | |
| return weight.view(self.original_shape) | |
| def forward(self, x): | |
| weight = self.dequantize().to(x.dtype) | |
| return F.linear(x, weight, self.bias) | |