import argparse import copy import json import os import sys from pathlib import Path import mlx.core as mx import mlx.nn as nn from mlx.utils import tree_flatten, tree_map, tree_unflatten from mlx_lm.quant.dynamic_quant import eval_ppl from mlx_lm.quant.utils import load_data from safetensors import safe_open from tqdm import tqdm from transformers import AutoTokenizer # FIX: Correctly calculate the project root to find model.py project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) if project_root not in sys.path: sys.path.insert(0, project_root) from model import Model, ModelArgs def estimate_sensitivities( model, data, low_bits, low_group_size, high_bits, high_group_size, batch_size=4 ): def qdq(w, bits, group_size): w, s, b = mx.quantize(w, bits=bits, group_size=group_size) return mx.dequantize(w, scales=s, biases=b, bits=bits, group_size=group_size) q_model = copy.deepcopy(model) linear_layers = { k: layer for k, layer in tree_flatten( q_model.leaf_modules(), is_leaf=nn.Module.is_module ) if isinstance(layer, nn.Linear) } # Quantize-dequantize weights for low-precision model copy and ensure # the weights remain trainable so gradients are computed for sensitivities. for layer in linear_layers.values(): layer.weight = qdq(layer.weight, low_bits, low_group_size) def loss_fn(batch, targets): logits = q_model(batch) return nn.losses.cross_entropy(logits, targets, reduction="mean") grad_accum = tree_map(lambda x: mx.zeros(x.shape), q_model.trainable_parameters()) for s in tqdm(range(0, len(data), batch_size), desc="Estimating sensitivities"): batch = data[s : s + batch_size] targets = model(batch[:, :-1]) mx.eval(targets) _, grads = nn.value_and_grad(q_model, loss_fn)(batch[:, :-1], batch[:, 1:]) grad_accum = tree_map(lambda x, y: x + y, grad_accum, grads) mx.eval(grad_accum) def compute_sensitivity(grad, lq_w, orig_w): hq_w = qdq(orig_w, high_bits, high_group_size) return (grad * (lq_w - hq_w)).sum() # Use a direct loop instead of tree_map to be more robust grad_dict = dict(tree_flatten(grad_accum)) q_params_dict = dict(tree_flatten(q_model.parameters())) orig_params_dict = dict(tree_flatten(model.parameters())) sensitivities = {} for path, module in linear_layers.items(): weight_key = f"{path}.weight" if weight_key in grad_dict: grad = grad_dict[weight_key] q_weight = q_params_dict[weight_key] orig_weight = orig_params_dict[weight_key] sensitivity = compute_sensitivity(grad, q_weight, orig_weight) sensitivities[path] = sensitivity.item() return sensitivities def estimate_threshold( model, sensitivities, target_bpw, low_bits, low_group_size, high_bits, high_group_size, ): def predicate(p, m, threshold): if not isinstance(m, nn.Linear): return False return sensitivities.get(p, 0) > threshold sens_vals = list(sensitivities.values()) if len(sens_vals) == 0: raise RuntimeError( "No sensitivities were computed. This usually means gradients " "for Linear weights were not collected. Ensure layers are detected " "and weights are trainable during sensitivity estimation." ) min_thr, max_thr = min(sens_vals), max(sens_vals) while (max_thr - min_thr) > 1e-3 * (max(sens_vals) - min(sens_vals)): mid = (max_thr + min_thr) / 2 q_model = copy.deepcopy(model) def high_predicate(p, m): return predicate(p, m, mid) def low_predicate(p, m): # Only quantize remaining float nn.Linear layers; avoid re-quantizing # modules already quantized in the first pass. return isinstance(m, nn.Linear) and (not predicate(p, m, mid)) nn.quantize( q_model, group_size=high_group_size, bits=high_bits, class_predicate=high_predicate, ) nn.quantize( q_model, group_size=low_group_size, bits=low_bits, class_predicate=low_predicate, ) bpw = ( sum(p.nbytes for _, p in tree_flatten(q_model.parameters())) * 8 / sum(p.size for _, p in tree_flatten(q_model.parameters())) ) if bpw > target_bpw: min_thr = mid else: max_thr = mid return (max_thr + min_thr) / 2 # --- Main Conversion and Saving Logic --- def main(): parser = argparse.ArgumentParser( description="Convert and optionally quantize a model." ) parser.add_argument( "--hf-path", type=str, default=".", help="Path to the Hugging Face model." ) parser.add_argument( "--mlx-path", type=str, required=True, help="Path to save the MLX model." ) parser.add_argument( "--quantize", "-q", action="store_true", help="Generate a simple uniformly quantized model.", ) parser.add_argument( "--dynamic-quant", action="store_true", help="Use advanced mixed-precision quantization.", ) parser.add_argument( "--report-ppl", action="store_true", help="Report perplexity before and after quantization.", ) parser.add_argument( "--target-bpw", type=float, default=4.5, help="Target bits per weight for advanced quant.", ) parser.add_argument( "--bits", "-b", type=int, default=4, help="Bits for uniform quantization." ) parser.add_argument( "--group-size", "-g", type=int, default=None, help="Group size for quantization. If omitted, defaults to 64 when quantizing.", ) args = parser.parse_args() print(f"Loading model from {args.hf_path}...") hf_path = Path(args.hf_path) tokenizer = AutoTokenizer.from_pretrained(args.hf_path) with open(hf_path / "config.json", "r") as f: config = json.load(f) with safe_open(hf_path / "model.safetensors", framework="mlx") as f: keys = list(f.keys()) has_dual = any( (".feed_forward.g_up.weight" in k) or (".mlp.g_up.weight" in k) for k in keys ) model_args = ModelArgs.from_dict(config) model_args.use_dual_mlp = bool(has_dual) model = Model(model_args) weights = {} with safe_open(hf_path / "model.safetensors", framework="mlx") as f: for k in f.keys(): if has_dual and ("gate_proj" in k or "up_proj" in k or "down_proj" in k): continue v = f.get_tensor(k) k = k.replace("model.embed_tokens", "tok_embeddings") k = k.replace("model.layers", "layers") k = k.replace("self_attn", "attention") k = k.replace("input_layernorm", "attention_norm") k = k.replace("post_attention_layernorm", "ffn_norm") k = k.replace("mlp.", "feed_forward.") k = k.replace("model.norm", "norm") weights[k] = v if config.get("tie_word_embeddings", True): weights.pop("output.weight", None) model.update(tree_unflatten(list(weights.items()))) calibration_data = None if args.report_ppl or args.dynamic_quant: print("Loading calibration data...") calibration_data = load_data(tokenizer, num_samples=-1, sequence_length=512) if args.report_ppl: print("Calculating perplexity of original model...") ppl = eval_ppl(model, data=calibration_data) print(f"Original PPL: {ppl:.3f}") if args.dynamic_quant: # Choose a sensible default group size if not provided if args.group_size is None: args.group_size = 64 print("[info] Using default group_size=64 for dynamic quantization") print("Starting advanced mixed-precision quantization...") sensitivities = estimate_sensitivities( model, calibration_data, 4, args.group_size, 8, args.group_size ) threshold = estimate_threshold( model, sensitivities, args.target_bpw, 4, args.group_size, 8, args.group_size, ) # Compute per-layer bit widths BEFORE mutating the model per_layer_bits = {p: (8 if s > threshold else 4) for p, s in sensitivities.items()} def high_predicate(p, m): return isinstance(m, nn.Linear) and per_layer_bits.get(p, 4) == 8 def low_predicate(p, m): return isinstance(m, nn.Linear) and per_layer_bits.get(p, 4) == 4 nn.quantize( model, group_size=args.group_size, bits=8, class_predicate=high_predicate ) nn.quantize( model, group_size=args.group_size, bits=4, class_predicate=low_predicate ) # Persist per-layer bit-widths so the loader can re-materialize # the correct QuantizedLinear modules on load without touching # embeddings or other layers. config["quantization"] = { "group_size": args.group_size, "method": "mixed_precision_dynamic", "per_layer_bits": per_layer_bits, } elif args.quantize: # Choose a sensible default group size if not provided if args.group_size is None: args.group_size = 64 print("[info] Using default group_size=64 for uniform quantization") print("Starting simple uniform quantization...") nn.quantize(model, group_size=args.group_size, bits=args.bits) config["quantization"] = { "group_size": args.group_size, "bits": args.bits, "method": "uniform", } if args.report_ppl and (args.quantize or args.dynamic_quant): print("Calculating perplexity of quantized model...") ppl = eval_ppl(model, data=calibration_data) print(f"Quantized PPL: {ppl:.3f}") output_path = Path(args.mlx_path) output_path.mkdir(parents=True, exist_ok=True) mx.savez(str(output_path / "weights.npz"), **dict(tree_flatten(model.parameters()))) with open(output_path / "config.json", "w") as f: json.dump(config, f, indent=4) tokenizer.save_pretrained(output_path) print(f"\n✅ Model saved to {args.mlx_path}") if __name__ == "__main__": main()