robbiemu's picture
add mlx and mlx-lm support
e39ff3a

Custom MLX-LM Conversion, Quantization, and Inference

Overview

  • Scripts here convert the HF safetensors model to MLX format, optionally apply mixed-precision dynamic quantization, and run inference with prompt formatting consistent with inference.py.
  • Quant layout is persisted in config.json so the loader can re-materialize only Linear layers as QuantizedLinear while keeping embeddings and norms in float.

Key scripts

  • custom_convert_2.py
    • Convert and optionally quantize.
    • Mixed precision uses calibration data and a sensitivity-driven split between 4-bit and 8-bit Linear layers.
    • Saves weights to weights.npz and writes quantization metadata to config.json.
  • custom_loader.py
    • Loads the model with the correct module types (QuantizedLinear vs float) based on config metadata, then applies saved weights.
    • Leaves embeddings and layernorms in float.
  • inference_mlx_lm.py (CLI: mobilellm-infer)
    • Runs generation. Uses chat_template.jinja when present, else prepends BOS, matching inference.py behavior.
  • quant_summary.py
    • Prints a summary of per-layer bit-widths and checks quantized tensors exist in weights.npz.

Quickstart

  • Mixed-precision dynamic quantization
    • uv run python custom_mlx_lm/custom_convert_2.py --hf-path . --mlx-path MobileLLM-R1-950M-mixed-4bit-mlx --dynamic-quant --target-bpw 4.5 --report-ppl
    • Group size defaults to 64 when not provided.
  • Uniform quantization
    • uv run python custom_mlx_lm/custom_convert_2.py --hf-path . --mlx-path MobileLLM-R1-950M-4bit-mlx --quantize --bits 4 --report-ppl
  • Summarize quant layout
    • uv run python custom_mlx_lm/quant_summary.py --model-path MobileLLM-R1-950M-mixed-4bit-mlx --show 8
  • Inference
    • mobilellm-infer --model-path MobileLLM-R1-950M-mixed-4bit-mlx --prompt "What is the nearest prime to 9^2?"

Notes and defaults

  • Calibration: load_data uses WikiText-like data; dynamic quant computes sensitivities once and chooses 4/8-bit per Linear layer to target the requested bits-per-weight. Reported PPL is from the same set.
  • Group size: defaults to 64 when quantizing if not provided.
  • Prompt formatting: by default uses chat_template.jinja if present; otherwise prepends BOS for stable behavior across float and quant models.

Troubleshooting

  • Empty sensitivities (ValueError: min() arg is empty)
    • Fixed: ensure Linear weights are not frozen during sensitivity estimation; grads must exist.
  • Unable to quantize model of type QuantizedLinear
    • Fixed: second quantization pass now targets only remaining float Linear layers.
  • [dequantize] The matrix should be given as a uint32
    • Fixed: loader does not blanket-quantize; it re-materializes only Linear layers from per-layer bits map before loading weights, leaving embeddings in float.

Rationale and behavior

  • Persist per-layer bits: enables deterministic, loader-driven reconstruction of quant modules and prevents accidental quantization of unsupported modules.
  • Keep embeddings float: avoids dtype mismatch and preserves quality.
  • Match inference.py formatting: improves output consistency between float and quant variants.