File size: 3,399 Bytes
cffc684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
Quantization module for loading compressed DeepSeek-OCR models
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

class QuantizedLinear(nn.Module):
    """Quantized linear layer supporting 2/4/8-bit quantization"""
    
    def __init__(self, in_features, out_features, bits, weight_data, scale, zero_point, bias=None):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.bits = bits

        if bits == 8:
            self.register_buffer('weight_quantized', weight_data.to(torch.int8))
        elif bits == 4:
            packed = self._pack_4bit(weight_data)
            self.register_buffer('weight_quantized', packed)
        elif bits == 2:
            packed = self._pack_2bit(weight_data)
            self.register_buffer('weight_quantized', packed)
        else:
            self.register_buffer('weight_quantized', weight_data.to(torch.int8))

        self.register_buffer('scale', scale)
        self.register_buffer('zero_point', zero_point)

        if bias is not None:
            self.register_buffer('bias', bias)
        else:
            self.bias = None

        self.original_shape = (out_features, in_features)

    def _pack_4bit(self, data):
        data = data.view(-1)
        if data.numel() % 2 != 0:
            data = F.pad(data, (0, 1))
        data = data.to(torch.int8)
        even = data[0::2] & 0x0F
        odd = (data[1::2] & 0x0F) << 4
        return (even | odd).to(torch.uint8)

    def _unpack_4bit(self, packed):
        even = (packed & 0x0F).to(torch.int8)
        odd = ((packed >> 4) & 0x0F).to(torch.int8)
        even = torch.where(even > 7, even - 16, even)
        odd = torch.where(odd > 7, odd - 16, odd)
        unpacked = torch.stack([even, odd], dim=1).view(-1)
        return unpacked[:self.original_shape[0] * self.original_shape[1]]

    def _pack_2bit(self, data):
        data = data.view(-1)
        pad_size = (4 - data.numel() % 4) % 4
        if pad_size > 0:
            data = F.pad(data, (0, pad_size))
        data = data.to(torch.int8) & 0x03
        packed = data[0::4] | (data[1::4] << 2) | (data[2::4] << 4) | (data[3::4] << 6)
        return packed.to(torch.uint8)

    def _unpack_2bit(self, packed):
        b0 = (packed & 0x03).to(torch.int8)
        b1 = ((packed >> 2) & 0x03).to(torch.int8)
        b2 = ((packed >> 4) & 0x03).to(torch.int8)
        b3 = ((packed >> 6) & 0x03).to(torch.int8)
        b0 = torch.where(b0 > 1, b0 - 4, b0)
        b1 = torch.where(b1 > 1, b1 - 4, b1)
        b2 = torch.where(b2 > 1, b2 - 4, b2)
        b3 = torch.where(b3 > 1, b3 - 4, b3)
        unpacked = torch.stack([b0, b1, b2, b3], dim=1).view(-1)
        return unpacked[:self.original_shape[0] * self.original_shape[1]]

    def dequantize(self):
        if self.bits == 8:
            weight_int = self.weight_quantized.float()
        elif self.bits == 4:
            weight_int = self._unpack_4bit(self.weight_quantized).float()
        elif self.bits == 2:
            weight_int = self._unpack_2bit(self.weight_quantized).float()
        else:
            weight_int = self.weight_quantized.float()

        weight = (weight_int - self.zero_point) * self.scale
        return weight.view(self.original_shape)

    def forward(self, x):
        weight = self.dequantize().to(x.dtype)
        return F.linear(x, weight, self.bias)