MobileLLM-R1-950M-MLX / mlx_technical_summary.md
robbiemu's picture
add mlx and mlx-lm support
e39ff3a
# Porting **MobileLLM-R1-950M** to MLX and mlx-lm: Architectural Challenges and Solutions
I spent a some time pairing with Gemini 2.5 Pro and later OpenAI Codex to drag the brand-new facebook/MobileLLM-R1-950M weights onto Apple Silicon.
This write-up is the “why it wasn’t copy-paste” story, plus the gotchas that bit us until the model finally spoke clean English and quantized without drama.
### Goal
Enable **facebook/MobileLLM-R1-950M** to run natively on Apple Silicon using MLX, then create quantized versions compatible with the mlx-lm ecosystem.
---
## 1. Why a Direct "Llama-4 Drop-In" Failed
Although the Hugging Face repo presents MobileLLM-R1-950M as a Llama-4-style dense model, its **config and weights don't align cleanly** with a stock Llama block. The deviations aren't quirks of MLX—they reflect this model's specific architecture:
* **MLP ambiguity**
Config advertises both `intermediate_size` and `intermediate_size_mlp`, suggesting a dual-branch feed-forward.
Actual weights contain only a SwiGLU branch (`gate_proj`, `up_proj`, `down_proj`).
→ Solution: **auto-detect MLP variant from weight names** at load time.
* **Grouped-Query Attention (GQA)**
`num_attention_heads=24`, `num_key_value_heads=6`.
K/V tensors must be **repeated to full head count** for attention shapes to align correctly.
* **QK-norm and scaling**
Config includes `use_qk_norm=True` and `attn_scale=0.1`.
We add the **RMSNorm on Q/K** as specified, but drop the extra `0.1` multiplier—applying it in MLX's `scaled_dot_product_attention` collapses logits into gibberish.
* **RoPE gating**
Config lists all layers under `no_rope_layers`.
Disabling RoPE everywhere would eliminate positional encoding entirely.
→ Treat "all layers disabled" as a config artifact and **apply RoPE everywhere**.
---
## 2. Prompt-Level Deviations
Even after weights loaded correctly, default inference was disrupted by tokenizer settings:
* **Chat template**
Default system prompt: *"Please reason step-by-step and put your final answer within \boxed{}."*
Without overrides, the model produces verbose "reasoning" outputs.
→ Added CLI controls: `--system`, `--disable-chat-template`, `--final-only`.
* **Double BOS**
Both tokenizer and template inserted BOS tokens.
→ Fixed with `add_special_tokens=False`.
* **Premature EOS**
Template headers (`<|eot_id|>`) were treated as stop tokens.
→ Limited stopping criteria to true EOS token only.
---
## 3. Sampling Stability
Sampling issues stemmed from API mismatches rather than model problems:
* **Top-p on probabilities** then feeding `mx.random.categorical` produced repetition loops.
* **Solution:** Apply penalties → scale logits → top-p mask (with `float('-inf')`) → `categorical(logits)`.
* Added controls for **temperature, repetition penalty, frequency penalty**.
---
## 4. Quantization in mlx-lm: Why Custom Metadata Was Required
mlx-lm provides quantization hooks, but MobileLLM's architecture exposed several challenges:
1. **Frozen gradients during sensitivity analysis** → empty sensitivity lists.
→ Avoid freezing weights during gradient computation.
2. **Re-quantizing quantized layers** → type errors on second pass.
→ Skip `QuantizedLinear` layers if already quantized.
3. **Embedding/norm dtype crashes**
Standard quantization re-quantized everything, but embeddings must remain float.
→ Introduced **metadata-driven approach**: config.json records *per-layer bit-widths*. Only specified layers are instantiated as `QuantizedLinear`.
This metadata contract allows **4-bit mixed-precision MobileLLM** to be loaded cleanly by our **metadata-aware `custom_loader.py`**, making it compatible with the mlx-lm ecosystem.
---
## 5. End State
* **MLX path:**
Structural fixes (GQA, MLP detection), numerical fixes (QK-norm, RoPE, attn_scale), and prompt controls together yield fluent, stable inference.
* **mlx-lm path:**
Custom quantization pipeline produces FP16 and 4-bit models. These can be loaded with our **metadata-aware `custom_loader.py`** and used for inference with our provided scripts.
Performance: measurable speedup and reduced VRAM usage on Apple Silicon, with minimal quality degradation.
---
### Takeaway
The MobileLLM-R1-950M port required systematically addressing architectural mismatches (MLP variant detection, GQA handling, QK-norm implementation, RoPE configuration) and developing a metadata-driven quantization approach. Once these were resolved, the model became fully functional in MLX with both float and quantized inference paths.