| import torch | |
| from safetensors.torch import load_file | |
| from transformers import AutoConfig | |
| def load_mbq_model(model_path): | |
| '''Load MBQ quantized model and dequantize on-the-fly''' | |
| # Load quantized state dict | |
| state_dict = load_file(f"{model_path}/model.safetensors") | |
| # Separate weights and scales | |
| weights = {} | |
| scales = {} | |
| for name, param in state_dict.items(): | |
| if '.scale' in name: | |
| scales[name.replace('.scale', '')] = param | |
| else: | |
| weights[name] = param | |
| # Dequantize weights | |
| dequantized_state_dict = {} | |
| for name, param in weights.items(): | |
| if name in scales: | |
| # Dequantize: weight = q_weight * scale | |
| scale = scales[name] | |
| dequantized = (param.float() * scale).to(torch.bfloat16) | |
| dequantized_state_dict[name] = dequantized | |
| else: | |
| dequantized_state_dict[name] = param | |
| return dequantized_state_dict | |
| # Usage: | |
| # state_dict = load_mbq_model("./") | |
| # model.load_state_dict(state_dict) | |