import torch from safetensors.torch import load_file from transformers import AutoConfig def load_mbq_model(model_path): '''Load MBQ quantized model and dequantize on-the-fly''' # Load quantized state dict state_dict = load_file(f"{model_path}/model.safetensors") # Separate weights and scales weights = {} scales = {} for name, param in state_dict.items(): if '.scale' in name: scales[name.replace('.scale', '')] = param else: weights[name] = param # Dequantize weights dequantized_state_dict = {} for name, param in weights.items(): if name in scales: # Dequantize: weight = q_weight * scale scale = scales[name] dequantized = (param.float() * scale).to(torch.bfloat16) dequantized_state_dict[name] = dequantized else: dequantized_state_dict[name] = param return dequantized_state_dict # Usage: # state_dict = load_mbq_model("./") # model.load_state_dict(state_dict)