deepseek-ocr-mbq-w4bit / load_mbq_model.py
SamMikaelson's picture
Upload MBQ quantized model (W4A8) - 65.05% size reduction
04e51cf
import torch
from safetensors.torch import load_file
from transformers import AutoConfig
def load_mbq_model(model_path):
'''Load MBQ quantized model and dequantize on-the-fly'''
# Load quantized state dict
state_dict = load_file(f"{model_path}/model.safetensors")
# Separate weights and scales
weights = {}
scales = {}
for name, param in state_dict.items():
if '.scale' in name:
scales[name.replace('.scale', '')] = param
else:
weights[name] = param
# Dequantize weights
dequantized_state_dict = {}
for name, param in weights.items():
if name in scales:
# Dequantize: weight = q_weight * scale
scale = scales[name]
dequantized = (param.float() * scale).to(torch.bfloat16)
dequantized_state_dict[name] = dequantized
else:
dequantized_state_dict[name] = param
return dequantized_state_dict
# Usage:
# state_dict = load_mbq_model("./")
# model.load_state_dict(state_dict)