Nico Claude Opus 4.6 (1M context) commited on
Commit
3f56a7b
·
1 Parent(s): cc1d5e2

Add Gemma 4-26B-A4B support: 4.15 tok/s on M4 Mac Mini

Browse files

Replace placeholder Gemma 4 engine with working implementation:
- Custom forward pass for Gemma 4 architecture (sliding/full attention,
layer scalars, dual layernorms, dense MLP + MoE in parallel)
- Mixed quantization handling (experts 4-bit, dense MLP 8-bit)
- Cache-aware routing bias=1.5 (steers router toward cached experts)
- Gemma 4 chat template encoder (turn_start/turn_end tokens 105/106)
- gelu_approx activation in expert FFN

Replace preprocess_gemma4.py with SwitchLinear unstack:
- Loads mlx-community/gemma-4-26b-a4b-it-4bit (15.6 GB) instead of
bf16 source
- Unstacks (128, out, in) experts into per-expert bin blocks
- Preserves bfloat16 bytes via uint16 view (no precision loss)

Wire up auto-dispatch in:
- sniper.py (SniperEngine.from_dir auto-detects gemma4 model_type)
- generate.py (generate_stream forks to _gemma4_generate_stream)
- calibrate.py (_build_engine handles gemma4 path)

Update download registry: gemma4-26b now points to mlx-community
4-bit version (15.6 GB) instead of Google bf16 (50 GB).

Update README with verified 4.15 tok/s benchmark and memory
bandwidth scaling table for M2 → M2 Ultra.

Verified end-to-end on M4 Mac Mini 16 GB:
- 4.15 tok/s sustained
- 95.8% cache hit rate
- 7.8 GB RAM
- Coherent output (math, code, explanations)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

README.md CHANGED
@@ -20,6 +20,7 @@ Run MoE models larger than your RAM on Apple Silicon.
20
  |-------|------|---------|-----------------|--------------|-----------|-----|
21
  | Qwen3.5-35B-A3B | 19.5 GB | 256/layer | OOM | **5.37 tok/s** | 92.0% | 8.7 GB |
22
  | Qwen3-30B-A3B | 17.2 GB | 128/layer | OOM | **4.29 tok/s** | 90.4% | 8.7 GB |
 
23
 
24
  All benchmarks: M4 Mac Mini 16 GB, 5 varied prompts, greedy decoding.
25
 
@@ -32,15 +33,36 @@ All benchmarks: M4 Mac Mini 16 GB, 5 varied prompts, greedy decoding.
32
 
33
  **30B**: right-sized LRU + co-activation prefetch. REAP/bias not yet applied.
34
 
 
 
 
 
 
 
35
  ## Supported Models
36
 
37
  | Model | Size | Experts | tok/s (M4 16GB) | Status |
38
  |-------|------|---------|-----------------|--------|
39
- | Qwen3.5-35B-A3B | 19.5 GB | 256/layer | 5.4 tok/s | Verified |
40
- | Qwen3-30B-A3B | 17.2 GB | 128/layer | 3.3 tok/s | Verified |
 
41
 
42
  More models coming. To request a model, open an issue on [GitHub](https://github.com/walter-grace/mac-code).
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  ### Hardware Requirements
45
 
46
  | Mac | RAM | What you can run |
 
20
  |-------|------|---------|-----------------|--------------|-----------|-----|
21
  | Qwen3.5-35B-A3B | 19.5 GB | 256/layer | OOM | **5.37 tok/s** | 92.0% | 8.7 GB |
22
  | Qwen3-30B-A3B | 17.2 GB | 128/layer | OOM | **4.29 tok/s** | 90.4% | 8.7 GB |
23
+ | **Gemma 4-26B-A4B** | 15.6 GB | 128/layer | OOM | **4.15 tok/s** | 95.8% | 7.8 GB |
24
 
25
  All benchmarks: M4 Mac Mini 16 GB, 5 varied prompts, greedy decoding.
26
 
 
33
 
34
  **30B**: right-sized LRU + co-activation prefetch. REAP/bias not yet applied.
35
 
36
+ **Gemma 4-26B-A4B** (NEW):
37
+ - Custom Gemma 4 model class (sliding/full attention hybrid, layer scalars, dual layernorms)
38
+ - Mixed quantization: experts 4-bit, dense MLP and router 8-bit (matches mlx-community format)
39
+ - Cache-aware routing bias=1.5 + co-activation prefetch (95.8% hit rate)
40
+ - Source: `mlx-community/gemma-4-26b-a4b-it-4bit`
41
+
42
  ## Supported Models
43
 
44
  | Model | Size | Experts | tok/s (M4 16GB) | Status |
45
  |-------|------|---------|-----------------|--------|
46
+ | Qwen3.5-35B-A3B | 19.5 GB | 256/layer | 5.37 tok/s | Verified |
47
+ | Qwen3-30B-A3B | 17.2 GB | 128/layer | 4.29 tok/s | Verified |
48
+ | **Gemma 4-26B-A4B** | 15.6 GB | 128/layer | **4.15 tok/s** | Verified |
49
 
50
  More models coming. To request a model, open an issue on [GitHub](https://github.com/walter-grace/mac-code).
51
 
52
+ ### Memory Bandwidth Scaling
53
+
54
+ MoE inference is bandwidth-bound. Expected speeds on different Apple Silicon Macs:
55
+
56
+ | Mac | Memory BW | Qwen 35B est. | Gemma 4-26B est. |
57
+ |-----|-----------|---------------|------------------|
58
+ | M2 Mac Mini | 100 GB/s | ~4.5 tok/s | ~3.5 tok/s |
59
+ | **M4 Mac Mini** | **120 GB/s** | **5.37 tok/s** ✓ | **4.15 tok/s** ✓ |
60
+ | M2 Pro Mac Mini | 200 GB/s | ~8-10 tok/s | ~7-8 tok/s |
61
+ | M4 Pro Mac Mini | 273 GB/s | ~12-14 tok/s | ~10-11 tok/s |
62
+ | M2 Max Studio | 400 GB/s | ~16-20 tok/s | ~14-17 tok/s |
63
+ | M4 Max MacBook Pro | 546 GB/s | ~22-28 tok/s | ~18-23 tok/s |
64
+ | M2 Ultra Studio | 800 GB/s | ~30-40 tok/s | ~25-32 tok/s |
65
+
66
  ### Hardware Requirements
67
 
68
  | Mac | RAM | What you can run |
src/mlx_expert_sniper/calibrate.py CHANGED
@@ -63,7 +63,16 @@ def _detect_model_type(model_dir):
63
  def _build_engine(model_dir, cache_size):
64
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
65
  model_type = _detect_model_type(model_dir)
66
- if "qwen3_next" in model_type:
 
 
 
 
 
 
 
 
 
67
  from .engine_next import MoESniperEngineNext as EngineClass
68
  from . import engine_next as engine_mod
69
  engine_mod.MODEL_DIR = model_dir
 
63
  def _build_engine(model_dir, cache_size):
64
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
65
  model_type = _detect_model_type(model_dir)
66
+ if "gemma4" in model_type:
67
+ from .engine_gemma4 import MoESniperEngineGemma4
68
+ engine = MoESniperEngineGemma4(
69
+ model_dir=model_dir,
70
+ cache_size=cache_size,
71
+ enable_prediction=False,
72
+ )
73
+ engine.load()
74
+ return engine
75
+ elif "qwen3_next" in model_type:
76
  from .engine_next import MoESniperEngineNext as EngineClass
77
  from . import engine_next as engine_mod
78
  engine_mod.MODEL_DIR = model_dir
src/mlx_expert_sniper/download.py CHANGED
@@ -45,11 +45,11 @@ MODEL_REGISTRY = {
45
  "default_dir": "qwen3-235b-stream",
46
  "description": "Qwen3-235B-A22B 4-bit (~130 GB, 128 experts, needs 64+ GB RAM)",
47
  },
48
- # Gemma 4 (Google) — NEW ARCHITECTURE
49
  "gemma4-26b": {
50
- "repo": "google/gemma-4-26B-A4B-it",
51
  "default_dir": "gemma4-26b-stream",
52
- "description": "Gemma 4-26B-A4B bf16 (~50 GB, 128 experts, Google MoEEXPERIMENTAL)",
53
  "preprocess": "gemma4",
54
  },
55
  }
 
45
  "default_dir": "qwen3-235b-stream",
46
  "description": "Qwen3-235B-A22B 4-bit (~130 GB, 128 experts, needs 64+ GB RAM)",
47
  },
48
+ # Gemma 4 (Google) — 4.15 tok/s on M4 Mac Mini
49
  "gemma4-26b": {
50
+ "repo": "mlx-community/gemma-4-26b-a4b-it-4bit",
51
  "default_dir": "gemma4-26b-stream",
52
+ "description": "Gemma 4-26B-A4B 4-bit (~15.6 GB, 128 experts, mixed quantVerified 4.15 tok/s on M4)",
53
  "preprocess": "gemma4",
54
  },
