| Custom MLX-LM Conversion, Quantization, and Inference | |
| Overview | |
| - Scripts here convert the HF safetensors model to MLX format, optionally apply mixed-precision dynamic quantization, and run inference with prompt formatting consistent with inference.py. | |
| - Quant layout is persisted in config.json so the loader can re-materialize only Linear layers as QuantizedLinear while keeping embeddings and norms in float. | |
| Key scripts | |
| - custom_convert_2.py | |
| - Convert and optionally quantize. | |
| - Mixed precision uses calibration data and a sensitivity-driven split between 4-bit and 8-bit Linear layers. | |
| - Saves weights to weights.npz and writes quantization metadata to config.json. | |
| - custom_loader.py | |
| - Loads the model with the correct module types (QuantizedLinear vs float) based on config metadata, then applies saved weights. | |
| - Leaves embeddings and layernorms in float. | |
| - inference_mlx_lm.py (CLI: mobilellm-infer) | |
| - Runs generation. Uses chat_template.jinja when present, else prepends BOS, matching inference.py behavior. | |
| - quant_summary.py | |
| - Prints a summary of per-layer bit-widths and checks quantized tensors exist in weights.npz. | |
| Quickstart | |
| - Mixed-precision dynamic quantization | |
| - uv run python custom_mlx_lm/custom_convert_2.py --hf-path . --mlx-path MobileLLM-R1-950M-mixed-4bit-mlx --dynamic-quant --target-bpw 4.5 --report-ppl | |
| - Group size defaults to 64 when not provided. | |
| - Uniform quantization | |
| - uv run python custom_mlx_lm/custom_convert_2.py --hf-path . --mlx-path MobileLLM-R1-950M-4bit-mlx --quantize --bits 4 --report-ppl | |
| - Summarize quant layout | |
| - uv run python custom_mlx_lm/quant_summary.py --model-path MobileLLM-R1-950M-mixed-4bit-mlx --show 8 | |
| - Inference | |
| - mobilellm-infer --model-path MobileLLM-R1-950M-mixed-4bit-mlx --prompt "What is the nearest prime to 9^2?" | |
| Notes and defaults | |
| - Calibration: load_data uses WikiText-like data; dynamic quant computes sensitivities once and chooses 4/8-bit per Linear layer to target the requested bits-per-weight. Reported PPL is from the same set. | |
| - Group size: defaults to 64 when quantizing if not provided. | |
| - Prompt formatting: by default uses chat_template.jinja if present; otherwise prepends BOS for stable behavior across float and quant models. | |
| Troubleshooting | |
| - Empty sensitivities (ValueError: min() arg is empty) | |
| - Fixed: ensure Linear weights are not frozen during sensitivity estimation; grads must exist. | |
| - Unable to quantize model of type QuantizedLinear | |
| - Fixed: second quantization pass now targets only remaining float Linear layers. | |
| - [dequantize] The matrix should be given as a uint32 | |
| - Fixed: loader does not blanket-quantize; it re-materializes only Linear layers from per-layer bits map before loading weights, leaving embeddings in float. | |
| Rationale and behavior | |
| - Persist per-layer bits: enables deterministic, loader-driven reconstruction of quant modules and prevents accidental quantization of unsupported modules. | |
| - Keep embeddings float: avoids dtype mismatch and preserves quality. | |
| - Match inference.py formatting: improves output consistency between float and quant variants. | |