robbiemu's picture
add mlx and mlx-lm support
e39ff3a
import argparse
import copy
import json
import os
import sys
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten, tree_map, tree_unflatten
from mlx_lm.quant.dynamic_quant import eval_ppl
from mlx_lm.quant.utils import load_data
from safetensors import safe_open
from tqdm import tqdm
from transformers import AutoTokenizer
# FIX: Correctly calculate the project root to find model.py
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if project_root not in sys.path:
sys.path.insert(0, project_root)
from model import Model, ModelArgs
def estimate_sensitivities(
model, data, low_bits, low_group_size, high_bits, high_group_size, batch_size=4
):
def qdq(w, bits, group_size):
w, s, b = mx.quantize(w, bits=bits, group_size=group_size)
return mx.dequantize(w, scales=s, biases=b, bits=bits, group_size=group_size)
q_model = copy.deepcopy(model)
linear_layers = {
k: layer
for k, layer in tree_flatten(
q_model.leaf_modules(), is_leaf=nn.Module.is_module
)
if isinstance(layer, nn.Linear)
}
# Quantize-dequantize weights for low-precision model copy and ensure
# the weights remain trainable so gradients are computed for sensitivities.
for layer in linear_layers.values():
layer.weight = qdq(layer.weight, low_bits, low_group_size)
def loss_fn(batch, targets):
logits = q_model(batch)
return nn.losses.cross_entropy(logits, targets, reduction="mean")
grad_accum = tree_map(lambda x: mx.zeros(x.shape), q_model.trainable_parameters())
for s in tqdm(range(0, len(data), batch_size), desc="Estimating sensitivities"):
batch = data[s : s + batch_size]
targets = model(batch[:, :-1])
mx.eval(targets)
_, grads = nn.value_and_grad(q_model, loss_fn)(batch[:, :-1], batch[:, 1:])
grad_accum = tree_map(lambda x, y: x + y, grad_accum, grads)
mx.eval(grad_accum)
def compute_sensitivity(grad, lq_w, orig_w):
hq_w = qdq(orig_w, high_bits, high_group_size)
return (grad * (lq_w - hq_w)).sum()
# Use a direct loop instead of tree_map to be more robust
grad_dict = dict(tree_flatten(grad_accum))
q_params_dict = dict(tree_flatten(q_model.parameters()))
orig_params_dict = dict(tree_flatten(model.parameters()))
sensitivities = {}
for path, module in linear_layers.items():
weight_key = f"{path}.weight"
if weight_key in grad_dict:
grad = grad_dict[weight_key]
q_weight = q_params_dict[weight_key]
orig_weight = orig_params_dict[weight_key]
sensitivity = compute_sensitivity(grad, q_weight, orig_weight)
sensitivities[path] = sensitivity.item()
return sensitivities
def estimate_threshold(
model,
sensitivities,
target_bpw,
low_bits,
low_group_size,
high_bits,
high_group_size,
):
def predicate(p, m, threshold):
if not isinstance(m, nn.Linear):
return False
return sensitivities.get(p, 0) > threshold
sens_vals = list(sensitivities.values())
if len(sens_vals) == 0:
raise RuntimeError(
"No sensitivities were computed. This usually means gradients "
"for Linear weights were not collected. Ensure layers are detected "
"and weights are trainable during sensitivity estimation."
)
min_thr, max_thr = min(sens_vals), max(sens_vals)
while (max_thr - min_thr) > 1e-3 * (max(sens_vals) - min(sens_vals)):
mid = (max_thr + min_thr) / 2
q_model = copy.deepcopy(model)
def high_predicate(p, m):
return predicate(p, m, mid)
def low_predicate(p, m):
# Only quantize remaining float nn.Linear layers; avoid re-quantizing
# modules already quantized in the first pass.
return isinstance(m, nn.Linear) and (not predicate(p, m, mid))
nn.quantize(
q_model,
group_size=high_group_size,
bits=high_bits,
class_predicate=high_predicate,
)
nn.quantize(
q_model,
group_size=low_group_size,
bits=low_bits,
class_predicate=low_predicate,
)
bpw = (
sum(p.nbytes for _, p in tree_flatten(q_model.parameters()))
* 8
/ sum(p.size for _, p in tree_flatten(q_model.parameters()))
)
if bpw > target_bpw:
min_thr = mid
else:
max_thr = mid
return (max_thr + min_thr) / 2
# --- Main Conversion and Saving Logic ---
def main():
parser = argparse.ArgumentParser(
description="Convert and optionally quantize a model."
)
parser.add_argument(
"--hf-path", type=str, default=".", help="Path to the Hugging Face model."
)
parser.add_argument(
"--mlx-path", type=str, required=True, help="Path to save the MLX model."
)
parser.add_argument(
"--quantize",
"-q",
action="store_true",
help="Generate a simple uniformly quantized model.",
)
parser.add_argument(
"--dynamic-quant",
action="store_true",
help="Use advanced mixed-precision quantization.",
)
parser.add_argument(
"--report-ppl",
action="store_true",
help="Report perplexity before and after quantization.",
)
parser.add_argument(
"--target-bpw",
type=float,
default=4.5,
help="Target bits per weight for advanced quant.",
)
parser.add_argument(
"--bits", "-b", type=int, default=4, help="Bits for uniform quantization."
)
parser.add_argument(
"--group-size",
"-g",
type=int,
default=None,
help="Group size for quantization. If omitted, defaults to 64 when quantizing.",
)
args = parser.parse_args()
print(f"Loading model from {args.hf_path}...")
hf_path = Path(args.hf_path)
tokenizer = AutoTokenizer.from_pretrained(args.hf_path)
with open(hf_path / "config.json", "r") as f:
config = json.load(f)
with safe_open(hf_path / "model.safetensors", framework="mlx") as f:
keys = list(f.keys())
has_dual = any(
(".feed_forward.g_up.weight" in k) or (".mlp.g_up.weight" in k) for k in keys
)
model_args = ModelArgs.from_dict(config)
model_args.use_dual_mlp = bool(has_dual)
model = Model(model_args)
weights = {}
with safe_open(hf_path / "model.safetensors", framework="mlx") as f:
for k in f.keys():
if has_dual and ("gate_proj" in k or "up_proj" in k or "down_proj" in k):
continue
v = f.get_tensor(k)
k = k.replace("model.embed_tokens", "tok_embeddings")
k = k.replace("model.layers", "layers")
k = k.replace("self_attn", "attention")
k = k.replace("input_layernorm", "attention_norm")
k = k.replace("post_attention_layernorm", "ffn_norm")
k = k.replace("mlp.", "feed_forward.")
k = k.replace("model.norm", "norm")
weights[k] = v
if config.get("tie_word_embeddings", True):
weights.pop("output.weight", None)
model.update(tree_unflatten(list(weights.items())))
calibration_data = None
if args.report_ppl or args.dynamic_quant:
print("Loading calibration data...")
calibration_data = load_data(tokenizer, num_samples=-1, sequence_length=512)
if args.report_ppl:
print("Calculating perplexity of original model...")
ppl = eval_ppl(model, data=calibration_data)
print(f"Original PPL: {ppl:.3f}")
if args.dynamic_quant:
# Choose a sensible default group size if not provided
if args.group_size is None:
args.group_size = 64
print("[info] Using default group_size=64 for dynamic quantization")
print("Starting advanced mixed-precision quantization...")
sensitivities = estimate_sensitivities(
model, calibration_data, 4, args.group_size, 8, args.group_size
)
threshold = estimate_threshold(
model,
sensitivities,
args.target_bpw,
4,
args.group_size,
8,
args.group_size,
)
# Compute per-layer bit widths BEFORE mutating the model
per_layer_bits = {p: (8 if s > threshold else 4) for p, s in sensitivities.items()}
def high_predicate(p, m):
return isinstance(m, nn.Linear) and per_layer_bits.get(p, 4) == 8
def low_predicate(p, m):
return isinstance(m, nn.Linear) and per_layer_bits.get(p, 4) == 4
nn.quantize(
model, group_size=args.group_size, bits=8, class_predicate=high_predicate
)
nn.quantize(
model, group_size=args.group_size, bits=4, class_predicate=low_predicate
)
# Persist per-layer bit-widths so the loader can re-materialize
# the correct QuantizedLinear modules on load without touching
# embeddings or other layers.
config["quantization"] = {
"group_size": args.group_size,
"method": "mixed_precision_dynamic",
"per_layer_bits": per_layer_bits,
}
elif args.quantize:
# Choose a sensible default group size if not provided
if args.group_size is None:
args.group_size = 64
print("[info] Using default group_size=64 for uniform quantization")
print("Starting simple uniform quantization...")
nn.quantize(model, group_size=args.group_size, bits=args.bits)
config["quantization"] = {
"group_size": args.group_size,
"bits": args.bits,
"method": "uniform",
}
if args.report_ppl and (args.quantize or args.dynamic_quant):
print("Calculating perplexity of quantized model...")
ppl = eval_ppl(model, data=calibration_data)
print(f"Quantized PPL: {ppl:.3f}")
output_path = Path(args.mlx_path)
output_path.mkdir(parents=True, exist_ok=True)
mx.savez(str(output_path / "weights.npz"), **dict(tree_flatten(model.parameters())))
with open(output_path / "config.json", "w") as f:
json.dump(config, f, indent=4)
tokenizer.save_pretrained(output_path)
print(f"\n✅ Model saved to {args.mlx_path}")
if __name__ == "__main__":
main()