55
  }
src/mlx_expert_sniper/engine_gemma4.py CHANGED
@@ -3,126 +3,356 @@
3
  MoE Sniper engine for Gemma 4-26B-A4B.
4
 
5
  Architecture differences from Qwen:
 
 
6
  - Dense MLP runs on every token (always), MoE adds on top
7
- - Fused gate_up_proj per expert (split in half for gate/up)
8
- - Router: norm scale → proj → softmax → top_k → per_expert_scale
9
- - Extra layernorms: post_feedforward_layernorm_1, pre/post_feedforward_layernorm_2
10
- - layer_scalar: per-layer output scaling
11
- - Sliding window attention on most layers, full attention every 6th
12
- - gelu_pytorch_tanh activation (not silu)
13
- - K=V sharing (attention_k_eq_v)
14
  """
15
- import json, sys, os, time, gc
 
 
 
 
 
16
  import numpy as np
17
  import mlx.core as mx
18
  import mlx.nn as nn
19
  from mlx.utils import tree_flatten
 
20
  from .expert_io import MoEExpertReader
21
  from .coactivation import CoActivationTracker
22
 
23
- MODEL_DIR = "" # Set before load()
24
- BITS = 4
25
  GROUP_SIZE = 64
26
 
27
 
28
- def gelu_tanh(x):
29
- """GELU with tanh approximation (matches PyTorch's gelu_pytorch_tanh)."""
30
- return 0.5 * x * (1 + mx.tanh(0.7978845608 * (x + 0.044715 * x * x * x)))
31
-
32
-
33
- def run_expert_ffn_gemma4(x, expert_data, top_k_indices, top_k_weights,
34
- num_experts_total=128, hidden_size=2816, moe_inter=704):
35
- """
36
- Gemma 4 expert FFN. Experts have fused gate_up_proj.
37
 
38
- expert_data[eid] has:
39
- 'experts.gate_up_proj': [2*moe_inter, hidden_size] bf16
40
- 'experts.down_proj': [hidden_size, moe_inter] bf16
41
  """
42
- # For now: per-expert loop (not batched gather_qmm since experts are bf16)
43
- batch_shape = x.shape[:-1]
44
- x_flat = x.reshape(-1, x.shape[-1]) # [B*T, H]
45
 
46
- inds_np = np.array(top_k_indices).reshape(-1, top_k_indices.shape[-1]) # [B*T, K]
47
- weights_np = np.array(top_k_weights.astype(mx.float32)).reshape(-1, top_k_weights.shape[-1])
 
 
48
 
49
- output = mx.zeros_like(x_flat)
 
 
 
 
50
 
51
- for token_idx in range(x_flat.shape[0]):
52
- token_out = mx.zeros((x_flat.shape[1],))
53
- for k_idx in range(inds_np.shape[1]):
54
- eid = int(inds_np[token_idx, k_idx])
55
- w = float(weights_np[token_idx, k_idx])
56
 
57
- if eid not in expert_data:
58
- continue
59
 
60
- ed = expert_data[eid]
61
- gate_up = ed["experts.gate_up_proj"].astype(mx.float16) # [2*inter, hidden]
62
- down = ed["experts.down_proj"].astype(mx.float16) # [hidden, inter]
 
 
 
 
63
 
64
- token_vec = x_flat[token_idx].astype(mx.float16)
 
 
 
65
 
66
- # gate_up @ token → [2*inter], then split
67
- gu = gate_up @ token_vec # [2*inter]
68
- gate, up = mx.split(gu, 2)
69
- h = gelu_tanh(gate) * up
70
 
71
- # down @ h [hidden]
72
- out = down @ h
73
- token_out = token_out + out.astype(mx.float32) * w
 
 
74
 
75
- output = output.at[token_idx].add(token_out)
76
 
77
- mx.eval(output)
78
- return output.reshape(*batch_shape, -1)
79
 
 
 
 
80
 
81
- class MoESniperEngineGemma4:
82
- def __init__(self, cache_size=3000, enable_prediction=True):
 
83
  self.model = None
84
  self.reader = None
85
  self.tokenizer = None
86
  self.cache = None
87
- self.num_layers = 30
88
- self.coact = None
89
  self._cache_size = cache_size
90
  self._enable_prediction = enable_prediction
 
 
 
91
 
92
  def load(self):
93
- """Load Gemma 4 model.
94
-
95
- NOTE: This is a PLACEHOLDER. Gemma 4 (gemma4) is not yet in mlx-lm.
96
- Once mlx-lm adds gemma4 support, this will use their Model class.
97
- For now, this demonstrates the architecture and expert streaming.
98
- """
99
- with open(os.path.join(MODEL_DIR, "config.json")) as f:
100
  config = json.load(f)
101
- tc = config.get("text_config", config)
102
- self.num_layers = tc["num_hidden_layers"]
103
- self.num_experts = tc["num_experts"]
104
- self.top_k = tc["top_k_experts"]
105
- self.hidden_size = tc["hidden_size"]
106
- self.moe_inter = tc["moe_intermediate_size"]
107
-
108
- streaming = config.get("streaming", {})
109
- expert_dir = os.path.join(MODEL_DIR, streaming.get("expert_dir", "bin"))
110
- self.reader = MoEExpertReader(expert_dir, self.num_layers,
111
- num_workers=8, cache_size=self._cache_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  self.coact = CoActivationTracker(self.num_layers, warmup_tokens=3)
113
 
114
- # TODO: Load model architecture once mlx-lm supports gemma4
115
- # For now, we can test expert streaming and I/O patterns
116
- # without the full model by loading pinned weights manually
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
- from transformers import AutoTokenizer
119
- self.tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True)
120
 
121
- print(f"Gemma 4 engine loaded (placeholder)")
122
- print(f" Layers: {self.num_layers}, Experts: {self.num_experts}, Top-k: {self.top_k}")
123
- print(f" Hidden: {self.hidden_size}, MoE inter: {self.moe_inter}")
124
- print(f" NOTE: Full inference requires mlx-lm gemma4 support")
125
- return 0.0
 
 
 
126
 
127
  def reset_cache(self):
128
- self.cache = [None] * self.num_layers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  MoE Sniper engine for Gemma 4-26B-A4B.
4
 
5
  Architecture differences from Qwen:
6
+ - 30 layers (vs 40)
7
+ - 128 experts, top-8 (vs 256, top-8)
8
  - Dense MLP runs on every token (always), MoE adds on top
9
+ - Router: inline RMS norm + scale + per_expert_scale
10
+ - gelu_approx activation (not silu)
11
+ - Layer scalar per layer
12
+ - Sliding window + full attention hybrid
13
+ - Mixed quantization: experts 4-bit, dense MLP/router 8-bit
 
 
14
  """
15
+ import json
16
+ import os
17
+ import sys
18
+ import time
19
+ import gc
20
+
21
  import numpy as np
22
  import mlx.core as mx
23
  import mlx.nn as nn
24
  from mlx.utils import tree_flatten
25
+
26
  from .expert_io import MoEExpertReader
27
  from .coactivation import CoActivationTracker
28
 
 
 
29
  GROUP_SIZE = 64
30
 
31
 
32
+ def run_expert_ffn_gemma4(x, expert_data, top_k_indices, top_k_weights):
33
+ """Run expert FFN using gather_qmm with streamed Gemma 4 expert weights.
 
 
 
 
 
 
 
34
 
35
+ Expert tensor names: switch_mlp.{gate,up,down}_proj.{weight,scales,biases}
36
+ Activation: gelu_approx (not silu like Qwen)
 
37
  """
38
+ active_ids = sorted(expert_data.keys())
39
+ if not active_ids:
40
+ return mx.zeros_like(x)
41
 
42
+ id_to_local = {eid: i for i, eid in enumerate(active_ids)}
43
+ inds_np = np.array(top_k_indices)
44
+ local_np = np.vectorize(lambda v: id_to_local.get(int(v), 0))(inds_np)
45
+ local_indices = mx.array(local_np)
46
 
47
+ def stack_proj(proj):
48
+ w = mx.stack([expert_data[eid][f"switch_mlp.{proj}.weight"] for eid in active_ids])
49
+ s = mx.stack([expert_data[eid][f"switch_mlp.{proj}.scales"] for eid in active_ids])
50
+ b = mx.stack([expert_data[eid][f"switch_mlp.{proj}.biases"] for eid in active_ids])
51
+ return w, s, b
52
 
53
+ gate_w, gate_s, gate_b = stack_proj("gate_proj")
54
+ up_w, up_s, up_b = stack_proj("up_proj")
55
+ down_w, down_s, down_b = stack_proj("down_proj")
 
 
56
 
57
+ x_exp = mx.expand_dims(x, (-2, -3))
 
58
 
59
+ # Auto-detect bits from weight vs scales shape
60
+ n_packed = gate_w.shape[-1]
61
+ n_groups = gate_s.shape[-1]
62
+ real_input = n_groups * GROUP_SIZE
63
+ bits = round(32 * n_packed / real_input)
64
+ if bits not in (4, 8):
65
+ bits = 4
66
 
67
+ gate_out = mx.gather_qmm(x_exp, gate_w, scales=gate_s, biases=gate_b,
68
+ rhs_indices=local_indices, transpose=True, group_size=GROUP_SIZE, bits=bits)
69
+ up_out = mx.gather_qmm(x_exp, up_w, scales=up_s, biases=up_b,
70
+ rhs_indices=local_indices, transpose=True, group_size=GROUP_SIZE, bits=bits)
71
 
72
+ # Gemma 4 uses gelu_approx
73
+ hidden = nn.gelu_approx(gate_out) * up_out
 
 
74
 
75
+ down_out = mx.gather_qmm(hidden, down_w, scales=down_s, biases=down_b,
76
+ rhs_indices=local_indices, transpose=True, group_size=GROUP_SIZE, bits=bits)
77
+ out = down_out.squeeze(-2)
78
+ out = (out * top_k_weights[..., None]).sum(axis=-2)
79
+ return out
80
 
 
81
 
82
+ class MoESniperEngineGemma4:
83
+ """Single-machine MoE Sniper for Gemma 4-26B-A4B with SSD expert streaming.
84
 
85
+ Verified results on M4 Mac Mini 16 GB:
86
+ 4.15 tok/s, 95.8% cache hit, 7.8 GB RAM
87
+ """
88
 
89
+ def __init__(self, model_dir, cache_size=4000, enable_prediction=True,
90
+ routing_bias=1.5):
91
+ self.model_dir = os.path.expanduser(model_dir)
92
  self.model = None
93
  self.reader = None
94
  self.tokenizer = None
95
  self.cache = None
 
 
96
  self._cache_size = cache_size
97
  self._enable_prediction = enable_prediction
98
+ self.routing_bias = routing_bias
99
+ self.num_layers = 30
100
+ self.coact = None
101
 
102
  def load(self):
103
+ config_path = os.path.join(self.model_dir, "config.json")
104
+ with open(config_path) as f:
 
 
 
 
 
105
  config = json.load(f)
106
+
107
+ text_config = config.get("text_config", config)
108
+ self.num_layers = text_config["num_hidden_layers"]
109
+
110
+ # Import Gemma 4 model class from this package
111
+ from .models.gemma4 import Model, ModelArgs
112
+
113
+ args = ModelArgs.from_dict(text_config)
114
+ self.model = Model(args)
115
+
116
+ # Mixed quantization handling
117
+ quant_config = config.get("quantization", config.get("quantization_config", {}))
118
+ default_bits = quant_config.get("bits", 4)
119
+ default_gs = quant_config.get("group_size", GROUP_SIZE)
120
+
121
+ def _is_8bit(path, module):
122
+ if not isinstance(module, nn.Linear):
123
+ return False
124
+ full_path = "language_model." + path
125
+ if full_path in quant_config and isinstance(quant_config[full_path], dict):
126
+ return quant_config[full_path].get("bits", default_bits) == 8
127
+ return False
128
+
129
+ def _q4(path, module):
130
+ if isinstance(module, nn.Embedding):
131
+ return True
132
+ if not isinstance(module, nn.Linear):
133
+ return False
134
+ if _is_8bit(path, module):
135
+ return False
136
+ if module.weight.shape[-1] < default_gs:
137
+ return False
138
+ return True
139
+
140
+ nn.quantize(self.model, group_size=default_gs, bits=default_bits,
141
+ class_predicate=_q4)
142
+ nn.quantize(self.model, group_size=64, bits=8,
143
+ class_predicate=lambda p, m: isinstance(m, nn.Linear) and _is_8bit(p, m))
144
+
145
+ mx.set_memory_limit(14 * 1024**3)
146
+ mx.set_cache_limit(512 * 1024**2)
147
+
148
+ # Load pinned weights
149
+ pinned_path = os.path.join(self.model_dir, "pinned.safetensors")
150
+ pinned = mx.load(pinned_path)
151
+ stripped = [(k.replace("language_model.", "", 1), v) for k, v in pinned.items()]
152
+ self.model.load_weights(stripped, strict=False)
153
+
154
+ # Eval only non-expert params
155
+ params = [p for name, p in tree_flatten(self.model.parameters())
156
+ if "expert" not in name and "switch" not in name]
157
+ mx.eval(*params)
158
+ del pinned
159
+ gc.collect()
160
+ mx.clear_cache()
161
+
162
+ pinned_gb = sum(p.nbytes for p in params) / 1e9
163
+
164
+ # Expert reader (F_NOCACHE + pread)
165
+ sniper_config_path = os.path.join(self.model_dir, "sniper_config.json")
166
+ if os.path.exists(sniper_config_path):
167
+ with open(sniper_config_path) as f:
168
+ sc = json.load(f)
169
+ expert_dir = os.path.join(self.model_dir, sc.get("streaming", {}).get("expert_dir", "bin"))
170
+ else:
171
+ expert_dir = os.path.join(self.model_dir, "bin")
172
+
173
+ self.reader = MoEExpertReader(
174
+ expert_dir, self.num_layers,
175
+ num_workers=8, cache_size=self._cache_size
176
+ )
177
  self.coact = CoActivationTracker(self.num_layers, warmup_tokens=3)
178
 
179
+ # Tokenizer (prefer fast tokenizers)
180
+ from tokenizers import Tokenizer
181
+ tok_path = os.path.join(self.model_dir, "tokenizer.json")
182
+ if os.path.exists(tok_path):
183
+ self.tokenizer = Tokenizer.from_file(tok_path)
184
+ self._fast_tok = True
185
+ else:
186
+ from transformers import AutoTokenizer
187
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
188
+ self._fast_tok = False
189
+
190
+ self.cache = self.model.make_cache()
191
+
192
+ print(f"Gemma 4 Sniper loaded: {pinned_gb:.1f} GB pinned, "
193
+ f"cache={self._cache_size}, layers={self.num_layers}")
194
+ return pinned_gb
195
+
196
+ def encode(self, text):
197
+ if self._fast_tok:
198
+ return self.tokenizer.encode(text).ids
199
+ return self.tokenizer.encode(text)
200
+
201
+ def decode(self, ids):
202
+ if self._fast_tok:
203
+ return self.tokenizer.decode(ids)
204
+ return self.tokenizer.decode(ids, skip_special_tokens=False)
205
 
206
+ def encode_chat(self, prompt):
207
+ """Encode prompt with Gemma 4 chat template.
208
 
209
+ Format: <bos><|turn>user\\n{prompt}<turn|>\\n<|turn>model\\n
210
+ Token IDs: bos=2, turn_start=105, turn_end=106, newline=107
211
+ """
212
+ NL = chr(10)
213
+ prompt_toks = self.encode(prompt)
214
+ user_toks = self.encode("user" + NL)
215
+ model_toks = self.encode("model" + NL)
216
+ return [2, 105] + user_toks + prompt_toks + [106, 107, 105] + model_toks
217
 
218
  def reset_cache(self):
219
+ self.cache = self.model.make_cache()
220
+
221
+ def forward(self, input_ids):
222
+ """Forward pass with SSD-streamed experts."""
223
+ from mlx_lm.models.base import create_attention_mask
224
+
225
+ h = self.model.model.embed_tokens(input_ids)
226
+ h = h * (self.model.args.hidden_size ** 0.5)
227
+
228
+ mask = create_attention_mask(h, self.cache[0] if self.cache else None)
229
+
230
+ for i in range(self.num_layers):
231
+ layer = self.model.model.layers[i]
232
+ cache_i = self.cache[i] if self.cache else None
233
+
234
+ # Attention
235
+ residual = h
236
+ h_norm = layer.input_layernorm(h)
237
+ h_attn = layer.self_attn(h_norm, mask=mask, cache=cache_i)
238
+ h_attn = layer.post_attention_layernorm(h_attn)
239
+ h = residual + h_attn
240
+ mx.eval(h)
241
+
242
+ # Dense MLP (always)
243
+ residual = h
244
+ h_ff = layer.pre_feedforward_layernorm(h)
245
+ h_ff = layer.mlp(h_ff)
246
+
247
+ expert_data = {}
248
+ expert_out = None
249
+ moe_input = None
250
+
251
+ if layer.enable_moe_block:
252
+ h_dense = layer.post_feedforward_layernorm_1(h_ff)
253
+
254
+ # Router with cache-aware bias
255
+ B, L, D = residual.shape
256
+ residual_flat = residual.reshape(-1, D)
257
+ router = layer.router
258
+ x_normed = router._inline_rms_norm(residual_flat)
259
+ x_normed = x_normed * router.scale * (router.hidden_size ** -0.5)
260
+ scores = router.proj(x_normed)
261
+
262
+ if self.routing_bias > 0 and self.reader.lru is not None:
263
+ bias_np = np.zeros(scores.shape[-1], dtype=np.float32)
264
+ for (li, eid) in self.reader.lru.cache.keys():
265
+ if li == i:
266
+ bias_np[eid] = self.routing_bias
267
+ if bias_np.any():
268
+ scores = scores + mx.array(bias_np)
269
+
270
+ probs = mx.softmax(scores, axis=-1)
271
+ top_k_indices = mx.argpartition(-probs, kth=router.top_k - 1, axis=-1)[..., :router.top_k]
272
+ top_k_weights = mx.take_along_axis(probs, top_k_indices, axis=-1)
273
+ top_k_weights = top_k_weights / mx.sum(top_k_weights, axis=-1, keepdims=True)
274
+ expert_scales = router.per_expert_scale[top_k_indices]
275
+ top_k_weights = top_k_weights * expert_scales
276
+
277
+ moe_input = layer.pre_feedforward_layernorm_2(residual_flat)
278
+ mx.eval(moe_input, top_k_indices, top_k_weights)
279
+
280
+ top_k_indices_r = top_k_indices.reshape(B, L, -1)
281
+ top_k_weights_r = top_k_weights.reshape(B, L, -1)
282
+
283
+ active_ids = list(set(int(e) for e in np.array(top_k_indices_r).flatten()))
284
+ self.coact.record_layer(i, active_ids)
285
+
286
+ # Predictive prefetch
287
+ if self._enable_prediction and self.coact.ready and i + 1 < self.num_layers:
288
+ predicted = self.coact.predict_next_layer(i, active_ids, top_k=6)
289
+ if predicted:
290
+ to_fetch = [eid for eid in predicted
291
+ if self.reader.lru and self.reader.lru.get(i + 1, eid) is None]
292
+ if to_fetch:
293
+ self.reader.prefetch_experts(i + 1, to_fetch)
294
+
295
+ if i + 1 < self.num_layers:
296
+ self.reader.prefetch_experts(i + 1, active_ids)
297
+
298
+ # Expert FFN from SSD
299
+ expert_data = self.reader.get_experts(i, active_ids)
300
+ moe_input_r = moe_input.reshape(B, L, D)
301
+ expert_out = run_expert_ffn_gemma4(moe_input_r, expert_data,
302
+ top_k_indices_r, top_k_weights_r)
303
+ h_moe = layer.post_feedforward_layernorm_2(expert_out)
304
+ h_ff = h_dense + h_moe
305
+
306
+ # Final norm + residual + scalar
307
+ h_ff = layer.post_feedforward_layernorm(h_ff)
308
+ h = residual + h_ff
309
+ h = h * layer.layer_scalar
310
+ mx.eval(h)
311
+
312
+ del expert_data, expert_out, moe_input
313
+ mx.clear_cache()
314
+
315
+ self.coact.end_token()
316
+ h = self.model.model.norm(h)
317
+
318
+ if self.model.args.tie_word_embeddings:
319
+ return self.model.model.embed_tokens.as_linear(h)
320
+ else:
321
+ return self.model.lm_head(h)
322
+
323
+ def generate(self, prompt, max_tokens=200, temperature=0.7):
324
+ """Generate text from a prompt with chat template + EOS detection."""
325
+ tokens = self.encode_chat(prompt)
326
+ input_ids = mx.array([tokens])
327
+
328
+ # Prefill
329
+ logits = self.forward(input_ids)
330
+ mx.eval(logits)
331
+
332
+ if temperature <= 0:
333
+ next_token = int(mx.argmax(logits[0, -1]).item())
334
+ else:
335
+ probs = mx.softmax(logits[0, -1] / temperature, axis=-1)
336
+ next_token = int(mx.random.categorical(mx.log(probs + 1e-10)).item())
337
+
338
+ generated = [next_token]
339
+ input_ids = mx.array([[next_token]])
340
+
341
+ for step in range(max_tokens - 1):
342
+ logits = self.forward(input_ids)
343
+ mx.eval(logits)
344
+
345
+ if temperature <= 0:
346
+ next_token = int(mx.argmax(logits[0, -1]).item())
347
+ else:
348
+ probs = mx.softmax(logits[0, -1] / temperature, axis=-1)
349
+ next_token = int(mx.random.categorical(mx.log(probs + 1e-10)).item())
350
+
351
+ generated.append(next_token)
352
+ input_ids = mx.array([[next_token]])
353
+
354
+ # EOS: <eos>=1, <turn|>=106
355
+ if next_token in [1, 106]:
356
+ break
357
+
358
+ return self.decode(generated)
src/mlx_expert_sniper/generate.py CHANGED
@@ -20,7 +20,17 @@ def load_engine(model_dir):
20
  bias = 0.0
21
 
22
  model_type = _detect_model_type(model_dir)
23
- if "qwen3_next" in model_type:
 
 
 
 
 
 
 
 
 
 
24
  from . import engine_next as engine_mod
25
  engine_mod.MODEL_DIR = model_dir
26
  from .engine_next import MoESniperEngineNext as EngineClass
@@ -39,8 +49,19 @@ def load_engine(model_dir):
39
 
40
 
41
  def generate_stream(engine, messages, bias=0.0, max_tokens=200):
42
- """Generator yielding token strings. Handles both 35B (SSM) and 30B (standard attention)."""
 
 
 
 
 
43
  import mlx.core as mx
 
 
 
 
 
 
44
  from mlx_lm.models.base import create_attention_mask
45
  from .engine import run_expert_ffn
46
 
@@ -150,3 +171,50 @@ def generate_stream(engine, messages, bias=0.0, max_tokens=200):
150
  yield chunk
151
  logits = forward(token.reshape(1, 1))
152
  mx.eval(logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  bias = 0.0
21
 
22
  model_type = _detect_model_type(model_dir)
23
+ if "gemma4" in model_type:
24
+ # Gemma 4 has its own engine due to architectural differences
25
+ from .engine_gemma4 import MoESniperEngineGemma4
26
+ eng = MoESniperEngineGemma4(
27
+ model_dir=model_dir,
28
+ cache_size=cache_size,
29
+ enable_prediction=True,
30
+ )
31
+ eng.load()
32
+ return eng, bias, model_type
33
+ elif "qwen3_next" in model_type:
34
  from . import engine_next as engine_mod
35
  engine_mod.MODEL_DIR = model_dir
36
  from .engine_next import MoESniperEngineNext as EngineClass
 
49
 
50
 
51
  def generate_stream(engine, messages, bias=0.0, max_tokens=200):
52
+ """Generator yielding token strings.
53
+
54
+ Dispatches to model-specific forward path:
55
+ - Gemma 4 → uses engine's own forward() (different architecture)
56
+ - Qwen 3.x → inline Qwen forward (SSM hybrid or standard)
57
+ """
58
  import mlx.core as mx
59
+
60
+ # Gemma 4 has its own engine with a built-in forward + chat template
61
+ if engine.__class__.__name__ == "MoESniperEngineGemma4":
62
+ yield from _gemma4_generate_stream(engine, messages, max_tokens=max_tokens)
63
+ return
64
+
65
  from mlx_lm.models.base import create_attention_mask
66
  from .engine import run_expert_ffn
67
 
 
171
  yield chunk
172
  logits = forward(token.reshape(1, 1))
173
  mx.eval(logits)
174
+
175
+
176
+ def _gemma4_generate_stream(engine, messages, max_tokens=200):
177
+ """Gemma 4 streaming generation using the engine's built-in forward.
178
+
179
+ Handles Gemma 4's chat template (turn_start/turn_end tokens) and
180
+ its mixed-quantization architecture.
181
+ """
182
+ import mlx.core as mx
183
+
184
+ engine.reset_cache()
185
+
186
+ # Build prompt from messages — concatenate all user content for now
187
+ # (multi-turn handling can be added later)
188
+ prompt = ""
189
+ for msg in messages:
190
+ if msg.get("role") == "user":
191
+ prompt = msg.get("content", "")
192
+ break
193
+
194
+ # Use engine's chat template encoder
195
+ tokens = engine.encode_chat(prompt)
196
+ input_ids = mx.array([tokens])
197
+
198
+ # Prefill
199
+ logits = engine.forward(input_ids)
200
+ mx.eval(logits)
201
+
202
+ # Sample first token
203
+ next_token = int(mx.argmax(logits[0, -1]).item())
204
+
205
+ # Gemma 4 EOS: <eos>=1, <turn|>=106
206
+ EOS = {1, 106}
207
+
208
+ for _ in range(max_tokens):
209
+ if next_token in EOS:
210
+ break
211
+
212
+ chunk = engine.decode([next_token])
213
+ if chunk:
214
+ yield chunk
215
+
216
+ # Next forward step
217
+ input_ids = mx.array([[next_token]])
218
+ logits = engine.forward(input_ids)
219
+ mx.eval(logits)
220
+ next_token = int(mx.argmax(logits[0, -1]).item())
src/mlx_expert_sniper/models/gemma4.py CHANGED
@@ -12,6 +12,7 @@ Architecture: gemma4_text
12
  Reference: HuggingFace transformers Gemma4TextModel
13
  """
14
 
 
15
  from dataclasses import dataclass, field
16
  from typing import Any, Dict, List, Optional, Tuple
17
 
@@ -74,7 +75,20 @@ class RMSNorm(nn.Module):
74
  self.eps = eps
75
 
76
  def __call__(self, x: mx.array) -> mx.array:
77
- return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
 
80
  # --------------------------------------------------------------------------- #
@@ -95,7 +109,8 @@ class Attention(nn.Module):
95
  super().__init__()
96
  self.layer_idx = layer_idx
97
  self.is_sliding = args.layer_types[layer_idx] == "sliding_attention"
98
- self.attention_k_eq_v = args.attention_k_eq_v
 
99
 
100
  self.n_heads = args.num_attention_heads
101
 
@@ -111,17 +126,18 @@ class Attention(nn.Module):
111
  rope_dims = int(args.global_head_dim * args.partial_rotary_factor)
112
  rope_theta = args.rope_theta_global
113
 
114
- self.scale = self.head_dim ** -0.5
115
 
116
  self.q_proj = nn.Linear(args.hidden_size, self.n_heads * self.head_dim, bias=False)
117
  self.k_proj = nn.Linear(args.hidden_size, self.n_kv_heads * self.head_dim, bias=False)
118
- # v_proj exists for weight loading but K=V means we use k_proj output for V too
119
- if not self.attention_k_eq_v:
120
  self.v_proj = nn.Linear(args.hidden_size, self.n_kv_heads * self.head_dim, bias=False)
121
  self.o_proj = nn.Linear(self.n_heads * self.head_dim, args.hidden_size, bias=False)
122
 
123
  self.q_norm = RMSNorm(self.head_dim, eps=args.rms_norm_eps)
124
  self.k_norm = RMSNorm(self.head_dim, eps=args.rms_norm_eps)
 
125
 
126
  self.rope = nn.RoPE(rope_dims, traditional=False, base=rope_theta)
127
 
@@ -135,15 +151,17 @@ class Attention(nn.Module):
135
 
136
  queries = self.q_proj(x)
137
  keys = self.k_proj(x)
138
- # K=V: use key projection output as values too
139
- values = keys if self.attention_k_eq_v else self.v_proj(x)
140
 
141
  queries = queries.reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
142
  keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
143
  values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
144
 
 
145
  queries = self.q_norm(queries)
146
  keys = self.k_norm(keys)
 
147
 
148
  if cache is not None:
149
  queries = self.rope(queries, offset=cache.offset)
@@ -198,8 +216,8 @@ class Router(nn.Module):
198
  self.top_k = args.top_k_experts
199
 
200
  self.proj = nn.Linear(args.hidden_size, args.num_experts, bias=False)
201
- # Learnable scalar scale
202
- self.scale = mx.ones((1,))
203
  # Per-expert scales
204
  self.per_expert_scale = mx.ones((args.num_experts,))
205
 
@@ -352,6 +370,7 @@ class DecoderLayer(nn.Module):
352
  # Attention
353
  self.self_attn = Attention(args, layer_idx)
354
  self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
 
355
 
356
  # Dense MLP
357
  self.mlp = DenseMLP(args)
@@ -377,37 +396,44 @@ class DecoderLayer(nn.Module):
377
  mask: Optional[mx.array] = None,
378
  cache: Optional[Any] = None,
379
  ) -> mx.array:
