deepseek-ocr-int8-uniform / quantization.py
SamMikaelson's picture
INT8 quantized (safetensors format)
cffc684 verified
"""
Quantization module for loading compressed DeepSeek-OCR models
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class QuantizedLinear(nn.Module):
"""Quantized linear layer supporting 2/4/8-bit quantization"""
def __init__(self, in_features, out_features, bits, weight_data, scale, zero_point, bias=None):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.bits = bits
if bits == 8:
self.register_buffer('weight_quantized', weight_data.to(torch.int8))
elif bits == 4:
packed = self._pack_4bit(weight_data)
self.register_buffer('weight_quantized', packed)
elif bits == 2:
packed = self._pack_2bit(weight_data)
self.register_buffer('weight_quantized', packed)
else:
self.register_buffer('weight_quantized', weight_data.to(torch.int8))
self.register_buffer('scale', scale)
self.register_buffer('zero_point', zero_point)
if bias is not None:
self.register_buffer('bias', bias)
else:
self.bias = None
self.original_shape = (out_features, in_features)
def _pack_4bit(self, data):
data = data.view(-1)
if data.numel() % 2 != 0:
data = F.pad(data, (0, 1))
data = data.to(torch.int8)
even = data[0::2] & 0x0F
odd = (data[1::2] & 0x0F) << 4
return (even | odd).to(torch.uint8)
def _unpack_4bit(self, packed):
even = (packed & 0x0F).to(torch.int8)
odd = ((packed >> 4) & 0x0F).to(torch.int8)
even = torch.where(even > 7, even - 16, even)
odd = torch.where(odd > 7, odd - 16, odd)
unpacked = torch.stack([even, odd], dim=1).view(-1)
return unpacked[:self.original_shape[0] * self.original_shape[1]]
def _pack_2bit(self, data):
data = data.view(-1)
pad_size = (4 - data.numel() % 4) % 4
if pad_size > 0:
data = F.pad(data, (0, pad_size))
data = data.to(torch.int8) & 0x03
packed = data[0::4] | (data[1::4] << 2) | (data[2::4] << 4) | (data[3::4] << 6)
return packed.to(torch.uint8)
def _unpack_2bit(self, packed):
b0 = (packed & 0x03).to(torch.int8)
b1 = ((packed >> 2) & 0x03).to(torch.int8)
b2 = ((packed >> 4) & 0x03).to(torch.int8)
b3 = ((packed >> 6) & 0x03).to(torch.int8)
b0 = torch.where(b0 > 1, b0 - 4, b0)
b1 = torch.where(b1 > 1, b1 - 4, b1)
b2 = torch.where(b2 > 1, b2 - 4, b2)
b3 = torch.where(b3 > 1, b3 - 4, b3)
unpacked = torch.stack([b0, b1, b2, b3], dim=1).view(-1)
return unpacked[:self.original_shape[0] * self.original_shape[1]]
def dequantize(self):
if self.bits == 8:
weight_int = self.weight_quantized.float()
elif self.bits == 4:
weight_int = self._unpack_4bit(self.weight_quantized).float()
elif self.bits == 2:
weight_int = self._unpack_2bit(self.weight_quantized).float()
else:
weight_int = self.weight_quantized.float()
weight = (weight_int - self.zero_point) * self.scale
return weight.view(self.original_shape)
def forward(self, x):
weight = self.dequantize().to(x.dtype)
return F.linear(x, weight, self.bias)