Add Gemma 4-26B-A4B support: 4.15 tok/s on M4 Mac Mini
Browse filesReplace placeholder Gemma 4 engine with working implementation:
- Custom forward pass for Gemma 4 architecture (sliding/full attention,
layer scalars, dual layernorms, dense MLP + MoE in parallel)
- Mixed quantization handling (experts 4-bit, dense MLP 8-bit)
- Cache-aware routing bias=1.5 (steers router toward cached experts)
- Gemma 4 chat template encoder (turn_start/turn_end tokens 105/106)
- gelu_approx activation in expert FFN
Replace preprocess_gemma4.py with SwitchLinear unstack:
- Loads mlx-community/gemma-4-26b-a4b-it-4bit (15.6 GB) instead of
bf16 source
- Unstacks (128, out, in) experts into per-expert bin blocks
- Preserves bfloat16 bytes via uint16 view (no precision loss)
Wire up auto-dispatch in:
- sniper.py (SniperEngine.from_dir auto-detects gemma4 model_type)
- generate.py (generate_stream forks to _gemma4_generate_stream)
- calibrate.py (_build_engine handles gemma4 path)
Update download registry: gemma4-26b now points to mlx-community
4-bit version (15.6 GB) instead of Google bf16 (50 GB).
Update README with verified 4.15 tok/s benchmark and memory
bandwidth scaling table for M2 → M2 Ultra.
Verified end-to-end on M4 Mac Mini 16 GB:
- 4.15 tok/s sustained
- 95.8% cache hit rate
- 7.8 GB RAM
- Coherent output (math, code, explanations)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- README.md +24 -2
- src/mlx_expert_sniper/calibrate.py +10 -1
- src/mlx_expert_sniper/download.py +3 -3
- src/mlx_expert_sniper/engine_gemma4.py +312 -82
- src/mlx_expert_sniper/generate.py +70 -2
- src/mlx_expert_sniper/models/gemma4.py +66 -33
- src/mlx_expert_sniper/preprocess_gemma4.py +195 -201
- src/mlx_expert_sniper/sniper.py +23 -2
|
@@ -20,6 +20,7 @@ Run MoE models larger than your RAM on Apple Silicon.
|
|
| 20 |
|-------|------|---------|-----------------|--------------|-----------|-----|
|
| 21 |
| Qwen3.5-35B-A3B | 19.5 GB | 256/layer | OOM | **5.37 tok/s** | 92.0% | 8.7 GB |
|
| 22 |
| Qwen3-30B-A3B | 17.2 GB | 128/layer | OOM | **4.29 tok/s** | 90.4% | 8.7 GB |
|
|
|
|
| 23 |
|
| 24 |
All benchmarks: M4 Mac Mini 16 GB, 5 varied prompts, greedy decoding.
|
| 25 |
|
|
@@ -32,15 +33,36 @@ All benchmarks: M4 Mac Mini 16 GB, 5 varied prompts, greedy decoding.
|
|
| 32 |
|
| 33 |
**30B**: right-sized LRU + co-activation prefetch. REAP/bias not yet applied.
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
## Supported Models
|
| 36 |
|
| 37 |
| Model | Size | Experts | tok/s (M4 16GB) | Status |
|
| 38 |
|-------|------|---------|-----------------|--------|
|
| 39 |
-
| Qwen3.5-35B-A3B | 19.5 GB | 256/layer | 5.
|
| 40 |
-
| Qwen3-30B-A3B | 17.2 GB | 128/layer |
|
|
|
|
| 41 |
|
| 42 |
More models coming. To request a model, open an issue on [GitHub](https://github.com/walter-grace/mac-code).
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
### Hardware Requirements
|
| 45 |
|
| 46 |
| Mac | RAM | What you can run |
|
|
|
|
| 20 |
|-------|------|---------|-----------------|--------------|-----------|-----|
|
| 21 |
| Qwen3.5-35B-A3B | 19.5 GB | 256/layer | OOM | **5.37 tok/s** | 92.0% | 8.7 GB |
|
| 22 |
| Qwen3-30B-A3B | 17.2 GB | 128/layer | OOM | **4.29 tok/s** | 90.4% | 8.7 GB |
|
| 23 |
+
| **Gemma 4-26B-A4B** | 15.6 GB | 128/layer | OOM | **4.15 tok/s** | 95.8% | 7.8 GB |
|
| 24 |
|
| 25 |
All benchmarks: M4 Mac Mini 16 GB, 5 varied prompts, greedy decoding.
|
| 26 |
|
|
|
|
| 33 |
|
| 34 |
**30B**: right-sized LRU + co-activation prefetch. REAP/bias not yet applied.
|
| 35 |
|
| 36 |
+
**Gemma 4-26B-A4B** (NEW):
|
| 37 |
+
- Custom Gemma 4 model class (sliding/full attention hybrid, layer scalars, dual layernorms)
|
| 38 |
+
- Mixed quantization: experts 4-bit, dense MLP and router 8-bit (matches mlx-community format)
|
| 39 |
+
- Cache-aware routing bias=1.5 + co-activation prefetch (95.8% hit rate)
|
| 40 |
+
- Source: `mlx-community/gemma-4-26b-a4b-it-4bit`
|
| 41 |
+
|
| 42 |
## Supported Models
|
| 43 |
|
| 44 |
| Model | Size | Experts | tok/s (M4 16GB) | Status |
|
| 45 |
|-------|------|---------|-----------------|--------|
|
| 46 |
+
| Qwen3.5-35B-A3B | 19.5 GB | 256/layer | 5.37 tok/s | Verified |
|
| 47 |
+
| Qwen3-30B-A3B | 17.2 GB | 128/layer | 4.29 tok/s | Verified |
|
| 48 |
+
| **Gemma 4-26B-A4B** | 15.6 GB | 128/layer | **4.15 tok/s** | Verified |
|
| 49 |
|
| 50 |
More models coming. To request a model, open an issue on [GitHub](https://github.com/walter-grace/mac-code).
|
| 51 |
|
| 52 |
+
### Memory Bandwidth Scaling
|
| 53 |
+
|
| 54 |
+
MoE inference is bandwidth-bound. Expected speeds on different Apple Silicon Macs:
|
| 55 |
+
|
| 56 |
+
| Mac | Memory BW | Qwen 35B est. | Gemma 4-26B est. |
|
| 57 |
+
|-----|-----------|---------------|------------------|
|
| 58 |
+
| M2 Mac Mini | 100 GB/s | ~4.5 tok/s | ~3.5 tok/s |
|
| 59 |
+
| **M4 Mac Mini** | **120 GB/s** | **5.37 tok/s** ✓ | **4.15 tok/s** ✓ |
|
| 60 |
+
| M2 Pro Mac Mini | 200 GB/s | ~8-10 tok/s | ~7-8 tok/s |
|
| 61 |
+
| M4 Pro Mac Mini | 273 GB/s | ~12-14 tok/s | ~10-11 tok/s |
|
| 62 |
+
| M2 Max Studio | 400 GB/s | ~16-20 tok/s | ~14-17 tok/s |
|
| 63 |
+
| M4 Max MacBook Pro | 546 GB/s | ~22-28 tok/s | ~18-23 tok/s |
|
| 64 |
+
| M2 Ultra Studio | 800 GB/s | ~30-40 tok/s | ~25-32 tok/s |
|
| 65 |
+
|
| 66 |
### Hardware Requirements
|
| 67 |
|
| 68 |
| Mac | RAM | What you can run |
|
|
@@ -63,7 +63,16 @@ def _detect_model_type(model_dir):
|
|
| 63 |
def _build_engine(model_dir, cache_size):
|
| 64 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 65 |
model_type = _detect_model_type(model_dir)
|
| 66 |
-
if "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
from .engine_next import MoESniperEngineNext as EngineClass
|
| 68 |
from . import engine_next as engine_mod
|
| 69 |
engine_mod.MODEL_DIR = model_dir
|
|
|
|
| 63 |
def _build_engine(model_dir, cache_size):
|
| 64 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 65 |
model_type = _detect_model_type(model_dir)
|
| 66 |
+
if "gemma4" in model_type:
|
| 67 |
+
from .engine_gemma4 import MoESniperEngineGemma4
|
| 68 |
+
engine = MoESniperEngineGemma4(
|
| 69 |
+
model_dir=model_dir,
|
| 70 |
+
cache_size=cache_size,
|
| 71 |
+
enable_prediction=False,
|
| 72 |
+
)
|
| 73 |
+
engine.load()
|
| 74 |
+
return engine
|
| 75 |
+
elif "qwen3_next" in model_type:
|
| 76 |
from .engine_next import MoESniperEngineNext as EngineClass
|
| 77 |
from . import engine_next as engine_mod
|
| 78 |
engine_mod.MODEL_DIR = model_dir
|
|
@@ -45,11 +45,11 @@ MODEL_REGISTRY = {
|
|
| 45 |
"default_dir": "qwen3-235b-stream",
|
| 46 |
"description": "Qwen3-235B-A22B 4-bit (~130 GB, 128 experts, needs 64+ GB RAM)",
|
| 47 |
},
|
| 48 |
-
# Gemma 4 (Google) —
|
| 49 |
"gemma4-26b": {
|
| 50 |
-
"repo": "
|
| 51 |
"default_dir": "gemma4-26b-stream",
|
| 52 |
-
"description": "Gemma 4-26B-A4B
|
| 53 |
"preprocess": "gemma4",
|
| 54 |
},
|
| 55 |
}
|
|
|
|
| 45 |
"default_dir": "qwen3-235b-stream",
|
| 46 |
"description": "Qwen3-235B-A22B 4-bit (~130 GB, 128 experts, needs 64+ GB RAM)",
|
| 47 |
},
|
| 48 |
+
# Gemma 4 (Google) — 4.15 tok/s on M4 Mac Mini
|
| 49 |
"gemma4-26b": {
|
| 50 |
+
"repo": "mlx-community/gemma-4-26b-a4b-it-4bit",
|
| 51 |
"default_dir": "gemma4-26b-stream",
|
| 52 |
+
"description": "Gemma 4-26B-A4B 4-bit (~15.6 GB, 128 experts, mixed quant — Verified 4.15 tok/s on M4)",
|
| 53 |
"preprocess": "gemma4",
|
| 54 |
},
|
| 55 |
}
|
|
@@ -3,126 +3,356 @@
|
|
| 3 |
MoE Sniper engine for Gemma 4-26B-A4B.
|
| 4 |
|
| 5 |
Architecture differences from Qwen:
|
|
|
|
|
|
|
| 6 |
- Dense MLP runs on every token (always), MoE adds on top
|
| 7 |
-
-
|
| 8 |
-
-
|
| 9 |
-
-
|
| 10 |
-
-
|
| 11 |
-
-
|
| 12 |
-
- gelu_pytorch_tanh activation (not silu)
|
| 13 |
-
- K=V sharing (attention_k_eq_v)
|
| 14 |
"""
|
| 15 |
-
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
import numpy as np
|
| 17 |
import mlx.core as mx
|
| 18 |
import mlx.nn as nn
|
| 19 |
from mlx.utils import tree_flatten
|
|
|
|
| 20 |
from .expert_io import MoEExpertReader
|
| 21 |
from .coactivation import CoActivationTracker
|
| 22 |
|
| 23 |
-
MODEL_DIR = "" # Set before load()
|
| 24 |
-
BITS = 4
|
| 25 |
GROUP_SIZE = 64
|
| 26 |
|
| 27 |
|
| 28 |
-
def
|
| 29 |
-
"""
|
| 30 |
-
return 0.5 * x * (1 + mx.tanh(0.7978845608 * (x + 0.044715 * x * x * x)))
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def run_expert_ffn_gemma4(x, expert_data, top_k_indices, top_k_weights,
|
| 34 |
-
num_experts_total=128, hidden_size=2816, moe_inter=704):
|
| 35 |
-
"""
|
| 36 |
-
Gemma 4 expert FFN. Experts have fused gate_up_proj.
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
'experts.down_proj': [hidden_size, moe_inter] bf16
|
| 41 |
"""
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
eid = int(inds_np[token_idx, k_idx])
|
| 55 |
-
w = float(weights_np[token_idx, k_idx])
|
| 56 |
|
| 57 |
-
|
| 58 |
-
continue
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
gate, up = mx.split(gu, 2)
|
| 69 |
-
h = gelu_tanh(gate) * up
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
| 74 |
|
| 75 |
-
output = output.at[token_idx].add(token_out)
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
-
|
| 82 |
-
|
|
|
|
| 83 |
self.model = None
|
| 84 |
self.reader = None
|
| 85 |
self.tokenizer = None
|
| 86 |
self.cache = None
|
| 87 |
-
self.num_layers = 30
|
| 88 |
-
self.coact = None
|
| 89 |
self._cache_size = cache_size
|
| 90 |
self._enable_prediction = enable_prediction
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
def load(self):
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
NOTE: This is a PLACEHOLDER. Gemma 4 (gemma4) is not yet in mlx-lm.
|
| 96 |
-
Once mlx-lm adds gemma4 support, this will use their Model class.
|
| 97 |
-
For now, this demonstrates the architecture and expert streaming.
|
| 98 |
-
"""
|
| 99 |
-
with open(os.path.join(MODEL_DIR, "config.json")) as f:
|
| 100 |
config = json.load(f)
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
self.
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
self.coact = CoActivationTracker(self.num_layers, warmup_tokens=3)
|
| 113 |
|
| 114 |
-
#
|
| 115 |
-
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
-
|
| 119 |
-
|
| 120 |
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
def reset_cache(self):
|
| 128 |
-
self.cache =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
MoE Sniper engine for Gemma 4-26B-A4B.
|
| 4 |
|
| 5 |
Architecture differences from Qwen:
|
| 6 |
+
- 30 layers (vs 40)
|
| 7 |
+
- 128 experts, top-8 (vs 256, top-8)
|
| 8 |
- Dense MLP runs on every token (always), MoE adds on top
|
| 9 |
+
- Router: inline RMS norm + scale + per_expert_scale
|
| 10 |
+
- gelu_approx activation (not silu)
|
| 11 |
+
- Layer scalar per layer
|
| 12 |
+
- Sliding window + full attention hybrid
|
| 13 |
+
- Mixed quantization: experts 4-bit, dense MLP/router 8-bit
|
|
|
|
|
|
|
| 14 |
"""
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
import sys
|
| 18 |
+
import time
|
| 19 |
+
import gc
|
| 20 |
+
|
| 21 |
import numpy as np
|
| 22 |
import mlx.core as mx
|
| 23 |
import mlx.nn as nn
|
| 24 |
from mlx.utils import tree_flatten
|
| 25 |
+
|
| 26 |
from .expert_io import MoEExpertReader
|
| 27 |
from .coactivation import CoActivationTracker
|
| 28 |
|
|
|
|
|
|
|
| 29 |
GROUP_SIZE = 64
|
| 30 |
|
| 31 |
|
| 32 |
+
def run_expert_ffn_gemma4(x, expert_data, top_k_indices, top_k_weights):
|
| 33 |
+
"""Run expert FFN using gather_qmm with streamed Gemma 4 expert weights.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
+
Expert tensor names: switch_mlp.{gate,up,down}_proj.{weight,scales,biases}
|
| 36 |
+
Activation: gelu_approx (not silu like Qwen)
|
|
|
|
| 37 |
"""
|
| 38 |
+
active_ids = sorted(expert_data.keys())
|
| 39 |
+
if not active_ids:
|
| 40 |
+
return mx.zeros_like(x)
|
| 41 |
|
| 42 |
+
id_to_local = {eid: i for i, eid in enumerate(active_ids)}
|
| 43 |
+
inds_np = np.array(top_k_indices)
|
| 44 |
+
local_np = np.vectorize(lambda v: id_to_local.get(int(v), 0))(inds_np)
|
| 45 |
+
local_indices = mx.array(local_np)
|
| 46 |
|
| 47 |
+
def stack_proj(proj):
|
| 48 |
+
w = mx.stack([expert_data[eid][f"switch_mlp.{proj}.weight"] for eid in active_ids])
|
| 49 |
+
s = mx.stack([expert_data[eid][f"switch_mlp.{proj}.scales"] for eid in active_ids])
|
| 50 |
+
b = mx.stack([expert_data[eid][f"switch_mlp.{proj}.biases"] for eid in active_ids])
|
| 51 |
+
return w, s, b
|
| 52 |
|
| 53 |
+
gate_w, gate_s, gate_b = stack_proj("gate_proj")
|
| 54 |
+
up_w, up_s, up_b = stack_proj("up_proj")
|
| 55 |
+
down_w, down_s, down_b = stack_proj("down_proj")
|
|
|
|
|
|
|
| 56 |
|
| 57 |
+
x_exp = mx.expand_dims(x, (-2, -3))
|
|
|
|
| 58 |
|
| 59 |
+
# Auto-detect bits from weight vs scales shape
|
| 60 |
+
n_packed = gate_w.shape[-1]
|
| 61 |
+
n_groups = gate_s.shape[-1]
|
| 62 |
+
real_input = n_groups * GROUP_SIZE
|
| 63 |
+
bits = round(32 * n_packed / real_input)
|
| 64 |
+
if bits not in (4, 8):
|
| 65 |
+
bits = 4
|
| 66 |
|
| 67 |
+
gate_out = mx.gather_qmm(x_exp, gate_w, scales=gate_s, biases=gate_b,
|
| 68 |
+
rhs_indices=local_indices, transpose=True, group_size=GROUP_SIZE, bits=bits)
|
| 69 |
+
up_out = mx.gather_qmm(x_exp, up_w, scales=up_s, biases=up_b,
|
| 70 |
+
rhs_indices=local_indices, transpose=True, group_size=GROUP_SIZE, bits=bits)
|
| 71 |
|
| 72 |
+
# Gemma 4 uses gelu_approx
|
| 73 |
+
hidden = nn.gelu_approx(gate_out) * up_out
|
|
|
|
|
|
|
| 74 |
|
| 75 |
+
down_out = mx.gather_qmm(hidden, down_w, scales=down_s, biases=down_b,
|
| 76 |
+
rhs_indices=local_indices, transpose=True, group_size=GROUP_SIZE, bits=bits)
|
| 77 |
+
out = down_out.squeeze(-2)
|
| 78 |
+
out = (out * top_k_weights[..., None]).sum(axis=-2)
|
| 79 |
+
return out
|
| 80 |
|
|
|
|
| 81 |
|
| 82 |
+
class MoESniperEngineGemma4:
|
| 83 |
+
"""Single-machine MoE Sniper for Gemma 4-26B-A4B with SSD expert streaming.
|
| 84 |
|
| 85 |
+
Verified results on M4 Mac Mini 16 GB:
|
| 86 |
+
4.15 tok/s, 95.8% cache hit, 7.8 GB RAM
|
| 87 |
+
"""
|
| 88 |
|
| 89 |
+
def __init__(self, model_dir, cache_size=4000, enable_prediction=True,
|
| 90 |
+
routing_bias=1.5):
|
| 91 |
+
self.model_dir = os.path.expanduser(model_dir)
|
| 92 |
self.model = None
|
| 93 |
self.reader = None
|
| 94 |
self.tokenizer = None
|
| 95 |
self.cache = None
|
|
|
|
|
|
|
| 96 |
self._cache_size = cache_size
|
| 97 |
self._enable_prediction = enable_prediction
|
| 98 |
+
self.routing_bias = routing_bias
|
| 99 |
+
self.num_layers = 30
|
| 100 |
+
self.coact = None
|
| 101 |
|
| 102 |
def load(self):
|
| 103 |
+
config_path = os.path.join(self.model_dir, "config.json")
|
| 104 |
+
with open(config_path) as f:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
config = json.load(f)
|
| 106 |
+
|
| 107 |
+
text_config = config.get("text_config", config)
|
| 108 |
+
self.num_layers = text_config["num_hidden_layers"]
|
| 109 |
+
|
| 110 |
+
# Import Gemma 4 model class from this package
|
| 111 |
+
from .models.gemma4 import Model, ModelArgs
|
| 112 |
+
|
| 113 |
+
args = ModelArgs.from_dict(text_config)
|
| 114 |
+
self.model = Model(args)
|
| 115 |
+
|
| 116 |
+
# Mixed quantization handling
|
| 117 |
+
quant_config = config.get("quantization", config.get("quantization_config", {}))
|
| 118 |
+
default_bits = quant_config.get("bits", 4)
|
| 119 |
+
default_gs = quant_config.get("group_size", GROUP_SIZE)
|
| 120 |
+
|
| 121 |
+
def _is_8bit(path, module):
|
| 122 |
+
if not isinstance(module, nn.Linear):
|
| 123 |
+
return False
|
| 124 |
+
full_path = "language_model." + path
|
| 125 |
+
if full_path in quant_config and isinstance(quant_config[full_path], dict):
|
| 126 |
+
return quant_config[full_path].get("bits", default_bits) == 8
|
| 127 |
+
return False
|
| 128 |
+
|
| 129 |
+
def _q4(path, module):
|
| 130 |
+
if isinstance(module, nn.Embedding):
|
| 131 |
+
return True
|
| 132 |
+
if not isinstance(module, nn.Linear):
|
| 133 |
+
return False
|
| 134 |
+
if _is_8bit(path, module):
|
| 135 |
+
return False
|
| 136 |
+
if module.weight.shape[-1] < default_gs:
|
| 137 |
+
return False
|
| 138 |
+
return True
|
| 139 |
+
|
| 140 |
+
nn.quantize(self.model, group_size=default_gs, bits=default_bits,
|
| 141 |
+
class_predicate=_q4)
|
| 142 |
+
nn.quantize(self.model, group_size=64, bits=8,
|
| 143 |
+
class_predicate=lambda p, m: isinstance(m, nn.Linear) and _is_8bit(p, m))
|
| 144 |
+
|
| 145 |
+
mx.set_memory_limit(14 * 1024**3)
|
| 146 |
+
mx.set_cache_limit(512 * 1024**2)
|
| 147 |
+
|
| 148 |
+
# Load pinned weights
|
| 149 |
+
pinned_path = os.path.join(self.model_dir, "pinned.safetensors")
|
| 150 |
+
pinned = mx.load(pinned_path)
|
| 151 |
+
stripped = [(k.replace("language_model.", "", 1), v) for k, v in pinned.items()]
|
| 152 |
+
self.model.load_weights(stripped, strict=False)
|
| 153 |
+
|
| 154 |
+
# Eval only non-expert params
|
| 155 |
+
params = [p for name, p in tree_flatten(self.model.parameters())
|
| 156 |
+
if "expert" not in name and "switch" not in name]
|
| 157 |
+
mx.eval(*params)
|
| 158 |
+
del pinned
|
| 159 |
+
gc.collect()
|
| 160 |
+
mx.clear_cache()
|
| 161 |
+
|
| 162 |
+
pinned_gb = sum(p.nbytes for p in params) / 1e9
|
| 163 |
+
|
| 164 |
+
# Expert reader (F_NOCACHE + pread)
|
| 165 |
+
sniper_config_path = os.path.join(self.model_dir, "sniper_config.json")
|
| 166 |
+
if os.path.exists(sniper_config_path):
|
| 167 |
+
with open(sniper_config_path) as f:
|
| 168 |
+
sc = json.load(f)
|
| 169 |
+
expert_dir = os.path.join(self.model_dir, sc.get("streaming", {}).get("expert_dir", "bin"))
|
| 170 |
+
else:
|
| 171 |
+
expert_dir = os.path.join(self.model_dir, "bin")
|
| 172 |
+
|
| 173 |
+
self.reader = MoEExpertReader(
|
| 174 |
+
expert_dir, self.num_layers,
|
| 175 |
+
num_workers=8, cache_size=self._cache_size
|
| 176 |
+
)
|
| 177 |
self.coact = CoActivationTracker(self.num_layers, warmup_tokens=3)
|
| 178 |
|
| 179 |
+
# Tokenizer (prefer fast tokenizers)
|
| 180 |
+
from tokenizers import Tokenizer
|
| 181 |
+
tok_path = os.path.join(self.model_dir, "tokenizer.json")
|
| 182 |
+
if os.path.exists(tok_path):
|
| 183 |
+
self.tokenizer = Tokenizer.from_file(tok_path)
|
| 184 |
+
self._fast_tok = True
|
| 185 |
+
else:
|
| 186 |
+
from transformers import AutoTokenizer
|
| 187 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
|
| 188 |
+
self._fast_tok = False
|
| 189 |
+
|
| 190 |
+
self.cache = self.model.make_cache()
|
| 191 |
+
|
| 192 |
+
print(f"Gemma 4 Sniper loaded: {pinned_gb:.1f} GB pinned, "
|
| 193 |
+
f"cache={self._cache_size}, layers={self.num_layers}")
|
| 194 |
+
return pinned_gb
|
| 195 |
+
|
| 196 |
+
def encode(self, text):
|
| 197 |
+
if self._fast_tok:
|
| 198 |
+
return self.tokenizer.encode(text).ids
|
| 199 |
+
return self.tokenizer.encode(text)
|
| 200 |
+
|
| 201 |
+
def decode(self, ids):
|
| 202 |
+
if self._fast_tok:
|
| 203 |
+
return self.tokenizer.decode(ids)
|
| 204 |
+
return self.tokenizer.decode(ids, skip_special_tokens=False)
|
| 205 |
|
| 206 |
+
def encode_chat(self, prompt):
|
| 207 |
+
"""Encode prompt with Gemma 4 chat template.
|
| 208 |
|
| 209 |
+
Format: <bos><|turn>user\\n{prompt}<turn|>\\n<|turn>model\\n
|
| 210 |
+
Token IDs: bos=2, turn_start=105, turn_end=106, newline=107
|
| 211 |
+
"""
|
| 212 |
+
NL = chr(10)
|
| 213 |
+
prompt_toks = self.encode(prompt)
|
| 214 |
+
user_toks = self.encode("user" + NL)
|
| 215 |
+
model_toks = self.encode("model" + NL)
|
| 216 |
+
return [2, 105] + user_toks + prompt_toks + [106, 107, 105] + model_toks
|
| 217 |
|
| 218 |
def reset_cache(self):
|
| 219 |
+
self.cache = self.model.make_cache()
|
| 220 |
+
|
| 221 |
+
def forward(self, input_ids):
|
| 222 |
+
"""Forward pass with SSD-streamed experts."""
|
| 223 |
+
from mlx_lm.models.base import create_attention_mask
|
| 224 |
+
|
| 225 |
+
h = self.model.model.embed_tokens(input_ids)
|
| 226 |
+
h = h * (self.model.args.hidden_size ** 0.5)
|
| 227 |
+
|
| 228 |
+
mask = create_attention_mask(h, self.cache[0] if self.cache else None)
|
| 229 |
+
|
| 230 |
+
for i in range(self.num_layers):
|
| 231 |
+
layer = self.model.model.layers[i]
|
| 232 |
+
cache_i = self.cache[i] if self.cache else None
|
| 233 |
+
|
| 234 |
+
# Attention
|
| 235 |
+
residual = h
|
| 236 |
+
h_norm = layer.input_layernorm(h)
|
| 237 |
+
h_attn = layer.self_attn(h_norm, mask=mask, cache=cache_i)
|
| 238 |
+
h_attn = layer.post_attention_layernorm(h_attn)
|
| 239 |
+
h = residual + h_attn
|
| 240 |
+
mx.eval(h)
|
| 241 |
+
|
| 242 |
+
# Dense MLP (always)
|
| 243 |
+
residual = h
|
| 244 |
+
h_ff = layer.pre_feedforward_layernorm(h)
|
| 245 |
+
h_ff = layer.mlp(h_ff)
|
| 246 |
+
|
| 247 |
+
expert_data = {}
|
| 248 |
+
expert_out = None
|
| 249 |
+
moe_input = None
|
| 250 |
+
|
| 251 |
+
if layer.enable_moe_block:
|
| 252 |
+
h_dense = layer.post_feedforward_layernorm_1(h_ff)
|
| 253 |
+
|
| 254 |
+
# Router with cache-aware bias
|
| 255 |
+
B, L, D = residual.shape
|
| 256 |
+
residual_flat = residual.reshape(-1, D)
|
| 257 |
+
router = layer.router
|
| 258 |
+
x_normed = router._inline_rms_norm(residual_flat)
|
| 259 |
+
x_normed = x_normed * router.scale * (router.hidden_size ** -0.5)
|
| 260 |
+
scores = router.proj(x_normed)
|
| 261 |
+
|
| 262 |
+
if self.routing_bias > 0 and self.reader.lru is not None:
|
| 263 |
+
bias_np = np.zeros(scores.shape[-1], dtype=np.float32)
|
| 264 |
+
for (li, eid) in self.reader.lru.cache.keys():
|
| 265 |
+
if li == i:
|
| 266 |
+
bias_np[eid] = self.routing_bias
|
| 267 |
+
if bias_np.any():
|
| 268 |
+
scores = scores + mx.array(bias_np)
|
| 269 |
+
|
| 270 |
+
probs = mx.softmax(scores, axis=-1)
|
| 271 |
+
top_k_indices = mx.argpartition(-probs, kth=router.top_k - 1, axis=-1)[..., :router.top_k]
|
| 272 |
+
top_k_weights = mx.take_along_axis(probs, top_k_indices, axis=-1)
|
| 273 |
+
top_k_weights = top_k_weights / mx.sum(top_k_weights, axis=-1, keepdims=True)
|
| 274 |
+
expert_scales = router.per_expert_scale[top_k_indices]
|
| 275 |
+
top_k_weights = top_k_weights * expert_scales
|
| 276 |
+
|
| 277 |
+
moe_input = layer.pre_feedforward_layernorm_2(residual_flat)
|
| 278 |
+
mx.eval(moe_input, top_k_indices, top_k_weights)
|
| 279 |
+
|
| 280 |
+
top_k_indices_r = top_k_indices.reshape(B, L, -1)
|
| 281 |
+
top_k_weights_r = top_k_weights.reshape(B, L, -1)
|
| 282 |
+
|
| 283 |
+
active_ids = list(set(int(e) for e in np.array(top_k_indices_r).flatten()))
|
| 284 |
+
self.coact.record_layer(i, active_ids)
|
| 285 |
+
|
| 286 |
+
# Predictive prefetch
|
| 287 |
+
if self._enable_prediction and self.coact.ready and i + 1 < self.num_layers:
|
| 288 |
+
predicted = self.coact.predict_next_layer(i, active_ids, top_k=6)
|
| 289 |
+
if predicted:
|
| 290 |
+
to_fetch = [eid for eid in predicted
|
| 291 |
+
if self.reader.lru and self.reader.lru.get(i + 1, eid) is None]
|
| 292 |
+
if to_fetch:
|
| 293 |
+
self.reader.prefetch_experts(i + 1, to_fetch)
|
| 294 |
+
|
| 295 |
+
if i + 1 < self.num_layers:
|
| 296 |
+
self.reader.prefetch_experts(i + 1, active_ids)
|
| 297 |
+
|
| 298 |
+
# Expert FFN from SSD
|
| 299 |
+
expert_data = self.reader.get_experts(i, active_ids)
|
| 300 |
+
moe_input_r = moe_input.reshape(B, L, D)
|
| 301 |
+
expert_out = run_expert_ffn_gemma4(moe_input_r, expert_data,
|
| 302 |
+
top_k_indices_r, top_k_weights_r)
|
| 303 |
+
h_moe = layer.post_feedforward_layernorm_2(expert_out)
|
| 304 |
+
h_ff = h_dense + h_moe
|
| 305 |
+
|
| 306 |
+
# Final norm + residual + scalar
|
| 307 |
+
h_ff = layer.post_feedforward_layernorm(h_ff)
|
| 308 |
+
h = residual + h_ff
|
| 309 |
+
h = h * layer.layer_scalar
|
| 310 |
+
mx.eval(h)
|
| 311 |
+
|
| 312 |
+
del expert_data, expert_out, moe_input
|
| 313 |
+
mx.clear_cache()
|
| 314 |
+
|
| 315 |
+
self.coact.end_token()
|
| 316 |
+
h = self.model.model.norm(h)
|
| 317 |
+
|
| 318 |
+
if self.model.args.tie_word_embeddings:
|
| 319 |
+
return self.model.model.embed_tokens.as_linear(h)
|
| 320 |
+
else:
|
| 321 |
+
return self.model.lm_head(h)
|
| 322 |
+
|
| 323 |
+
def generate(self, prompt, max_tokens=200, temperature=0.7):
|
| 324 |
+
"""Generate text from a prompt with chat template + EOS detection."""
|
| 325 |
+
tokens = self.encode_chat(prompt)
|
| 326 |
+
input_ids = mx.array([tokens])
|
| 327 |
+
|
| 328 |
+
# Prefill
|
| 329 |
+
logits = self.forward(input_ids)
|
| 330 |
+
mx.eval(logits)
|
| 331 |
+
|
| 332 |
+
if temperature <= 0:
|
| 333 |
+
next_token = int(mx.argmax(logits[0, -1]).item())
|
| 334 |
+
else:
|
| 335 |
+
probs = mx.softmax(logits[0, -1] / temperature, axis=-1)
|
| 336 |
+
next_token = int(mx.random.categorical(mx.log(probs + 1e-10)).item())
|
| 337 |
+
|
| 338 |
+
generated = [next_token]
|
| 339 |
+
input_ids = mx.array([[next_token]])
|
| 340 |
+
|
| 341 |
+
for step in range(max_tokens - 1):
|
| 342 |
+
logits = self.forward(input_ids)
|
| 343 |
+
mx.eval(logits)
|
| 344 |
+
|
| 345 |
+
if temperature <= 0:
|
| 346 |
+
next_token = int(mx.argmax(logits[0, -1]).item())
|
| 347 |
+
else:
|
| 348 |
+
probs = mx.softmax(logits[0, -1] / temperature, axis=-1)
|
| 349 |
+
next_token = int(mx.random.categorical(mx.log(probs + 1e-10)).item())
|
| 350 |
+
|
| 351 |
+
generated.append(next_token)
|
| 352 |
+
input_ids = mx.array([[next_token]])
|
| 353 |
+
|
| 354 |
+
# EOS: <eos>=1, <turn|>=106
|
| 355 |
+
if next_token in [1, 106]:
|
| 356 |
+
break
|
| 357 |
+
|
| 358 |
+
return self.decode(generated)
|
|
@@ -20,7 +20,17 @@ def load_engine(model_dir):
|
|
| 20 |
bias = 0.0
|
| 21 |
|
| 22 |
model_type = _detect_model_type(model_dir)
|
| 23 |
-
if "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
from . import engine_next as engine_mod
|
| 25 |
engine_mod.MODEL_DIR = model_dir
|
| 26 |
from .engine_next import MoESniperEngineNext as EngineClass
|
|
@@ -39,8 +49,19 @@ def load_engine(model_dir):
|
|
| 39 |
|
| 40 |
|
| 41 |
def generate_stream(engine, messages, bias=0.0, max_tokens=200):
|
| 42 |
-
"""Generator yielding token strings.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
import mlx.core as mx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
from mlx_lm.models.base import create_attention_mask
|
| 45 |
from .engine import run_expert_ffn
|
| 46 |
|
|
@@ -150,3 +171,50 @@ def generate_stream(engine, messages, bias=0.0, max_tokens=200):
|
|
| 150 |
yield chunk
|
| 151 |
logits = forward(token.reshape(1, 1))
|
| 152 |
mx.eval(logits)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
bias = 0.0
|
| 21 |
|
| 22 |
model_type = _detect_model_type(model_dir)
|
| 23 |
+
if "gemma4" in model_type:
|
| 24 |
+
# Gemma 4 has its own engine due to architectural differences
|
| 25 |
+
from .engine_gemma4 import MoESniperEngineGemma4
|
| 26 |
+
eng = MoESniperEngineGemma4(
|
| 27 |
+
model_dir=model_dir,
|
| 28 |
+
cache_size=cache_size,
|
| 29 |
+
enable_prediction=True,
|
| 30 |
+
)
|
| 31 |
+
eng.load()
|
| 32 |
+
return eng, bias, model_type
|
| 33 |
+
elif "qwen3_next" in model_type:
|
| 34 |
from . import engine_next as engine_mod
|
| 35 |
engine_mod.MODEL_DIR = model_dir
|
| 36 |
from .engine_next import MoESniperEngineNext as EngineClass
|
|
|
|
| 49 |
|
| 50 |
|
| 51 |
def generate_stream(engine, messages, bias=0.0, max_tokens=200):
|
| 52 |
+
"""Generator yielding token strings.
|
| 53 |
+
|
| 54 |
+
Dispatches to model-specific forward path:
|
| 55 |
+
- Gemma 4 → uses engine's own forward() (different architecture)
|
| 56 |
+
- Qwen 3.x → inline Qwen forward (SSM hybrid or standard)
|
| 57 |
+
"""
|
| 58 |
import mlx.core as mx
|
| 59 |
+
|
| 60 |
+
# Gemma 4 has its own engine with a built-in forward + chat template
|
| 61 |
+
if engine.__class__.__name__ == "MoESniperEngineGemma4":
|
| 62 |
+
yield from _gemma4_generate_stream(engine, messages, max_tokens=max_tokens)
|
| 63 |
+
return
|
| 64 |
+
|
| 65 |
from mlx_lm.models.base import create_attention_mask
|
| 66 |
from .engine import run_expert_ffn
|
| 67 |
|
|
|
|
| 171 |
yield chunk
|
| 172 |
logits = forward(token.reshape(1, 1))
|
| 173 |
mx.eval(logits)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _gemma4_generate_stream(engine, messages, max_tokens=200):
|
| 177 |
+
"""Gemma 4 streaming generation using the engine's built-in forward.
|
| 178 |
+
|
| 179 |
+
Handles Gemma 4's chat template (turn_start/turn_end tokens) and
|
| 180 |
+
its mixed-quantization architecture.
|
| 181 |
+
"""
|
| 182 |
+
import mlx.core as mx
|
| 183 |
+
|
| 184 |
+
engine.reset_cache()
|
| 185 |
+
|
| 186 |
+
# Build prompt from messages — concatenate all user content for now
|
| 187 |
+
# (multi-turn handling can be added later)
|
| 188 |
+
prompt = ""
|
| 189 |
+
for msg in messages:
|
| 190 |
+
if msg.get("role") == "user":
|
| 191 |
+
prompt = msg.get("content", "")
|
| 192 |
+
break
|
| 193 |
+
|
| 194 |
+
# Use engine's chat template encoder
|
| 195 |
+
tokens = engine.encode_chat(prompt)
|
| 196 |
+
input_ids = mx.array([tokens])
|
| 197 |
+
|
| 198 |
+
# Prefill
|
| 199 |
+
logits = engine.forward(input_ids)
|
| 200 |
+
mx.eval(logits)
|
| 201 |
+
|
| 202 |
+
# Sample first token
|
| 203 |
+
next_token = int(mx.argmax(logits[0, -1]).item())
|
| 204 |
+
|
| 205 |
+
# Gemma 4 EOS: <eos>=1, <turn|>=106
|
| 206 |
+
EOS = {1, 106}
|
| 207 |
+
|
| 208 |
+
for _ in range(max_tokens):
|
| 209 |
+
if next_token in EOS:
|
| 210 |
+
break
|
| 211 |
+
|
| 212 |
+
chunk = engine.decode([next_token])
|
| 213 |
+
if chunk:
|
| 214 |
+
yield chunk
|
| 215 |
+
|
| 216 |
+
# Next forward step
|
| 217 |
+
input_ids = mx.array([[next_token]])
|
| 218 |
+
logits = engine.forward(input_ids)
|
| 219 |
+
mx.eval(logits)
|
| 220 |
+
next_token = int(mx.argmax(logits[0, -1]).item())
|
|
@@ -12,6 +12,7 @@ Architecture: gemma4_text
|
|
| 12 |
Reference: HuggingFace transformers Gemma4TextModel
|
| 13 |
"""
|
| 14 |
|
|
|
|
| 15 |
from dataclasses import dataclass, field
|
| 16 |
from typing import Any, Dict, List, Optional, Tuple
|
| 17 |
|
|
@@ -74,7 +75,20 @@ class RMSNorm(nn.Module):
|
|
| 74 |
self.eps = eps
|
| 75 |
|
| 76 |
def __call__(self, x: mx.array) -> mx.array:
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
# --------------------------------------------------------------------------- #
|
|
@@ -95,7 +109,8 @@ class Attention(nn.Module):
|
|
| 95 |
super().__init__()
|
| 96 |
self.layer_idx = layer_idx
|
| 97 |
self.is_sliding = args.layer_types[layer_idx] == "sliding_attention"
|
| 98 |
-
|
|
|
|
| 99 |
|
| 100 |
self.n_heads = args.num_attention_heads
|
| 101 |
|
|
@@ -111,17 +126,18 @@ class Attention(nn.Module):
|
|
| 111 |
rope_dims = int(args.global_head_dim * args.partial_rotary_factor)
|
| 112 |
rope_theta = args.rope_theta_global
|
| 113 |
|
| 114 |
-
self.scale =
|
| 115 |
|
| 116 |
self.q_proj = nn.Linear(args.hidden_size, self.n_heads * self.head_dim, bias=False)
|
| 117 |
self.k_proj = nn.Linear(args.hidden_size, self.n_kv_heads * self.head_dim, bias=False)
|
| 118 |
-
# v_proj
|
| 119 |
-
if not self.
|
| 120 |
self.v_proj = nn.Linear(args.hidden_size, self.n_kv_heads * self.head_dim, bias=False)
|
| 121 |
self.o_proj = nn.Linear(self.n_heads * self.head_dim, args.hidden_size, bias=False)
|
| 122 |
|
| 123 |
self.q_norm = RMSNorm(self.head_dim, eps=args.rms_norm_eps)
|
| 124 |
self.k_norm = RMSNorm(self.head_dim, eps=args.rms_norm_eps)
|
|
|
|
| 125 |
|
| 126 |
self.rope = nn.RoPE(rope_dims, traditional=False, base=rope_theta)
|
| 127 |
|
|
@@ -135,15 +151,17 @@ class Attention(nn.Module):
|
|
| 135 |
|
| 136 |
queries = self.q_proj(x)
|
| 137 |
keys = self.k_proj(x)
|
| 138 |
-
# K=V:
|
| 139 |
-
values = keys if self.
|
| 140 |
|
| 141 |
queries = queries.reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
| 142 |
keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
|
| 143 |
values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
|
| 144 |
|
|
|
|
| 145 |
queries = self.q_norm(queries)
|
| 146 |
keys = self.k_norm(keys)
|
|
|
|
| 147 |
|
| 148 |
if cache is not None:
|
| 149 |
queries = self.rope(queries, offset=cache.offset)
|
|
@@ -198,8 +216,8 @@ class Router(nn.Module):
|
|
| 198 |
self.top_k = args.top_k_experts
|
| 199 |
|
| 200 |
self.proj = nn.Linear(args.hidden_size, args.num_experts, bias=False)
|
| 201 |
-
# Learnable
|
| 202 |
-
self.scale = mx.ones((
|
| 203 |
# Per-expert scales
|
| 204 |
self.per_expert_scale = mx.ones((args.num_experts,))
|
| 205 |
|
|
@@ -352,6 +370,7 @@ class DecoderLayer(nn.Module):
|
|
| 352 |
# Attention
|
| 353 |
self.self_attn = Attention(args, layer_idx)
|
| 354 |
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
|
|
|
| 355 |
|
| 356 |
# Dense MLP
|
| 357 |
self.mlp = DenseMLP(args)
|
|
@@ -377,37 +396,44 @@ class DecoderLayer(nn.Module):
|
|
| 377 |
mask: Optional[mx.array] = None,
|
| 378 |
cache: Optional[Any] = None,
|
| 379 |
) -> mx.array:
|
| 380 |
-
# 1. Attention
|
| 381 |
residual = x
|
| 382 |
h = self.input_layernorm(x)
|
| 383 |
h = self.self_attn(h, mask, cache)
|
|
|
|
| 384 |
h = residual + h
|
| 385 |
|
| 386 |
-
# 2.
|
| 387 |
residual = h
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
h = self.post_feedforward_layernorm(dense_out)
|
| 391 |
|
| 392 |
-
# 3. MoE (parallel to dense, sharing the same residual)
|
| 393 |
if self.enable_moe_block:
|
| 394 |
-
#
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
#
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
#
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
h = residual + h
|
|
|
|
| 411 |
h = h * self.layer_scalar
|
| 412 |
|
| 413 |
return h
|
|
@@ -538,9 +564,16 @@ class Model(nn.Module):
|
|
| 538 |
if new_key.startswith("model.language_model."):
|
| 539 |
new_key = "model." + new_key[len("model.language_model."):]
|
| 540 |
|
| 541 |
-
# Drop v_proj
|
|
|
|
| 542 |
if self.args.attention_k_eq_v and "v_proj" in new_key:
|
| 543 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 544 |
|
| 545 |
# Drop lm_head when tied
|
| 546 |
if self.args.tie_word_embeddings and new_key == "lm_head.weight":
|
|
|
|
| 12 |
Reference: HuggingFace transformers Gemma4TextModel
|
| 13 |
"""
|
| 14 |
|
| 15 |
+
import re
|
| 16 |
from dataclasses import dataclass, field
|
| 17 |
from typing import Any, Dict, List, Optional, Tuple
|
| 18 |
|
|
|
|
| 75 |
self.eps = eps
|
| 76 |
|
| 77 |
def __call__(self, x: mx.array) -> mx.array:
|
| 78 |
+
# Gemma 4 GGUF norm_shift=0.0: weight is the final multiplier (no +1 offset)
|
| 79 |
+
# Confirmed by mlx-vlm RMSNormZeroShift implementation
|
| 80 |
+
return mx.fast.rms_norm(x, self.weight, self.eps)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class BareRMSNorm(nn.Module):
|
| 84 |
+
"""RMSNorm without learnable scale (used for v_norm)."""
|
| 85 |
+
def __init__(self, dims: int, eps: float = 1e-6):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.eps = eps
|
| 88 |
+
self._dims = dims
|
| 89 |
+
|
| 90 |
+
def __call__(self, x: mx.array) -> mx.array:
|
| 91 |
+
return mx.fast.rms_norm(x, mx.ones((self._dims,)), self.eps)
|
| 92 |
|
| 93 |
|
| 94 |
# --------------------------------------------------------------------------- #
|
|
|
|
| 109 |
super().__init__()
|
| 110 |
self.layer_idx = layer_idx
|
| 111 |
self.is_sliding = args.layer_types[layer_idx] == "sliding_attention"
|
| 112 |
+
# K=V sharing only applies to full (non-sliding) attention layers
|
| 113 |
+
self.use_kv_sharing = args.attention_k_eq_v and not self.is_sliding
|
| 114 |
|
| 115 |
self.n_heads = args.num_attention_heads
|
| 116 |
|
|
|
|
| 126 |
rope_dims = int(args.global_head_dim * args.partial_rotary_factor)
|
| 127 |
rope_theta = args.rope_theta_global
|
| 128 |
|
| 129 |
+
self.scale = 1.0 # HF Gemma4 uses scaling=1.0; q_norm/k_norm handle magnitude
|
| 130 |
|
| 131 |
self.q_proj = nn.Linear(args.hidden_size, self.n_heads * self.head_dim, bias=False)
|
| 132 |
self.k_proj = nn.Linear(args.hidden_size, self.n_kv_heads * self.head_dim, bias=False)
|
| 133 |
+
# v_proj needed for sliding layers; dropped for full layers with K=V sharing
|
| 134 |
+
if not self.use_kv_sharing:
|
| 135 |
self.v_proj = nn.Linear(args.hidden_size, self.n_kv_heads * self.head_dim, bias=False)
|
| 136 |
self.o_proj = nn.Linear(self.n_heads * self.head_dim, args.hidden_size, bias=False)
|
| 137 |
|
| 138 |
self.q_norm = RMSNorm(self.head_dim, eps=args.rms_norm_eps)
|
| 139 |
self.k_norm = RMSNorm(self.head_dim, eps=args.rms_norm_eps)
|
| 140 |
+
self.v_norm = BareRMSNorm(self.head_dim, eps=args.rms_norm_eps)
|
| 141 |
|
| 142 |
self.rope = nn.RoPE(rope_dims, traditional=False, base=rope_theta)
|
| 143 |
|
|
|
|
| 151 |
|
| 152 |
queries = self.q_proj(x)
|
| 153 |
keys = self.k_proj(x)
|
| 154 |
+
# K=V sharing: only for full attention layers
|
| 155 |
+
values = keys if self.use_kv_sharing else self.v_proj(x)
|
| 156 |
|
| 157 |
queries = queries.reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
| 158 |
keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
|
| 159 |
values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
|
| 160 |
|
| 161 |
+
# Norms: q_norm and k_norm BEFORE RoPE, v_norm on values
|
| 162 |
queries = self.q_norm(queries)
|
| 163 |
keys = self.k_norm(keys)
|
| 164 |
+
values = self.v_norm(values)
|
| 165 |
|
| 166 |
if cache is not None:
|
| 167 |
queries = self.rope(queries, offset=cache.offset)
|
|
|
|
| 216 |
self.top_k = args.top_k_experts
|
| 217 |
|
| 218 |
self.proj = nn.Linear(args.hidden_size, args.num_experts, bias=False)
|
| 219 |
+
# Learnable per-dimension scale (shape matches hidden_size)
|
| 220 |
+
self.scale = mx.ones((args.hidden_size,))
|
| 221 |
# Per-expert scales
|
| 222 |
self.per_expert_scale = mx.ones((args.num_experts,))
|
| 223 |
|
|
|
|
| 370 |
# Attention
|
| 371 |
self.self_attn = Attention(args, layer_idx)
|
| 372 |
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
| 373 |
+
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
| 374 |
|
| 375 |
# Dense MLP
|
| 376 |
self.mlp = DenseMLP(args)
|
|
|
|
| 396 |
mask: Optional[mx.array] = None,
|
| 397 |
cache: Optional[Any] = None,
|
| 398 |
) -> mx.array:
|
| 399 |
+
# 1. Attention with pre/post norms and residual
|
| 400 |
residual = x
|
| 401 |
h = self.input_layernorm(x)
|
| 402 |
h = self.self_attn(h, mask, cache)
|
| 403 |
+
h = self.post_attention_layernorm(h)
|
| 404 |
h = residual + h
|
| 405 |
|
| 406 |
+
# 2. Feed-forward (dense MLP, optionally combined with MoE)
|
| 407 |
residual = h
|
| 408 |
+
h = self.pre_feedforward_layernorm(h)
|
| 409 |
+
h = self.mlp(h)
|
|
|
|
| 410 |
|
|
|
|
| 411 |
if self.enable_moe_block:
|
| 412 |
+
# Dense MLP output -> post_feedforward_layernorm_1
|
| 413 |
+
h_dense = self.post_feedforward_layernorm_1(h)
|
| 414 |
+
|
| 415 |
+
# MoE: router takes residual (pre-MLP hidden states), NOT normed
|
| 416 |
+
B, L, D = residual.shape
|
| 417 |
+
residual_flat = residual.reshape(-1, D)
|
| 418 |
+
top_k_weights, top_k_indices = self.router(residual_flat)
|
| 419 |
+
|
| 420 |
+
# Expert input: pre_feedforward_layernorm_2 applied to residual
|
| 421 |
+
moe_input = self.pre_feedforward_layernorm_2(residual_flat)
|
| 422 |
+
expert_out = self.experts(
|
| 423 |
+
moe_input.reshape(B, L, D), top_k_indices.reshape(B, L, -1)
|
| 424 |
+
)
|
| 425 |
+
# Weighted sum over top-k experts
|
| 426 |
+
top_k_weights_r = top_k_weights.reshape(B, L, -1)
|
| 427 |
+
weighted_out = (expert_out * mx.expand_dims(top_k_weights_r, -1)).sum(axis=-2)
|
| 428 |
+
h_moe = self.post_feedforward_layernorm_2(weighted_out)
|
| 429 |
+
|
| 430 |
+
# Combine dense + MoE
|
| 431 |
+
h = h_dense + h_moe
|
| 432 |
+
|
| 433 |
+
# Final post-feedforward norm + residual
|
| 434 |
+
h = self.post_feedforward_layernorm(h)
|
| 435 |
h = residual + h
|
| 436 |
+
|
| 437 |
h = h * self.layer_scalar
|
| 438 |
|
| 439 |
return h
|
|
|
|
| 564 |
if new_key.startswith("model.language_model."):
|
| 565 |
new_key = "model." + new_key[len("model.language_model."):]
|
| 566 |
|
| 567 |
+
# Drop v_proj only for full attention layers with K=V sharing
|
| 568 |
+
# Sliding layers still need v_proj even when attention_k_eq_v is true
|
| 569 |
if self.args.attention_k_eq_v and "v_proj" in new_key:
|
| 570 |
+
# Extract layer index to check if it's a full attention layer
|
| 571 |
+
layer_match = re.search(r'layers\.(\d+)\.', new_key)
|
| 572 |
+
if layer_match:
|
| 573 |
+
layer_idx = int(layer_match.group(1))
|
| 574 |
+
if self.args.layer_types[layer_idx] != "sliding_attention":
|
| 575 |
+
continue # Drop v_proj for full attention layers
|
| 576 |
+
# If no layer index found, keep the weight
|
| 577 |
|
| 578 |
# Drop lm_head when tied
|
| 579 |
if self.args.tie_word_embeddings and new_key == "lm_head.weight":
|
|
@@ -2,231 +2,225 @@
|
|
| 2 |
"""
|
| 3 |
Preprocess Gemma 4-26B-A4B into sniper streaming format.
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
| 13 |
"""
|
| 14 |
-
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
import numpy as np
|
| 16 |
import mlx.core as mx
|
| 17 |
|
| 18 |
PAGE_SIZE = 16384
|
| 19 |
|
| 20 |
-
# Gemma 4 expert tensors (per layer, shape includes expert dim)
|
| 21 |
-
EXPERT_TENSORS = [
|
| 22 |
-
"experts.gate_up_proj", # [num_experts, 2*moe_inter, hidden]
|
| 23 |
-
"experts.down_proj", # [num_experts, hidden, moe_inter]
|
| 24 |
-
]
|
| 25 |
-
|
| 26 |
|
| 27 |
-
def preprocess_gemma4(input_dir, output_dir
|
| 28 |
-
"""Split Gemma 4 into pinned + streaming
|
| 29 |
|
| 30 |
Args:
|
| 31 |
-
input_dir: HuggingFace download directory
|
| 32 |
-
output_dir: sniper streaming format output
|
| 33 |
-
quantize_experts: if True, quantize experts to 4-bit (saves disk)
|
| 34 |
"""
|
| 35 |
-
os.
|
| 36 |
-
os.
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
config = json.load(open(os.path.join(
|
| 39 |
tc = config.get("text_config", config)
|
| 40 |
NUM_LAYERS = tc["num_hidden_layers"]
|
| 41 |
NUM_EXPERTS = tc["num_experts"]
|
| 42 |
-
hidden_size = tc["hidden_size"]
|
| 43 |
-
moe_inter = tc["moe_intermediate_size"]
|
| 44 |
|
| 45 |
-
|
| 46 |
-
print(f"
|
| 47 |
-
print(f"
|
| 48 |
-
print(f"
|
|
|
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
expert_layers_done = set()
|
| 53 |
t0 = time.time()
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
|
| 116 |
# Save pinned
|
|
|
|
|
|
|
| 117 |
pinned_bytes = sum(v.nbytes for v in pinned.values())
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
dst = os.path.join(output_dir, "bin", f"layer_{i:02d}.bin")
|
| 126 |
-
if os.path.exists(os.path.join(output_dir, "bin", src)) and not os.path.exists(dst):
|
| 127 |
-
os.symlink(src, dst)
|
| 128 |
-
|
| 129 |
-
# Write config
|
| 130 |
-
stream_config = {
|
| 131 |
-
"model_type": "gemma4",
|
| 132 |
-
"hidden_size": hidden_size,
|
| 133 |
-
"num_hidden_layers": NUM_LAYERS,
|
| 134 |
-
"num_experts": NUM_EXPERTS,
|
| 135 |
-
"top_k_experts": tc["top_k_experts"],
|
| 136 |
-
"moe_intermediate_size": moe_inter,
|
| 137 |
-
"intermediate_size": tc["intermediate_size"],
|
| 138 |
-
"num_attention_heads": tc["num_attention_heads"],
|
| 139 |
-
"num_key_value_heads": tc["num_key_value_heads"],
|
| 140 |
-
"num_global_key_value_heads": tc.get("num_global_key_value_heads", 2),
|
| 141 |
-
"global_head_dim": tc.get("global_head_dim", 512),
|
| 142 |
-
"head_dim": tc.get("head_dim", 256),
|
| 143 |
-
"vocab_size": tc["vocab_size"],
|
| 144 |
-
"rms_norm_eps": tc.get("rms_norm_eps", 1e-6),
|
| 145 |
-
"sliding_window": tc.get("sliding_window", 1024),
|
| 146 |
-
"layer_types": tc.get("layer_types", []),
|
| 147 |
-
"hidden_activation": tc.get("hidden_activation", "gelu_pytorch_tanh"),
|
| 148 |
-
"final_logit_softcapping": tc.get("final_logit_softcapping", 30.0),
|
| 149 |
-
"enable_moe_block": tc.get("enable_moe_block", True),
|
| 150 |
-
"attention_k_eq_v": tc.get("attention_k_eq_v", True),
|
| 151 |
-
"rope_parameters": tc.get("rope_parameters"),
|
| 152 |
-
"max_position_embeddings": tc.get("max_position_embeddings", 262144),
|
| 153 |
-
"tie_word_embeddings": config.get("tie_word_embeddings", True),
|
| 154 |
-
"streaming": {"pinned_file": "pinned.safetensors", "expert_dir": "bin"},
|
| 155 |
-
}
|
| 156 |
-
with open(os.path.join(output_dir, "config.json"), "w") as f:
|
| 157 |
json.dump(stream_config, f, indent=2)
|
| 158 |
|
| 159 |
-
# Copy tokenizer
|
| 160 |
import shutil
|
| 161 |
-
for tf in ["tokenizer.json", "tokenizer_config.json", "
|
| 162 |
-
"
|
| 163 |
-
src = os.path.join(
|
| 164 |
if os.path.exists(src):
|
| 165 |
-
shutil.copy(src, os.path.join(
|
| 166 |
-
|
| 167 |
-
# Verify
|
| 168 |
-
layer_count = sum(1 for f in os.listdir(os.path.join(output_dir, "bin"))
|
| 169 |
-
if f.startswith("moe_layer_") and f.endswith(".bin"))
|
| 170 |
-
elapsed = time.time() - t0
|
| 171 |
-
print(f"\n Done in {elapsed:.0f}s!")
|
| 172 |
-
print(f" Pinned: {pinned_bytes/1e9:.2f} GB, Experts: {total_expert_bytes/1e9:.2f} GB")
|
| 173 |
-
print(f" Layers: {layer_count}/{NUM_LAYERS}")
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
def _write_expert_layer(output_dir, layer_idx, layer_tensors, num_experts, t0):
|
| 177 |
-
"""Write one layer's experts to a binary file."""
|
| 178 |
-
# Build tensor info and calculate sizes
|
| 179 |
-
tensor_order = ["experts.gate_up_proj", "experts.down_proj"]
|
| 180 |
-
tensor_info = {}
|
| 181 |
-
offset = 0
|
| 182 |
-
for tname in tensor_order:
|
| 183 |
-
t = layer_tensors[tname]
|
| 184 |
-
per_expert_shape = list(t.shape[1:]) # remove expert dim
|
| 185 |
-
per_expert_bytes = int(np.prod(per_expert_shape)) * t.dtype.size
|
| 186 |
-
tensor_info[tname] = {
|
| 187 |
-
"inner_offset": offset,
|
| 188 |
-
"nbytes": per_expert_bytes,
|
| 189 |
-
"shape_per_expert": per_expert_shape,
|
| 190 |
-
"dtype": str(t.dtype),
|
| 191 |
-
}
|
| 192 |
-
offset += per_expert_bytes
|
| 193 |
-
|
| 194 |
-
expert_block_size = ((offset + PAGE_SIZE - 1) // PAGE_SIZE) * PAGE_SIZE
|
| 195 |
-
|
| 196 |
-
header = {
|
| 197 |
-
"layer_idx": layer_idx,
|
| 198 |
-
"num_experts": num_experts,
|
| 199 |
-
"layout": {
|
| 200 |
-
"expert_block_size": expert_block_size,
|
| 201 |
-
"data_start": PAGE_SIZE,
|
| 202 |
-
"tensors": tensor_info,
|
| 203 |
-
}
|
| 204 |
-
}
|
| 205 |
-
header_json = json.dumps(header).encode()
|
| 206 |
-
header_padded = header_json + b"\x00" * (PAGE_SIZE - len(header_json))
|
| 207 |
-
|
| 208 |
-
layer_path = os.path.join(output_dir, "bin", f"moe_layer_{layer_idx:02d}.bin")
|
| 209 |
-
with open(layer_path, "wb") as f:
|
| 210 |
-
f.write(header_padded)
|
| 211 |
-
for eid in range(num_experts):
|
| 212 |
-
expert_data = bytearray()
|
| 213 |
-
for tname in tensor_order:
|
| 214 |
-
expert_t = layer_tensors[tname][eid]
|
| 215 |
-
mx.eval(expert_t)
|
| 216 |
-
if expert_t.dtype == mx.bfloat16:
|
| 217 |
-
raw = np.array(expert_t.view(mx.uint16)).tobytes()
|
| 218 |
-
else:
|
| 219 |
-
raw = np.array(expert_t).tobytes()
|
| 220 |
-
expert_data.extend(raw)
|
| 221 |
-
pad = expert_block_size - len(expert_data)
|
| 222 |
-
if pad > 0:
|
| 223 |
-
expert_data.extend(b"\x00" * pad)
|
| 224 |
-
f.write(bytes(expert_data))
|
| 225 |
-
|
| 226 |
-
sym = os.path.join(output_dir, "bin", f"layer_{layer_idx:02d}.bin")
|
| 227 |
-
if not os.path.exists(sym):
|
| 228 |
-
os.symlink(f"moe_layer_{layer_idx:02d}.bin", sym)
|
| 229 |
|
| 230 |
elapsed = time.time() - t0
|
| 231 |
-
|
| 232 |
-
print(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
"""
|
| 3 |
Preprocess Gemma 4-26B-A4B into sniper streaming format.
|
| 4 |
|
| 5 |
+
Source: mlx-community/gemma-4-26b-a4b-it-4bit (15.6 GB)
|
| 6 |
+
|
| 7 |
+
Gemma 4 stores experts as SwitchLinear stacked tensors:
|
| 8 |
+
language_model.model.layers.X.experts.switch_glu.{gate,up,down}_proj.{weight,scales,biases}
|
| 9 |
+
Each tensor shape: (128, ...) where dim 0 is the expert index
|
| 10 |
+
|
| 11 |
+
This script unstacks them into per-expert blocks in the bin format that
|
| 12 |
+
expert_io.MoEExpertReader expects:
|
| 13 |
+
bin/layer_XX.bin (header + 128 expert blocks per layer)
|
| 14 |
+
|
| 15 |
+
Mixed quantization is preserved: experts stay 4-bit, dense MLP is 8-bit.
|
| 16 |
"""
|
| 17 |
+
import os
|
| 18 |
+
import json
|
| 19 |
+
import gc
|
| 20 |
+
import time
|
| 21 |
+
import glob
|
| 22 |
+
|
| 23 |
import numpy as np
|
| 24 |
import mlx.core as mx
|
| 25 |
|
| 26 |
PAGE_SIZE = 16384
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
def preprocess_gemma4(input_dir, output_dir):
|
| 30 |
+
"""Split Gemma 4 into pinned safetensors + streaming expert bins.
|
| 31 |
|
| 32 |
Args:
|
| 33 |
+
input_dir: HuggingFace download directory (mlx-community 4-bit)
|
| 34 |
+
output_dir: sniper streaming format output directory
|
|
|
|
| 35 |
"""
|
| 36 |
+
INPUT_DIR = os.path.expanduser(input_dir)
|
| 37 |
+
OUTPUT_DIR = os.path.expanduser(output_dir)
|
| 38 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 39 |
+
os.makedirs(os.path.join(OUTPUT_DIR, "bin"), exist_ok=True)
|
| 40 |
|
| 41 |
+
config = json.load(open(os.path.join(INPUT_DIR, "config.json")))
|
| 42 |
tc = config.get("text_config", config)
|
| 43 |
NUM_LAYERS = tc["num_hidden_layers"]
|
| 44 |
NUM_EXPERTS = tc["num_experts"]
|
|
|
|
|
|
|
| 45 |
|
| 46 |
+
print(f"Gemma 4 Preprocess (SwitchLinear unstack)")
|
| 47 |
+
print(f" Input: {INPUT_DIR}")
|
| 48 |
+
print(f" Output: {OUTPUT_DIR}")
|
| 49 |
+
print(f" Layers: {NUM_LAYERS}, Experts: {NUM_EXPERTS}")
|
| 50 |
+
print()
|
| 51 |
|
| 52 |
+
# Load all weights
|
| 53 |
+
print("Loading safetensors...")
|
|
|
|
| 54 |
t0 = time.time()
|
| 55 |
+
all_weights = {}
|
| 56 |
+
for sf in sorted(glob.glob(os.path.join(INPUT_DIR, "model-*.safetensors"))):
|
| 57 |
+
print(f" {os.path.basename(sf)}")
|
| 58 |
+
all_weights.update(mx.load(sf))
|
| 59 |
+
|
| 60 |
+
# Identify expert vs pinned keys
|
| 61 |
+
pinned = {}
|
| 62 |
+
expert_tensors = {} # layer_idx -> {tensor_name: stacked_tensor}
|
| 63 |
+
|
| 64 |
+
EXPERT_PREFIX_TMPL = "language_model.model.layers.{}.experts.switch_glu.{}.{}"
|
| 65 |
+
PROJ_NAMES = ["gate_proj", "up_proj", "down_proj"]
|
| 66 |
+
COMP_NAMES = ["weight", "scales", "biases"]
|
| 67 |
+
|
| 68 |
+
# Build set of expert key paths for fast lookup
|
| 69 |
+
expert_key_set = set()
|
| 70 |
+
for li in range(NUM_LAYERS):
|
| 71 |
+
for proj in PROJ_NAMES:
|
| 72 |
+
for comp in COMP_NAMES:
|
| 73 |
+
expert_key_set.add(EXPERT_PREFIX_TMPL.format(li, proj, comp))
|
| 74 |
+
|
| 75 |
+
for key, val in all_weights.items():
|
| 76 |
+
if key in expert_key_set:
|
| 77 |
+
# Extract layer_idx from key
|
| 78 |
+
parts = key.split(".")
|
| 79 |
+
li = int(parts[3]) # language_model.model.layers.X.experts...
|
| 80 |
+
proj = parts[6] # gate_proj/up_proj/down_proj
|
| 81 |
+
comp = parts[7] # weight/scales/biases
|
| 82 |
+
|
| 83 |
+
if li not in expert_tensors:
|
| 84 |
+
expert_tensors[li] = {}
|
| 85 |
+
tname = f"switch_mlp.{proj}.{comp}"
|
| 86 |
+
expert_tensors[li][tname] = val
|
| 87 |
+
else:
|
| 88 |
+
pinned[key] = val
|
| 89 |
+
|
| 90 |
+
print(f"\n Expert layers: {len(expert_tensors)}")
|
| 91 |
+
print(f" Pinned keys: {len(pinned)}")
|
| 92 |
+
|
| 93 |
+
# Determine per-expert block layout from first layer
|
| 94 |
+
first_layer = expert_tensors[0]
|
| 95 |
+
tensor_layout = {}
|
| 96 |
+
inner_offset = 0
|
| 97 |
+
|
| 98 |
+
for tname in sorted(first_layer.keys()):
|
| 99 |
+
arr = first_layer[tname]
|
| 100 |
+
per_expert_shape = list(arr.shape[1:])
|
| 101 |
+
|
| 102 |
+
if arr.dtype == mx.uint32:
|
| 103 |
+
dtype_str = "uint32"
|
| 104 |
+
elem_size = 4
|
| 105 |
+
elif arr.dtype == mx.bfloat16:
|
| 106 |
+
dtype_str = "bfloat16"
|
| 107 |
+
elem_size = 2
|
| 108 |
+
elif arr.dtype == mx.float16:
|
| 109 |
+
dtype_str = "float16"
|
| 110 |
+
elem_size = 2
|
| 111 |
+
elif arr.dtype == mx.float32:
|
| 112 |
+
dtype_str = "float32"
|
| 113 |
+
elem_size = 4
|
| 114 |
+
else:
|
| 115 |
+
dtype_str = str(arr.dtype).replace("mlx.core.", "")
|
| 116 |
+
elem_size = 2
|
| 117 |
+
|
| 118 |
+
nbytes = elem_size
|
| 119 |
+
for d in per_expert_shape:
|
| 120 |
+
nbytes *= d
|
| 121 |
+
|
| 122 |
+
tensor_layout[tname] = {
|
| 123 |
+
"inner_offset": inner_offset,
|
| 124 |
+
"nbytes": nbytes,
|
| 125 |
+
"shape_per_expert": per_expert_shape,
|
| 126 |
+
"dtype": dtype_str,
|
| 127 |
+
}
|
| 128 |
+
inner_offset += nbytes
|
| 129 |
+
|
| 130 |
+
expert_block_size = inner_offset
|
| 131 |
+
data_start = PAGE_SIZE
|
| 132 |
|
| 133 |
+
print(f" Expert block: {expert_block_size} bytes ({expert_block_size/1024:.1f} KB)")
|
| 134 |
+
print()
|
| 135 |
+
|
| 136 |
+
# Write layer files
|
| 137 |
+
total_expert_bytes = 0
|
| 138 |
+
for layer_idx in range(NUM_LAYERS):
|
| 139 |
+
lt = time.time()
|
| 140 |
+
layer_data = expert_tensors[layer_idx]
|
| 141 |
+
|
| 142 |
+
header = {
|
| 143 |
+
"format": "expert_sniper_v1",
|
| 144 |
+
"model": "gemma4-26b-a4b",
|
| 145 |
+
"layer_idx": layer_idx,
|
| 146 |
+
"num_experts": NUM_EXPERTS,
|
| 147 |
+
"layout": {
|
| 148 |
+
"expert_block_size": expert_block_size,
|
| 149 |
+
"data_start": data_start,
|
| 150 |
+
"tensors": tensor_layout,
|
| 151 |
+
}
|
| 152 |
+
}
|
| 153 |
+
header_bytes = json.dumps(header, indent=2).encode("utf-8")
|
| 154 |
+
assert len(header_bytes) < PAGE_SIZE, f"Header too large: {len(header_bytes)}"
|
| 155 |
+
header_padded = header_bytes + b"\x00" * (PAGE_SIZE - len(header_bytes))
|
| 156 |
+
|
| 157 |
+
layer_path = os.path.join(OUTPUT_DIR, "bin", f"layer_{layer_idx:02d}.bin")
|
| 158 |
+
with open(layer_path, "wb") as f:
|
| 159 |
+
f.write(header_padded)
|
| 160 |
+
|
| 161 |
+
for eid in range(NUM_EXPERTS):
|
| 162 |
+
expert_data = bytearray()
|
| 163 |
+
for tname in sorted(tensor_layout.keys()):
|
| 164 |
+
stacked = layer_data[tname]
|
| 165 |
+
single = stacked[eid]
|
| 166 |
+
mx.eval(single)
|
| 167 |
+
|
| 168 |
+
if single.dtype == mx.uint32:
|
| 169 |
+
np_arr = np.array(single).view(np.uint32)
|
| 170 |
+
elif single.dtype == mx.bfloat16:
|
| 171 |
+
# Preserve bfloat16 bytes via uint16 view
|
| 172 |
+
np_arr = np.array(single.view(mx.uint16))
|
| 173 |
+
elif single.dtype == mx.float32:
|
| 174 |
+
np_arr = np.array(single).view(np.float32)
|
| 175 |
+
elif single.dtype == mx.float16:
|
| 176 |
+
np_arr = np.array(single.view(mx.uint16))
|
| 177 |
+
else:
|
| 178 |
+
np_arr = np.array(single)
|
| 179 |
+
expert_data.extend(np_arr.tobytes())
|
| 180 |
+
|
| 181 |
+
if len(expert_data) < expert_block_size:
|
| 182 |
+
expert_data.extend(b"\x00" * (expert_block_size - len(expert_data)))
|
| 183 |
+
f.write(bytes(expert_data[:expert_block_size]))
|
| 184 |
+
|
| 185 |
+
file_size = os.path.getsize(layer_path)
|
| 186 |
+
total_expert_bytes += file_size
|
| 187 |
+
elapsed = time.time() - lt
|
| 188 |
+
print(f" Layer {layer_idx:2d}/{NUM_LAYERS}: {file_size/1e6:.1f} MB ({elapsed:.0f}s)")
|
| 189 |
+
|
| 190 |
+
del expert_tensors[layer_idx]
|
| 191 |
+
gc.collect()
|
| 192 |
|
| 193 |
# Save pinned
|
| 194 |
+
pinned_path = os.path.join(OUTPUT_DIR, "pinned.safetensors")
|
| 195 |
+
mx.save_safetensors(pinned_path, pinned)
|
| 196 |
pinned_bytes = sum(v.nbytes for v in pinned.values())
|
| 197 |
+
print(f"\nSaved pinned.safetensors: {pinned_bytes/1e9:.2f} GB ({len(pinned)} keys)")
|
| 198 |
+
|
| 199 |
+
# Streaming config
|
| 200 |
+
stream_config = dict(tc)
|
| 201 |
+
stream_config["quantization"] = config.get("quantization", {"bits": 4, "group_size": 64})
|
| 202 |
+
stream_config["streaming"] = {"pinned_file": "pinned.safetensors", "expert_dir": "bin"}
|
| 203 |
+
with open(os.path.join(OUTPUT_DIR, "config.json"), "w") as f:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
json.dump(stream_config, f, indent=2)
|
| 205 |
|
| 206 |
+
# Copy tokenizer files
|
| 207 |
import shutil
|
| 208 |
+
for tf in ["tokenizer.json", "tokenizer_config.json", "chat_template.jinja",
|
| 209 |
+
"generation_config.json", "processor_config.json"]:
|
| 210 |
+
src = os.path.join(INPUT_DIR, tf)
|
| 211 |
if os.path.exists(src):
|
| 212 |
+
shutil.copy(src, os.path.join(OUTPUT_DIR, tf))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
elapsed = time.time() - t0
|
| 215 |
+
print(f"\nDone in {elapsed:.0f}s!")
|
| 216 |
+
print(f"Pinned: {pinned_bytes/1e9:.2f} GB, Experts: {total_expert_bytes/1e9:.2f} GB")
|
| 217 |
+
return True
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
if __name__ == "__main__":
|
| 221 |
+
import argparse
|
| 222 |
+
parser = argparse.ArgumentParser()
|
| 223 |
+
parser.add_argument("--input", "-i", required=True)
|
| 224 |
+
parser.add_argument("--output", "-o", required=True)
|
| 225 |
+
args = parser.parse_args()
|
| 226 |
+
preprocess_gemma4(args.input, args.output)
|
|
@@ -45,8 +45,29 @@ class SniperEngine:
|
|
| 45 |
self._loaded = False
|
| 46 |
|
| 47 |
@classmethod
|
| 48 |
-
def from_dir(cls, sniper_dir: str, **overrides)
|
| 49 |
-
"""Create engine from a sniper directory.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
config = SniperConfig.from_dir(sniper_dir, **overrides)
|
| 51 |
engine = cls(config)
|
| 52 |
engine.load()
|
|
|
|
| 45 |
self._loaded = False
|
| 46 |
|
| 47 |
@classmethod
|
| 48 |
+
def from_dir(cls, sniper_dir: str, **overrides):
|
| 49 |
+
"""Create engine from a sniper directory.
|
| 50 |
+
|
| 51 |
+
Auto-detects model type and dispatches to the right engine:
|
| 52 |
+
- Qwen 3.x MoE → SniperEngine (this class)
|
| 53 |
+
- Gemma 4 → MoESniperEngineGemma4 (different architecture)
|
| 54 |
+
"""
|
| 55 |
+
# Peek at config to detect model type
|
| 56 |
+
with open(os.path.join(sniper_dir, "config.json")) as f:
|
| 57 |
+
cfg_data = json.load(f)
|
| 58 |
+
model_type = cfg_data.get("model_type", "")
|
| 59 |
+
|
| 60 |
+
if model_type.startswith("gemma4"):
|
| 61 |
+
from .engine_gemma4 import MoESniperEngineGemma4
|
| 62 |
+
cache_size = overrides.get("max_cached_experts", 4000)
|
| 63 |
+
engine = MoESniperEngineGemma4(
|
| 64 |
+
model_dir=sniper_dir,
|
| 65 |
+
cache_size=cache_size,
|
| 66 |
+
)
|
| 67 |
+
engine.load()
|
| 68 |
+
return engine
|
| 69 |
+
|
| 70 |
+
# Default: Qwen path
|
| 71 |
config = SniperConfig.from_dir(sniper_dir, **overrides)
|
| 72 |
engine = cls(config)
|
| 73 |
engine.load()
|