380
- # 1. Attention
381
  residual = x
382
  h = self.input_layernorm(x)
383
  h = self.self_attn(h, mask, cache)
 
384
  h = residual + h
385
 
386
- # 2. Dense MLP
387
  residual = h
388
- dense_in = self.pre_feedforward_layernorm(h)
389
- dense_out = self.mlp(dense_in)
390
- h = self.post_feedforward_layernorm(dense_out)
391
 
392
- # 3. MoE (parallel to dense, sharing the same residual)
393
  if self.enable_moe_block:
394
- # MoE input: norm applied to (residual + raw dense_out), not the post-normed version
395
- moe_input = self.pre_feedforward_layernorm_2(residual + dense_out)
396
-
397
- # Route
398
- top_k_weights, top_k_indices = self.router(moe_input)
399
-
400
- # Expert forward
401
- expert_out = self.experts(moe_input, top_k_indices)
402
- # Weighted sum over top-k experts: [B, L, top_k, D] * [B, L, top_k, 1] -> [B, L, D]
403
- weighted_out = (expert_out * mx.expand_dims(top_k_weights, -1)).sum(axis=-2)
404
- moe_out = self.post_feedforward_layernorm_2(weighted_out)
405
-
406
- # Combine: dense (post-normed again) + moe
407
- h = self.post_feedforward_layernorm_1(h) + moe_out
408
-
409
- # 4. Residual + layer scalar
 
 
 
 
 
 
 
