""" Quantization module for loading compressed DeepSeek-OCR models """ import torch import torch.nn as nn import torch.nn.functional as F class QuantizedLinear(nn.Module): """Quantized linear layer supporting 2/4/8-bit quantization""" def __init__(self, in_features, out_features, bits, weight_data, scale, zero_point, bias=None): super().__init__() self.in_features = in_features self.out_features = out_features self.bits = bits if bits == 8: self.register_buffer('weight_quantized', weight_data.to(torch.int8)) elif bits == 4: packed = self._pack_4bit(weight_data) self.register_buffer('weight_quantized', packed) elif bits == 2: packed = self._pack_2bit(weight_data) self.register_buffer('weight_quantized', packed) else: self.register_buffer('weight_quantized', weight_data.to(torch.int8)) self.register_buffer('scale', scale) self.register_buffer('zero_point', zero_point) if bias is not None: self.register_buffer('bias', bias) else: self.bias = None self.original_shape = (out_features, in_features) def _pack_4bit(self, data): data = data.view(-1) if data.numel() % 2 != 0: data = F.pad(data, (0, 1)) data = data.to(torch.int8) even = data[0::2] & 0x0F odd = (data[1::2] & 0x0F) << 4 return (even | odd).to(torch.uint8) def _unpack_4bit(self, packed): even = (packed & 0x0F).to(torch.int8) odd = ((packed >> 4) & 0x0F).to(torch.int8) even = torch.where(even > 7, even - 16, even) odd = torch.where(odd > 7, odd - 16, odd) unpacked = torch.stack([even, odd], dim=1).view(-1) return unpacked[:self.original_shape[0] * self.original_shape[1]] def _pack_2bit(self, data): data = data.view(-1) pad_size = (4 - data.numel() % 4) % 4 if pad_size > 0: data = F.pad(data, (0, pad_size)) data = data.to(torch.int8) & 0x03 packed = data[0::4] | (data[1::4] << 2) | (data[2::4] << 4) | (data[3::4] << 6) return packed.to(torch.uint8) def _unpack_2bit(self, packed): b0 = (packed & 0x03).to(torch.int8) b1 = ((packed >> 2) & 0x03).to(torch.int8) b2 = ((packed >> 4) & 0x03).to(torch.int8) b3 = ((packed >> 6) & 0x03).to(torch.int8) b0 = torch.where(b0 > 1, b0 - 4, b0) b1 = torch.where(b1 > 1, b1 - 4, b1) b2 = torch.where(b2 > 1, b2 - 4, b2) b3 = torch.where(b3 > 1, b3 - 4, b3) unpacked = torch.stack([b0, b1, b2, b3], dim=1).view(-1) return unpacked[:self.original_shape[0] * self.original_shape[1]] def dequantize(self): if self.bits == 8: weight_int = self.weight_quantized.float() elif self.bits == 4: weight_int = self._unpack_4bit(self.weight_quantized).float() elif self.bits == 2: weight_int = self._unpack_2bit(self.weight_quantized).float() else: weight_int = self.weight_quantized.float() weight = (weight_int - self.zero_point) * self.scale return weight.view(self.original_shape) def forward(self, x): weight = self.dequantize().to(x.dtype) return F.linear(x, weight, self.bias)