File size: 3,104 Bytes
e39ff3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
Custom MLX-LM Conversion, Quantization, and Inference

Overview
- Scripts here convert the HF safetensors model to MLX format, optionally apply mixed-precision dynamic quantization, and run inference with prompt formatting consistent with inference.py.
- Quant layout is persisted in config.json so the loader can re-materialize only Linear layers as QuantizedLinear while keeping embeddings and norms in float.

Key scripts
- custom_convert_2.py
  - Convert and optionally quantize.
  - Mixed precision uses calibration data and a sensitivity-driven split between 4-bit and 8-bit Linear layers.
  - Saves weights to weights.npz and writes quantization metadata to config.json.
- custom_loader.py
  - Loads the model with the correct module types (QuantizedLinear vs float) based on config metadata, then applies saved weights.
  - Leaves embeddings and layernorms in float.
- inference_mlx_lm.py (CLI: mobilellm-infer)
  - Runs generation. Uses chat_template.jinja when present, else prepends BOS, matching inference.py behavior.
- quant_summary.py
  - Prints a summary of per-layer bit-widths and checks quantized tensors exist in weights.npz.

Quickstart
- Mixed-precision dynamic quantization
  - uv run python custom_mlx_lm/custom_convert_2.py --hf-path . --mlx-path MobileLLM-R1-950M-mixed-4bit-mlx --dynamic-quant --target-bpw 4.5 --report-ppl
  - Group size defaults to 64 when not provided.
- Uniform quantization
  - uv run python custom_mlx_lm/custom_convert_2.py --hf-path . --mlx-path MobileLLM-R1-950M-4bit-mlx --quantize --bits 4 --report-ppl
- Summarize quant layout
  - uv run python custom_mlx_lm/quant_summary.py --model-path MobileLLM-R1-950M-mixed-4bit-mlx --show 8
- Inference
  - mobilellm-infer --model-path MobileLLM-R1-950M-mixed-4bit-mlx --prompt "What is the nearest prime to 9^2?"

Notes and defaults
- Calibration: load_data uses WikiText-like data; dynamic quant computes sensitivities once and chooses 4/8-bit per Linear layer to target the requested bits-per-weight. Reported PPL is from the same set.
- Group size: defaults to 64 when quantizing if not provided.
- Prompt formatting: by default uses chat_template.jinja if present; otherwise prepends BOS for stable behavior across float and quant models.

Troubleshooting
- Empty sensitivities (ValueError: min() arg is empty)
  - Fixed: ensure Linear weights are not frozen during sensitivity estimation; grads must exist.
- Unable to quantize model of type QuantizedLinear
  - Fixed: second quantization pass now targets only remaining float Linear layers.
- [dequantize] The matrix should be given as a uint32
  - Fixed: loader does not blanket-quantize; it re-materializes only Linear layers from per-layer bits map before loading weights, leaving embeddings in float.

Rationale and behavior
- Persist per-layer bits: enables deterministic, loader-driven reconstruction of quant modules and prevents accidental quantization of unsupported modules.
- Keep embeddings float: avoids dtype mismatch and preserves quality.
- Match inference.py formatting: improves output consistency between float and quant variants.