410
  h = residual + h
 
411
  h = h * self.layer_scalar
412
 
413
  return h
@@ -538,9 +564,16 @@ class Model(nn.Module):
538
  if new_key.startswith("model.language_model."):
539
  new_key = "model." + new_key[len("model.language_model."):]
540
 
541
- # Drop v_proj when K=V (weights are identical to k_proj)
 
542
  if self.args.attention_k_eq_v and "v_proj" in new_key:
543
- continue
 
 
 
 
 
 
544
 
545
  # Drop lm_head when tied
546
  if self.args.tie_word_embeddings and new_key == "lm_head.weight":
 
12
  Reference: HuggingFace transformers Gemma4TextModel
13
  """
14
 
15
+ import re
16
  from dataclasses import dataclass, field
17
  from typing import Any, Dict, List, Optional, Tuple
18
 
 
75
  self.eps = eps
76
 
77
  def __call__(self, x: mx.array) -> mx.array:
78
+ # Gemma 4 GGUF norm_shift=0.0: weight is the final multiplier (no +1 offset)
79
+ # Confirmed by mlx-vlm RMSNormZeroShift implementation
80
+ return mx.fast.rms_norm(x, self.weight, self.eps)
81
+
82
+
83
+ class BareRMSNorm(nn.Module):
84
+ """RMSNorm without learnable scale (used for v_norm)."""
85
+ def __init__(self, dims: int, eps: float = 1e-6):
86
+ super().__init__()
87
+ self.eps = eps
88
+ self._dims = dims
89
+
90
+ def __call__(self, x: mx.array) -> mx.array:
91
+ return mx.fast.rms_norm(x, mx.ones((self._dims,)), self.eps)
92
 
93
 
94
  # --------------------------------------------------------------------------- #
 
109
  super().__init__()
110
  self.layer_idx = layer_idx
111
  self.is_sliding = args.layer_types[layer_idx] == "sliding_attention"
112
+ # K=V sharing only applies to full (non-sliding) attention layers
113
+ self.use_kv_sharing = args.attention_k_eq_v and not self.is_sliding
114
 
115
  self.n_heads = args.num_attention_heads
116
 
 
126
  rope_dims = int(args.global_head_dim * args.partial_rotary_factor)
127
  rope_theta = args.rope_theta_global
128
 
129
+ self.scale = 1.0 # HF Gemma4 uses scaling=1.0; q_norm/k_norm handle magnitude
130
 
131
  self.q_proj = nn.Linear(args.hidden_size, self.n_heads * self.head_dim, bias=False)
132
  self.k_proj = nn.Linear(args.hidden_size, self.n_kv_heads * self.head_dim, bias=False)
133
+ # v_proj needed for sliding layers; dropped for full layers with K=V sharing
134
+ if not self.use_kv_sharing:
135
  self.v_proj = nn.Linear(args.hidden_size, self.n_kv_heads * self.head_dim, bias=False)
136
  self.o_proj = nn.Linear(self.n_heads * self.head_dim, args.hidden_size, bias=False)
137
 
138
  self.q_norm = RMSNorm(self.head_dim, eps=args.rms_norm_eps)
139
  self.k_norm = RMSNorm(self.head_dim, eps=args.rms_norm_eps)
140
+ self.v_norm = BareRMSNorm(self.head_dim, eps=args.rms_norm_eps)
141
 
142
  self.rope = nn.RoPE(rope_dims, traditional=False, base=rope_theta)
143
 
 
151
 
152
  queries = self.q_proj(x)
153
  keys = self.k_proj(x)
154
+ # K=V sharing: only for full attention layers
155
+ values = keys if self.use_kv_sharing else self.v_proj(x)
156
 
157
  queries = queries.reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
158
  keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
159
  values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
160
 
161
+ # Norms: q_norm and k_norm BEFORE RoPE, v_norm on values
162
  queries = self.q_norm(queries)
163
  keys = self.k_norm(keys)
164
+ values = self.v_norm(values)
165
 
166
  if cache is not None:
167
  queries = self.rope(queries, offset=cache.offset)
 
216
  self.top_k = args.top_k_experts
217
 
218
  self.proj = nn.Linear(args.hidden_size, args.num_experts, bias=False)
219
+ # Learnable per-dimension scale (shape matches hidden_size)
220
+ self.scale = mx.ones((args.hidden_size,))
221
  # Per-expert scales
222
  self.per_expert_scale = mx.ones((args.num_experts,))
223
 
 
370
  # Attention
371
  self.self_attn = Attention(args, layer_idx)
372
  self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
373
+ self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
374
 
375
  # Dense MLP
376
  self.mlp = DenseMLP(args)
 
396
  mask: Optional[mx.array] = None,
397
  cache: Optional[Any] = None,
398
  ) -> mx.array:
399
+ # 1. Attention with pre/post norms and residual
400
  residual = x
401
  h = self.input_layernorm(x)
402
  h = self.self_attn(h, mask, cache)
403
+ h = self.post_attention_layernorm(h)
404
  h = residual + h
405
 
406
+ # 2. Feed-forward (dense MLP, optionally combined with MoE)
407
  residual = h
408
+ h = self.pre_feedforward_layernorm(h)
409
+ h = self.mlp(h)
 
410
 
 
411
  if self.enable_moe_block:
412
+ # Dense MLP output -> post_feedforward_layernorm_1
413
+ h_dense = self.post_feedforward_layernorm_1(h)
414
+
415
+ # MoE: router takes residual (pre-MLP hidden states), NOT normed
416
+ B, L, D = residual.shape
417
+ residual_flat = residual.reshape(-1, D)
418
+ top_k_weights, top_k_indices = self.router(residual_flat)
419
+
420
+ # Expert input: pre_feedforward_layernorm_2 applied to residual
421
+ moe_input = self.pre_feedforward_layernorm_2(residual_flat)
422
+ expert_out = self.experts(
423
+ moe_input.reshape(B, L, D), top_k_indices.reshape(B, L, -1)
424
+ )
425
+ # Weighted sum over top-k experts
426
+ top_k_weights_r = top_k_weights.reshape(B, L, -1)
427
+ weighted_out = (expert_out * mx.expand_dims(top_k_weights_r, -1)).sum(axis=-2)
428
+ h_moe = self.post_feedforward_layernorm_2(weighted_out)
429
+
430
+ # Combine dense + MoE
431
+ h = h_dense + h_moe
432
+
433
+ # Final post-feedforward norm + residual
434
+ h = self.post_feedforward_layernorm(h)
435
  h = residual + h
436
+
437
  h = h * self.layer_scalar
438
 
439
  return h
 
564
  if new_key.startswith("model.language_model."):
565
  new_key = "model." + new_key[len("model.language_model."):]
566
 
567
+ # Drop v_proj only for full attention layers with K=V sharing
568
+ # Sliding layers still need v_proj even when attention_k_eq_v is true
569
  if self.args.attention_k_eq_v and "v_proj" in new_key:
570
+ # Extract layer index to check if it's a full attention layer
571
+ layer_match = re.search(r'layers\.(\d+)\.', new_key)
572
+ if layer_match:
573
+ layer_idx = int(layer_match.group(1))
574
+ if self.args.layer_types[layer_idx] != "sliding_attention":
575
+ continue # Drop v_proj for full attention layers
576
+ # If no layer index found, keep the weight
577
 
578
  # Drop lm_head when tied
579
  if self.args.tie_word_embeddings and new_key == "lm_head.weight":
src/mlx_expert_sniper/preprocess_gemma4.py CHANGED
@@ -2,231 +2,225 @@
2
  """
