File size: 4,666 Bytes
e39ff3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# Porting **MobileLLM-R1-950M** to MLX and mlx-lm: Architectural Challenges and Solutions

I spent a some time pairing with Gemini 2.5 Pro and later OpenAI Codex to drag the brand-new facebook/MobileLLM-R1-950M weights onto Apple Silicon.
This write-up is the “why it wasn’t copy-paste” story, plus the gotchas that bit us until the model finally spoke clean English and quantized without drama.

### Goal

Enable **facebook/MobileLLM-R1-950M** to run natively on Apple Silicon using MLX, then create quantized versions compatible with the mlx-lm ecosystem.

---

## 1. Why a Direct "Llama-4 Drop-In" Failed

Although the Hugging Face repo presents MobileLLM-R1-950M as a Llama-4-style dense model, its **config and weights don't align cleanly** with a stock Llama block. The deviations aren't quirks of MLX—they reflect this model's specific architecture:

* **MLP ambiguity**  
  Config advertises both `intermediate_size` and `intermediate_size_mlp`, suggesting a dual-branch feed-forward.  
  Actual weights contain only a SwiGLU branch (`gate_proj`, `up_proj`, `down_proj`).  
  → Solution: **auto-detect MLP variant from weight names** at load time.

* **Grouped-Query Attention (GQA)**  
  `num_attention_heads=24`, `num_key_value_heads=6`.  
  K/V tensors must be **repeated to full head count** for attention shapes to align correctly.

* **QK-norm and scaling**  
  Config includes `use_qk_norm=True` and `attn_scale=0.1`.  
  We add the **RMSNorm on Q/K** as specified, but drop the extra `0.1` multiplier—applying it in MLX's `scaled_dot_product_attention` collapses logits into gibberish.

* **RoPE gating**  
  Config lists all layers under `no_rope_layers`.  
  Disabling RoPE everywhere would eliminate positional encoding entirely.  
  → Treat "all layers disabled" as a config artifact and **apply RoPE everywhere**.

---

## 2. Prompt-Level Deviations

Even after weights loaded correctly, default inference was disrupted by tokenizer settings:

* **Chat template**  
  Default system prompt: *"Please reason step-by-step and put your final answer within \boxed{}."*  
  Without overrides, the model produces verbose "reasoning" outputs.  
  → Added CLI controls: `--system`, `--disable-chat-template`, `--final-only`.

* **Double BOS**  
  Both tokenizer and template inserted BOS tokens.  
  → Fixed with `add_special_tokens=False`.

* **Premature EOS**  
  Template headers (`<|eot_id|>`) were treated as stop tokens.  
  → Limited stopping criteria to true EOS token only.

---

## 3. Sampling Stability

Sampling issues stemmed from API mismatches rather than model problems:

* **Top-p on probabilities** then feeding `mx.random.categorical` produced repetition loops.  
* **Solution:** Apply penalties → scale logits → top-p mask (with `float('-inf')`) → `categorical(logits)`.  
* Added controls for **temperature, repetition penalty, frequency penalty**.

---

## 4. Quantization in mlx-lm: Why Custom Metadata Was Required

mlx-lm provides quantization hooks, but MobileLLM's architecture exposed several challenges:

1. **Frozen gradients during sensitivity analysis** → empty sensitivity lists.  
   → Avoid freezing weights during gradient computation.

2. **Re-quantizing quantized layers** → type errors on second pass.  
   → Skip `QuantizedLinear` layers if already quantized.

3. **Embedding/norm dtype crashes**  
   Standard quantization re-quantized everything, but embeddings must remain float.  
   → Introduced **metadata-driven approach**: config.json records *per-layer bit-widths*. Only specified layers are instantiated as `QuantizedLinear`.

This metadata contract allows **4-bit mixed-precision MobileLLM** to be loaded cleanly by our **metadata-aware `custom_loader.py`**, making it compatible with the mlx-lm ecosystem.

---

## 5. End State

* **MLX path:**  
  Structural fixes (GQA, MLP detection), numerical fixes (QK-norm, RoPE, attn_scale), and prompt controls together yield fluent, stable inference.

* **mlx-lm path:**  
  Custom quantization pipeline produces FP16 and 4-bit models. These can be loaded with our **metadata-aware `custom_loader.py`** and used for inference with our provided scripts.  
  Performance: measurable speedup and reduced VRAM usage on Apple Silicon, with minimal quality degradation.

---

### Takeaway

The MobileLLM-R1-950M port required systematically addressing architectural mismatches (MLP variant detection, GQA handling, QK-norm implementation, RoPE configuration) and developing a metadata-driven quantization approach. Once these were resolved, the model became fully functional in MLX with both float and quantized inference paths.