File size: 1,337 Bytes
42a1727
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""
4-bit GPTQ Model Loader
This script unpacks the 4-bit packed weights for inference
"""
import torch
from safetensors.torch import load_file

def unpack_4bit_weights(packed_weights):
    """Unpack 4-bit weights to full precision"""
    rows, packed_cols = packed_weights.shape
    cols = packed_cols * 2

    unpacked = torch.zeros((rows, cols), dtype=torch.uint8)

    for i in range(packed_cols):
        unpacked[:, i*2] = (packed_weights[:, i] >> 4) & 0x0F
        if i*2 + 1 < cols:
            unpacked[:, i*2 + 1] = packed_weights[:, i] & 0x0F

    return unpacked

def load_quantized_model(model_path):
    """Load and unpack quantized model"""
    tensors = load_file(f"{model_path}/model.safetensors")

    state_dict = {}
    for key in tensors:
        if key.endswith('.weight_packed'):
            base_name = key.replace('.weight_packed', '')
            packed = tensors[key]
            scale = tensors[f"{base_name}.scale"]
            shape = tensors[f"{base_name}.shape"]

            # Unpack weights
            unpacked = unpack_4bit_weights(packed)
            weights = unpacked.float() * scale

            state_dict[f"{base_name}.weight"] = weights.reshape(shape.tolist())

    return state_dict

# Usage:
# state_dict = load_quantized_model("./deepseek-ocr-gptq-4bit")
# model.load_state_dict(state_dict)