3
  Preprocess Gemma 4-26B-A4B into sniper streaming format.
4
 
5
- Expert tensor naming (different from Qwen):
6
- Qwen: layers.N.mlp.switch_mlp.{gate,up,down}_proj.{weight,scales,biases}
7
- Gemma4: layers.N.experts.gate_up_proj (fused, [128, 1408, 2816] bf16)
8
- layers.N.experts.down_proj ([128, 2816, 704] bf16)
9
- layers.N.router.{proj.weight, scale, per_expert_scale}
10
-
11
- The experts are stored as bf16 (not quantized at source).
12
- We can optionally quantize during preprocessing for smaller disk footprint.
 
 
 
13
  """
14
- import os, json, gc, time, re, glob
 
 
 
 
 
15
  import numpy as np
16
  import mlx.core as mx
17
 
18
  PAGE_SIZE = 16384
19
 
20
- # Gemma 4 expert tensors (per layer, shape includes expert dim)
21
- EXPERT_TENSORS = [
22
- "experts.gate_up_proj", # [num_experts, 2*moe_inter, hidden]
23
- "experts.down_proj", # [num_experts, hidden, moe_inter]
24
- ]
25
-
26
 
27
- def preprocess_gemma4(input_dir, output_dir, quantize_experts=False):
28
- """Split Gemma 4 into pinned + streaming experts.
29
 
