Custom MLX-LM Conversion, Quantization, and Inference
Overview
- Scripts here convert the HF safetensors model to MLX format, optionally apply mixed-precision dynamic quantization, and run inference with prompt formatting consistent with inference.py.
- Quant layout is persisted in config.json so the loader can re-materialize only Linear layers as QuantizedLinear while keeping embeddings and norms in float.
Key scripts
- custom_convert_2.py
- Convert and optionally quantize.
- Mixed precision uses calibration data and a sensitivity-driven split between 4-bit and 8-bit Linear layers.
- Saves weights to weights.npz and writes quantization metadata to config.json.
- custom_loader.py
- Loads the model with the correct module types (QuantizedLinear vs float) based on config metadata, then applies saved weights.
- Leaves embeddings and layernorms in float.
- inference_mlx_lm.py (CLI: mobilellm-infer)
- Runs generation. Uses chat_template.jinja when present, else prepends BOS, matching inference.py behavior.
- quant_summary.py
- Prints a summary of per-layer bit-widths and checks quantized tensors exist in weights.npz.
Quickstart
- Mixed-precision dynamic quantization
- uv run python custom_mlx_lm/custom_convert_2.py --hf-path . --mlx-path MobileLLM-R1-950M-mixed-4bit-mlx --dynamic-quant --target-bpw 4.5 --report-ppl
- Group size defaults to 64 when not provided.
- Uniform quantization
- uv run python custom_mlx_lm/custom_convert_2.py --hf-path . --mlx-path MobileLLM-R1-950M-4bit-mlx --quantize --bits 4 --report-ppl
- Summarize quant layout
- uv run python custom_mlx_lm/quant_summary.py --model-path MobileLLM-R1-950M-mixed-4bit-mlx --show 8
- Inference
- mobilellm-infer --model-path MobileLLM-R1-950M-mixed-4bit-mlx --prompt "What is the nearest prime to 9^2?"
Notes and defaults
- Calibration: load_data uses WikiText-like data; dynamic quant computes sensitivities once and chooses 4/8-bit per Linear layer to target the requested bits-per-weight. Reported PPL is from the same set.
- Group size: defaults to 64 when quantizing if not provided.
- Prompt formatting: by default uses chat_template.jinja if present; otherwise prepends BOS for stable behavior across float and quant models.
Troubleshooting
- Empty sensitivities (ValueError: min() arg is empty)
- Fixed: ensure Linear weights are not frozen during sensitivity estimation; grads must exist.
- Unable to quantize model of type QuantizedLinear
- Fixed: second quantization pass now targets only remaining float Linear layers.
- [dequantize] The matrix should be given as a uint32
- Fixed: loader does not blanket-quantize; it re-materializes only Linear layers from per-layer bits map before loading weights, leaving embeddings in float.
Rationale and behavior
- Persist per-layer bits: enables deterministic, loader-driven reconstruction of quant modules and prevents accidental quantization of unsupported modules.
- Keep embeddings float: avoids dtype mismatch and preserves quality.
- Match inference.py formatting: improves output consistency between float and quant variants.