""" 4-bit GPTQ Model Loader This script unpacks the 4-bit packed weights for inference """ import torch from safetensors.torch import load_file def unpack_4bit_weights(packed_weights): """Unpack 4-bit weights to full precision""" rows, packed_cols = packed_weights.shape cols = packed_cols * 2 unpacked = torch.zeros((rows, cols), dtype=torch.uint8) for i in range(packed_cols): unpacked[:, i*2] = (packed_weights[:, i] >> 4) & 0x0F if i*2 + 1 < cols: unpacked[:, i*2 + 1] = packed_weights[:, i] & 0x0F return unpacked def load_quantized_model(model_path): """Load and unpack quantized model""" tensors = load_file(f"{model_path}/model.safetensors") state_dict = {} for key in tensors: if key.endswith('.weight_packed'): base_name = key.replace('.weight_packed', '') packed = tensors[key] scale = tensors[f"{base_name}.scale"] shape = tensors[f"{base_name}.shape"] # Unpack weights unpacked = unpack_4bit_weights(packed) weights = unpacked.float() * scale state_dict[f"{base_name}.weight"] = weights.reshape(shape.tolist()) return state_dict # Usage: # state_dict = load_quantized_model("./deepseek-ocr-gptq-4bit") # model.load_state_dict(state_dict)