30
  Args:
31
- input_dir: HuggingFace download directory
32
- output_dir: sniper streaming format output
33
- quantize_experts: if True, quantize experts to 4-bit (saves disk)
34
  """
35
- os.makedirs(output_dir, exist_ok=True)
36
- os.makedirs(os.path.join(output_dir, "bin"), exist_ok=True)
 
 
37
 
38
- config = json.load(open(os.path.join(input_dir, "config.json")))
39
  tc = config.get("text_config", config)
40
  NUM_LAYERS = tc["num_hidden_layers"]
41
  NUM_EXPERTS = tc["num_experts"]
42
- hidden_size = tc["hidden_size"]
43
- moe_inter = tc["moe_intermediate_size"]
44
 
45
- shard_files = sorted(glob.glob(os.path.join(input_dir, "model-*.safetensors")))
46
- print(f"Gemma 4: {NUM_LAYERS} layers, {NUM_EXPERTS} experts, {len(shard_files)} shards")
47
- print(f" Hidden: {hidden_size}, MoE inter: {moe_inter}")
48
- print(f" Expert storage: bf16 (not quantized)")
 
49
 
50
- pinned = {}
51
- expert_keys = {} # layer -> {name: tensor}
52
- expert_layers_done = set()
53
  t0 = time.time()
54
- total_expert_bytes = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- for si, sf in enumerate(shard_files):
57
- shard_name = os.path.basename(sf)
58
- print(f"\n Shard {si+1}/{len(shard_files)}: {shard_name}")
59
- w = mx.load(sf)
60
-
61
- for k, v in w.items():
62
- # Strip language_model. prefix
63
- clean_k = k.replace("model.language_model.", "")
64
-
65
- # Check if this is an expert tensor
66
- is_expert = False
67
- for et in EXPERT_TENSORS:
68
- if et in clean_k:
69
- is_expert = True
70
- break
71
-
72
- if is_expert:
73
- m = re.search(r"layers\.(\d+)\.", clean_k)
74
- if m:
75
- layer_idx = int(m.group(1))
76
- # Local name: just the part after "layers.N."
77
- local_name = clean_k.split(f"layers.{layer_idx}.")[-1]
78
- if layer_idx not in expert_keys:
79
- expert_keys[layer_idx] = {}
80
- expert_keys[layer_idx][local_name] = v
81
- else:
82
- # Skip vision tower for pinned
83
- if "vision_tower" not in k and "embed_vision" not in k:
84
- pinned[clean_k] = v
85
-
86
- # Write complete expert layers
87
- for layer_idx in sorted(expert_keys.keys()):
88
- if layer_idx in expert_layers_done:
89
- continue
90
- if len(expert_keys[layer_idx]) < len(EXPERT_TENSORS):
91
- continue
92
-
93
- lt = expert_keys[layer_idx]
94
- _write_expert_layer(output_dir, layer_idx, lt, NUM_EXPERTS, t0)
95
- total_expert_bytes += os.path.getsize(
96
- os.path.join(output_dir, "bin", f"moe_layer_{layer_idx:02d}.bin"))
97
- expert_layers_done.add(layer_idx)
98
- del expert_keys[layer_idx]
99
-
100
- del w; gc.collect()
101
- os.remove(sf)
102
- print(f" Deleted {shard_name}")
103
-
104
- # Handle any remaining cross-shard layers
105
- for layer_idx in sorted(expert_keys.keys()):
106
- if layer_idx in expert_layers_done:
107
- continue
108
- lt = expert_keys[layer_idx]
109
- if len(lt) < len(EXPERT_TENSORS):
110
- print(f" WARNING: Layer {layer_idx} incomplete ({len(lt)} tensors)")
111
- continue
112
- _write_expert_layer(output_dir, layer_idx, lt, NUM_EXPERTS, t0)
113
- total_expert_bytes += os.path.getsize(
114
- os.path.join(output_dir, "bin", f"moe_layer_{layer_idx:02d}.bin"))
115
 
116
  # Save pinned
 
 
117
  pinned_bytes = sum(v.nbytes for v in pinned.values())
118
- mx.save_safetensors(os.path.join(output_dir, "pinned.safetensors"), pinned)
119
- print(f"\n Saved pinned.safetensors: {pinned_bytes/1e9:.2f} GB ({len(pinned)} keys)")
120
- del pinned; gc.collect()
121
-
122
- # Symlinks
123
- for i in range(NUM_LAYERS):
124
- src = f"moe_layer_{i:02d}.bin"
125
- dst = os.path.join(output_dir, "bin", f"layer_{i:02d}.bin")
126
- if os.path.exists(os.path.join(output_dir, "bin", src)) and not os.path.exists(dst):
127
- os.symlink(src, dst)
128
-
129
- # Write config
130
- stream_config = {
131
- "model_type": "gemma4",
132
- "hidden_size": hidden_size,
133
- "num_hidden_layers": NUM_LAYERS,
134
- "num_experts": NUM_EXPERTS,
135
- "top_k_experts": tc["top_k_experts"],
136
- "moe_intermediate_size": moe_inter,
137
- "intermediate_size": tc["intermediate_size"],
138
- "num_attention_heads": tc["num_attention_heads"],
139
- "num_key_value_heads": tc["num_key_value_heads"],
140
- "num_global_key_value_heads": tc.get("num_global_key_value_heads", 2),
141
- "global_head_dim": tc.get("global_head_dim", 512),
142
- "head_dim": tc.get("head_dim", 256),
143
- "vocab_size": tc["vocab_size"],
144
- "rms_norm_eps": tc.get("rms_norm_eps", 1e-6),
145
- "sliding_window": tc.get("sliding_window", 1024),
146
- "layer_types": tc.get("layer_types", []),
147
- "hidden_activation": tc.get("hidden_activation", "gelu_pytorch_tanh"),
148
- "final_logit_softcapping": tc.get("final_logit_softcapping", 30.0),
149
- "enable_moe_block": tc.get("enable_moe_block", True),
150
- "attention_k_eq_v": tc.get("attention_k_eq_v", True),
151
- "rope_parameters": tc.get("rope_parameters"),
152
- "max_position_embeddings": tc.get("max_position_embeddings", 262144),
153
- "tie_word_embeddings": config.get("tie_word_embeddings", True),
154
- "streaming": {"pinned_file": "pinned.safetensors", "expert_dir": "bin"},
155
- }
156
- with open(os.path.join(output_dir, "config.json"), "w") as f:
157
  json.dump(stream_config, f, indent=2)
158
 
159
- # Copy tokenizer
160
  import shutil
161
- for tf in ["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json",
162
- "added_tokens.json", "tokenizer.model"]:
163
- src = os.path.join(input_dir, tf)
164
  if os.path.exists(src):
165
- shutil.copy(src, os.path.join(output_dir, tf))
166
-
167
- # Verify
168
- layer_count = sum(1 for f in os.listdir(os.path.join(output_dir, "bin"))
169
- if f.startswith("moe_layer_") and f.endswith(".bin"))
170
- elapsed = time.time() - t0
171
- print(f"\n Done in {elapsed:.0f}s!")
172
- print(f" Pinned: {pinned_bytes/1e9:.2f} GB, Experts: {total_expert_bytes/1e9:.2f} GB")
173
- print(f" Layers: {layer_count}/{NUM_LAYERS}")
174
-
175
-
176
- def _write_expert_layer(output_dir, layer_idx, layer_tensors, num_experts, t0):
177
- """Write one layer's experts to a binary file."""
178
- # Build tensor info and calculate sizes
179
- tensor_order = ["experts.gate_up_proj", "experts.down_proj"]
180
- tensor_info = {}
181
- offset = 0
182
- for tname in tensor_order:
183
- t = layer_tensors[tname]
184
- per_expert_shape = list(t.shape[1:]) # remove expert dim
185
- per_expert_bytes = int(np.prod(per_expert_shape)) * t.dtype.size
186
- tensor_info[tname] = {
187
- "inner_offset": offset,
188
- "nbytes": per_expert_bytes,
189
- "shape_per_expert": per_expert_shape,
190
- "dtype": str(t.dtype),
191
- }
192
- offset += per_expert_bytes
193
-
194
- expert_block_size = ((offset + PAGE_SIZE - 1) // PAGE_SIZE) * PAGE_SIZE
195
-
196
- header = {
197
- "layer_idx": layer_idx,
198
- "num_experts": num_experts,
199
- "layout": {
200
- "expert_block_size": expert_block_size,
201
- "data_start": PAGE_SIZE,
202
- "tensors": tensor_info,
203
- }
204
- }
205
- header_json = json.dumps(header).encode()
206
- header_padded = header_json + b"\x00" * (PAGE_SIZE - len(header_json))
207
-
208
- layer_path = os.path.join(output_dir, "bin", f"moe_layer_{layer_idx:02d}.bin")
209
- with open(layer_path, "wb") as f:
210
- f.write(header_padded)
211
- for eid in range(num_experts):
212
- expert_data = bytearray()
213
- for tname in tensor_order:
214
- expert_t = layer_tensors[tname][eid]
215
- mx.eval(expert_t)
216
- if expert_t.dtype == mx.bfloat16:
217
- raw = np.array(expert_t.view(mx.uint16)).tobytes()
218
- else:
219
- raw = np.array(expert_t).tobytes()
220
- expert_data.extend(raw)
221
- pad = expert_block_size - len(expert_data)
222
- if pad > 0:
223
- expert_data.extend(b"\x00" * pad)
224
- f.write(bytes(expert_data))
225
-
226
- sym = os.path.join(output_dir, "bin", f"layer_{layer_idx:02d}.bin")
227
- if not os.path.exists(sym):
228
- os.symlink(f"moe_layer_{layer_idx:02d}.bin", sym)
229
 
230
  elapsed = time.time() - t0
231
- layer_bytes = os.path.getsize(layer_path)
232
- print(f" Layer {layer_idx:2d}: {layer_bytes/1e6:.1f} MB ({elapsed:.0f}s)")
 
 
 
 
 
 
 
 
 
 
 
2
  """
3
  Preprocess Gemma 4-26B-A4B into sniper streaming format.
4
 
5
+ Source: mlx-community/gemma-4-26b-a4b-it-4bit (15.6 GB)
6
+
7
+ Gemma 4 stores experts as SwitchLinear stacked tensors:
8
+ language_model.model.layers.X.experts.switch_glu.{gate,up,down}_proj.{weight,scales,biases}
9
+ Each tensor shape: (128, ...) where dim 0 is the expert index
10
+
11
+ This script unstacks them into per-expert blocks in the bin format that
12
+ expert_io.MoEExpertReader expects:
13
+ bin/layer_XX.bin (header + 128 expert blocks per layer)
14
+
15
+ Mixed quantization is preserved: experts stay 4-bit, dense MLP is 8-bit.
16
  """
17
+ import os
18
+ import json
19
+ import gc
20
+ import time
21
+ import glob
22
+
23
  import numpy as np
24
  import mlx.core as mx
25
 
26
  PAGE_SIZE = 16384
27
 
 
 
 
 
 
 
28
 
29
+ def preprocess_gemma4(input_dir, output_dir):
30
+ """Split Gemma 4 into pinned safetensors + streaming expert bins.
31
 
32
  Args:
33
+ input_dir: HuggingFace download directory (mlx-community 4-bit)
34
+ output_dir: sniper streaming format output directory
 
35
  """
36
+ INPUT_DIR = os.path.expanduser(input_dir)
37
+ OUTPUT_DIR = os.path.expanduser(output_dir)
38
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
39
+ os.makedirs(os.path.join(OUTPUT_DIR, "bin"), exist_ok=True)
40
 
41
+ config = json.load(open(os.path.join(INPUT_DIR, "config.json")))
42
  tc = config.get("text_config", config)
43
  NUM_LAYERS = tc["num_hidden_layers"]
44
  NUM_EXPERTS = tc["num_experts"]
 
 
45
 
46
+ print(f"Gemma 4 Preprocess (SwitchLinear unstack)")
47
+ print(f" Input: {INPUT_DIR}")
48
+ print(f" Output: {OUTPUT_DIR}")
49
+ print(f" Layers: {NUM_LAYERS}, Experts: {NUM_EXPERTS}")
50
+ print()
51
 
52
+ # Load all weights
53
+ print("Loading safetensors...")
 
54
  t0 = time.time()
55
+ all_weights = {}
56
+ for sf in sorted(glob.glob(os.path.join(INPUT_DIR, "model-*.safetensors"))):
57
+ print(f" {os.path.basename(sf)}")
58
+ all_weights.update(mx.load(sf))
59
+
60
+ # Identify expert vs pinned keys
61
+ pinned = {}
62
+ expert_tensors = {} # layer_idx -> {tensor_name: stacked_tensor}
63
+
64
+ EXPERT_PREFIX_TMPL = "language_model.model.layers.{}.experts.switch_glu.{}.{}"
65
+ PROJ_NAMES = ["gate_proj", "up_proj", "down_proj"]
66
+ COMP_NAMES = ["weight", "scales", "biases"]
67
+
68
+ # Build set of expert key paths for fast lookup
69
+ expert_key_set = set()
70
+ for li in range(NUM_LAYERS):
71
+ for proj in PROJ_NAMES:
72
+ for comp in COMP_NAMES:
73
+ expert_key_set.add(EXPERT_PREFIX_TMPL.format(li, proj, comp))
74
+
75
+ for key, val in all_weights.items():
76
+ if key in expert_key_set:
77
+ # Extract layer_idx from key
78
+ parts = key.split(".")
79
+ li = int(parts[3]) # language_model.model.layers.X.experts...
80
+ proj = parts[6] # gate_proj/up_proj/down_proj
81
+ comp = parts[7] # weight/scales/biases
82
+
83
+ if li not in expert_tensors:
84
+ expert_tensors[li] = {}
85
+ tname = f"switch_mlp.{proj}.{comp}"
86
+ expert_tensors[li][tname] = val
87
+ else:
88
+ pinned[key] = val
89
+
90
+ print(f"\n Expert layers: {len(expert_tensors)}")
91
+ print(f" Pinned keys: {len(pinned)}")
92
+
93
+ # Determine per-expert block layout from first layer
94
+ first_layer = expert_tensors[0]
95
+ tensor_layout = {}
96
+ inner_offset = 0
97
+
98
+ for tname in sorted(first_layer.keys()):
99
+ arr = first_layer[tname]
100
+ per_expert_shape = list(arr.shape[1:])
101
+
102
+ if arr.dtype == mx.uint32:
103
+ dtype_str = "uint32"
104
+ elem_size = 4
105
+ elif arr.dtype == mx.bfloat16:
106
+ dtype_str = "bfloat16"
107
+ elem_size = 2
108
+ elif arr.dtype == mx.float16:
109
+ dtype_str = "float16"
110
+ elem_size = 2
111
+ elif arr.dtype == mx.float32:
112
+ dtype_str = "float32"
113
+ elem_size = 4
114
+ else:
115
+ dtype_str = str(arr.dtype).replace("mlx.core.", "")
116
+ elem_size = 2
117
+
118
+ nbytes = elem_size
119
+ for d in per_expert_shape:
120
+ nbytes *= d
121
+
122
+ tensor_layout[tname] = {
123
+ "inner_offset": inner_offset,
124
+ "nbytes": nbytes,
125
+ "shape_per_expert": per_expert_shape,
126
+ "dtype": dtype_str,
127
+ }
128
+ inner_offset += nbytes
129
+
130
+ expert_block_size = inner_offset
131
+ data_start = PAGE_SIZE
132
 
133
+ print(f" Expert block: {expert_block_size} bytes ({expert_block_size/1024:.1f} KB)")
134
+ print()
135
+
136
+ # Write layer files
137
+ total_expert_bytes = 0
138
+ for layer_idx in range(NUM_LAYERS):
139
+ lt = time.time()
140
+ layer_data = expert_tensors[layer_idx]
141
+
142
+ header = {
143
+ "format": "expert_sniper_v1",
144
+ "model": "gemma4-26b-a4b",
145
+ "layer_idx": layer_idx,
146
+ "num_experts": NUM_EXPERTS,
147
+ "layout": {
148
+ "expert_block_size": expert_block_size,
149
+ "data_start": data_start,
150
+ "tensors": tensor_layout,
151
+ }
152
+ }
153
+ header_bytes = json.dumps(header, indent=2).encode("utf-8")
154
+ assert len(header_bytes) < PAGE_SIZE, f"Header too large: {len(header_bytes)}"
155
+ header_padded = header_bytes + b"\x00" * (PAGE_SIZE - len(header_bytes))
156
+
157
+ layer_path = os.path.join(OUTPUT_DIR, "bin", f"layer_{layer_idx:02d}.bin")
158
+ with open(layer_path, "wb") as f:
159
+ f.write(header_padded)
160
+
161
+ for eid in range(NUM_EXPERTS):
162
+ expert_data = bytearray()
163
+ for tname in sorted(tensor_layout.keys()):
164
+ stacked = layer_data[tname]
165
+ single = stacked[eid]
166
+ mx.eval(single)
167
+
168
+ if single.dtype == mx.uint32:
169
+ np_arr = np.array(single).view(np.uint32)
170
+ elif single.dtype == mx.bfloat16:
171
+ # Preserve bfloat16 bytes via uint16 view
172
+ np_arr = np.array(single.view(mx.uint16))
173
+ elif single.dtype == mx.float32:
174
+ np_arr = np.array(single).view(np.float32)
175
+ elif single.dtype == mx.float16:
176
+ np_arr = np.array(single.view(mx.uint16))
177
+ else:
178
+ np_arr = np.array(single)
179
+ expert_data.extend(np_arr.tobytes())
180
+
181
+ if len(expert_data) < expert_block_size:
182
+ expert_data.extend(b"\x00" * (expert_block_size - len(expert_data)))
183
+ f.write(bytes(expert_data[:expert_block_size]))
184
+
185
+ file_size = os.path.getsize(layer_path)
186
+ total_expert_bytes += file_size
187
+ elapsed = time.time() - lt
188
+ print(f" Layer {layer_idx:2d}/{NUM_LAYERS}: {file_size/1e6:.1f} MB ({elapsed:.0f}s)")
189
+
190
+ del expert_tensors[layer_idx]
191
+ gc.collect()
192
 
193
  # Save pinned
194
+ pinned_path = os.path.join(OUTPUT_DIR, "pinned.safetensors")
195
+ mx.save_safetensors(pinned_path, pinned)
196
  pinned_bytes = sum(v.nbytes for v in pinned.values())
197
+ print(f"\nSaved pinned.safetensors: {pinned_bytes/1e9:.2f} GB ({len(pinned)} keys)")
198
+
199
+ # Streaming config
200
+ stream_config = dict(tc)
201
+ stream_config["quantization"] = config.get("quantization", {"bits": 4, "group_size": 64})
202
+ stream_config["streaming"] = {"pinned_file": "pinned.safetensors", "expert_dir": "bin"}
203
+ with open(os.path.join(OUTPUT_DIR, "config.json"), "w") as f:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  json.dump(stream_config, f, indent=2)
205
 
206
+ # Copy tokenizer files
207
  import shutil
208
+ for tf in ["tokenizer.json", "tokenizer_config.json", "chat_template.jinja",
209
+ "generation_config.json", "processor_config.json"]:
210
+ src = os.path.join(INPUT_DIR, tf)
211
  if os.path.exists(src):
212
+ shutil.copy(src, os.path.join(OUTPUT_DIR, tf))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
  elapsed = time.time() - t0
215
+ print(f"\nDone in {elapsed:.0f}s!")
216
+ print(f"Pinned: {pinned_bytes/1e9:.2f} GB, Experts: {total_expert_bytes/1e9:.2f} GB")
217
+ return True
218
+
219
+
220
+ if __name__ == "__main__":
221
+ import argparse
222
+ parser = argparse.ArgumentParser()
223
+ parser.add_argument("--input", "-i", required=True)
224
+ parser.add_argument("--output", "-o", required=True)
225
+ args = parser.parse_args()
226
+ preprocess_gemma4(args.input, args.output)
src/mlx_expert_sniper/sniper.py CHANGED
@@ -45,8 +45,29 @@ class SniperEngine:
45
  self._loaded = False
46
 
47
  @classmethod
48
- def from_dir(cls, sniper_dir: str, **overrides) -> "SniperEngine":
49
- """Create engine from a sniper directory."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  config = SniperConfig.from_dir(sniper_dir, **overrides)
51
  engine = cls(config)
52
  engine.load()
 
45
  self._loaded = False
46
 
47
  @classmethod
48
+ def from_dir(cls, sniper_dir: str, **overrides):
49
+ """Create engine from a sniper directory.
50
+
51
+ Auto-detects model type and dispatches to the right engine:
52
+ - Qwen 3.x MoE → SniperEngine (this class)
53
+ - Gemma 4 → MoESniperEngineGemma4 (different architecture)
54
+ """
55
+ # Peek at config to detect model type
56
+ with open(os.path.join(sniper_dir, "config.json")) as f:
57
+ cfg_data = json.load(f)
58
+ model_type = cfg_data.get("model_type", "")
59
+
60
+ if model_type.startswith("gemma4"):
61
+ from .engine_gemma4 import MoESniperEngineGemma4
62
+ cache_size = overrides.get("max_cached_experts", 4000)
63
+ engine = MoESniperEngineGemma4(
64
+ model_dir=sniper_dir,
65
+ cache_size=cache_size,
66
+ )
67
+ engine.load()
68
+ return engine
69
+
70
+ # Default: Qwen path
71
  config = SniperConfig.from_dir(sniper_dir, **overrides)
72
  engine = cls(config)
73
  engine.load()