|
|
""" |
|
|
4-bit GPTQ Model Loader |
|
|
This script unpacks the 4-bit packed weights for inference |
|
|
""" |
|
|
import torch |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
def unpack_4bit_weights(packed_weights): |
|
|
"""Unpack 4-bit weights to full precision""" |
|
|
rows, packed_cols = packed_weights.shape |
|
|
cols = packed_cols * 2 |
|
|
|
|
|
unpacked = torch.zeros((rows, cols), dtype=torch.uint8) |
|
|
|
|
|
for i in range(packed_cols): |
|
|
unpacked[:, i*2] = (packed_weights[:, i] >> 4) & 0x0F |
|
|
if i*2 + 1 < cols: |
|
|
unpacked[:, i*2 + 1] = packed_weights[:, i] & 0x0F |
|
|
|
|
|
return unpacked |
|
|
|
|
|
def load_quantized_model(model_path): |
|
|
"""Load and unpack quantized model""" |
|
|
tensors = load_file(f"{model_path}/model.safetensors") |
|
|
|
|
|
state_dict = {} |
|
|
for key in tensors: |
|
|
if key.endswith('.weight_packed'): |
|
|
base_name = key.replace('.weight_packed', '') |
|
|
packed = tensors[key] |
|
|
scale = tensors[f"{base_name}.scale"] |
|
|
shape = tensors[f"{base_name}.shape"] |
|
|
|
|
|
|
|
|
unpacked = unpack_4bit_weights(packed) |
|
|
weights = unpacked.float() * scale |
|
|
|
|
|
state_dict[f"{base_name}.weight"] = weights.reshape(shape.tolist()) |
|
|
|
|
|
return state_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|