File size: 1,337 Bytes
42a1727 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
"""
4-bit GPTQ Model Loader
This script unpacks the 4-bit packed weights for inference
"""
import torch
from safetensors.torch import load_file
def unpack_4bit_weights(packed_weights):
"""Unpack 4-bit weights to full precision"""
rows, packed_cols = packed_weights.shape
cols = packed_cols * 2
unpacked = torch.zeros((rows, cols), dtype=torch.uint8)
for i in range(packed_cols):
unpacked[:, i*2] = (packed_weights[:, i] >> 4) & 0x0F
if i*2 + 1 < cols:
unpacked[:, i*2 + 1] = packed_weights[:, i] & 0x0F
return unpacked
def load_quantized_model(model_path):
"""Load and unpack quantized model"""
tensors = load_file(f"{model_path}/model.safetensors")
state_dict = {}
for key in tensors:
if key.endswith('.weight_packed'):
base_name = key.replace('.weight_packed', '')
packed = tensors[key]
scale = tensors[f"{base_name}.scale"]
shape = tensors[f"{base_name}.shape"]
# Unpack weights
unpacked = unpack_4bit_weights(packed)
weights = unpacked.float() * scale
state_dict[f"{base_name}.weight"] = weights.reshape(shape.tolist())
return state_dict
# Usage:
# state_dict = load_quantized_model("./deepseek-ocr-gptq-4bit")
# model.load_state_dict(state_dict)
|