waltgrace commited on
Commit
d14a3c2
·
verified ·
1 Parent(s): 2d8a9bb

v0.2.0: Add Qwen3.5-35B-A3B support (5.78 tok/s, 19.5 GB on 16 GB RAM)

Browse files
README.md CHANGED
@@ -1,17 +1,16 @@
1
  # CLI Agent — `mlx-expert-sniper`
2
 
3
  Pip-installable CLI that wraps the Expert Sniper research into a production tool.
 
4
 
5
  ## Verified Results (M4 Mac Mini, 16 GB)
6
 
7
- | Metric | Value |
8
- |--------|-------|
9
- | Model | Qwen3-30B-A3B, 17.2 GB at 4-bit |
10
- | Standard mlx_lm | OOM |
11
- | **Sniper steady-state** | **4.22–4.68 tok/s** |
12
- | Cache hit rate | 85% (cold start) 88.5% (warm) |
13
- | RAM used | 0.87 GB pinned |
14
- | Output | Coherent code, math, essays |
15
 
16
  ## Install
17
 
 
1
  # CLI Agent — `mlx-expert-sniper`
2
 
3
  Pip-installable CLI that wraps the Expert Sniper research into a production tool.
4
+ Run MoE models larger than your RAM on Apple Silicon.
5
 
6
  ## Verified Results (M4 Mac Mini, 16 GB)
7
 
8
+ | Model | Size | Standard mlx_lm | Sniper tok/s | RAM pinned | Cache hit |
9
+ |-------|------|-----------------|--------------|------------|-----------|
10
+ | Qwen3-30B-A3B | 17.2 GB | OOM | **4.33 tok/s** | 0.87 GB | 88.5% |
11
+ | Qwen3.5-35B-A3B | 19.5 GB | OOM | **5.78 tok/s** | 1.38 GB | 75% |
12
+
13
+ Both models exceed 16 GB RAM. Both produce coherent multi-paragraph output.
 
 
14
 
15
  ## Install
16
 
src/mlx_expert_sniper/config.py CHANGED
@@ -52,19 +52,21 @@ class SniperConfig:
52
  with open(config_path) as f:
53
  model_config = json.load(f)
54
 
55
- quant = model_config.get("quantization", {})
 
 
56
 
57
  cfg = cls(
58
  sniper_dir=sniper_dir,
59
  bits=quant.get("bits", 4),
60
  group_size=quant.get("group_size", 64),
61
- num_hidden_layers=model_config.get("num_hidden_layers", 0),
62
- num_experts=model_config.get("num_experts", 0),
63
- num_experts_per_tok=model_config.get("num_experts_per_tok", 0),
64
- hidden_size=model_config.get("hidden_size", 0),
65
- moe_intermediate_size=model_config.get("moe_intermediate_size", 0),
66
- vocab_size=model_config.get("vocab_size", 0),
67
- norm_topk_prob=model_config.get("norm_topk_prob", True),
68
  model_type=model_config.get("model_type", ""),
69
  tokenizer_name=model_config.get("_name_or_path", ""),
70
  )
 
52
  with open(config_path) as f:
53
  model_config = json.load(f)
54
 
55
+ # Handle nested configs (qwen3_5_moe has text_config)
56
+ quant = model_config.get("quantization", model_config.get("quantization_config", {}))
57
+ text_cfg = model_config.get("text_config", model_config)
58
 
59
  cfg = cls(
60
  sniper_dir=sniper_dir,
61
  bits=quant.get("bits", 4),
62
  group_size=quant.get("group_size", 64),
63
+ num_hidden_layers=text_cfg.get("num_hidden_layers", 0),
64
+ num_experts=text_cfg.get("num_experts", 0),
65
+ num_experts_per_tok=text_cfg.get("num_experts_per_tok", 0),
66
+ hidden_size=text_cfg.get("hidden_size", 0),
67
+ moe_intermediate_size=text_cfg.get("moe_intermediate_size", 0),
68
+ vocab_size=text_cfg.get("vocab_size", 0),
69
+ norm_topk_prob=text_cfg.get("norm_topk_prob", True),
70
  model_type=model_config.get("model_type", ""),
71
  tokenizer_name=model_config.get("_name_or_path", ""),
72
  )
src/mlx_expert_sniper/preprocess.py CHANGED
@@ -163,6 +163,16 @@ def preprocess_model(model_dir: str, output_dir: str, verbose: bool = True):
163
  layer = int(key.split(".layers.")[1].split(".")[0])
164
  short = key.split(".switch_mlp.")[1]
165
  layer_experts.setdefault(layer, {})[short] = tensor
 
 
 
 
 
 
 
 
 
 
166
  else:
167
  pinned[key] = tensor
168
 
 
163
  layer = int(key.split(".layers.")[1].split(".")[0])
164
  short = key.split(".switch_mlp.")[1]
165
  layer_experts.setdefault(layer, {})[short] = tensor
166
+ elif "experts.gate_up_proj" in key:
167
+ # qwen3_5_moe: fused gate+up proj — split into separate tensors
168
+ layer = int(key.split(".layers.")[1].split(".")[0])
169
+ gate_up = tensor
170
+ mid = gate_up.shape[-2] // 2
171
+ layer_experts.setdefault(layer, {})["gate_proj.weight"] = gate_up[..., :mid, :]
172
+ layer_experts.setdefault(layer, {})["up_proj.weight"] = gate_up[..., mid:, :]
173
+ elif "experts.down_proj" in key:
174
+ layer = int(key.split(".layers.")[1].split(".")[0])
175
+ layer_experts.setdefault(layer, {})["down_proj.weight"] = tensor
176
  else:
177
  pinned[key] = tensor
178
 
src/mlx_expert_sniper/sniper.py CHANGED
@@ -108,8 +108,11 @@ class SniperEngine:
108
  if mt in ("qwen3_moe", "qwen2_moe"):
109
  from mlx_lm.models.qwen3_moe import Model, ModelArgs
110
  return Model, ModelArgs
 
 
 
111
  raise ValueError(f"Unsupported model_type: {mt}. "
112
- f"Currently supported: qwen3_moe, qwen2_moe")
113
 
114
  def _quantize_model(self, model_config: dict, quant: dict):
115
  """Apply quantization matching the stored format."""
@@ -129,35 +132,87 @@ class SniperEngine:
129
  class_predicate=class_predicate,
130
  )
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  def reset_kv_cache(self):
133
  """Reset the KV cache for a new conversation."""
134
- from mlx_lm.models.cache import make_prompt_cache
135
- self.kv_cache = make_prompt_cache(self.model)
 
 
 
136
 
137
  def forward_token(self, input_ids: mx.array) -> mx.array:
138
  """Run one forward pass with expert sniping.
139
 
140
- This is the proven forward pass: attention (pinned) router →
141
- mx.eval(indices) → cache/pread experts gather_qmm → combine.
142
  """
143
  from mlx_lm.models.base import create_attention_mask
144
 
145
  cfg = self.config
146
  bits = cfg.bits
147
  group_size = cfg.group_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- h = self.model.model.embed_tokens(input_ids)
150
- mask = create_attention_mask(h, self.kv_cache[0])
151
-
152
- for i, layer in enumerate(self.model.model.layers):
153
  # ── Attention (pinned weights, always in RAM) ──
154
  normed = layer.input_layernorm(h)
155
- attn_out = layer.self_attn(normed, mask=mask, cache=self.kv_cache[i])
 
 
 
 
 
 
 
156
  h = h + attn_out
157
  mx.eval(h) # Must eval before router (data-dependent)
158
 
159
- # ── Router: compute expert scores ──
160
  normed = layer.post_attention_layernorm(h)
 
 
 
 
 
 
 
 
161
  gates = layer.mlp.gate(normed)
162
  gates = mx.softmax(gates, axis=-1, precise=True)
163
  k = layer.mlp.top_k
@@ -226,11 +281,19 @@ class SniperEngine:
226
  do = do.squeeze(-2)
227
 
228
  # Weighted sum of expert outputs
229
- h = h + (do * scores[..., None]).sum(axis=-2)
230
  del gw, gs, gb, uw, us, ub, dw, ds, db
231
 
232
- h = self.model.model.norm(h)
233
- return self.model.lm_head(h)
 
 
 
 
 
 
 
 
234
 
235
  def generate(self, prompt, max_tokens=None, temperature=None,
236
  chat_messages=None):
@@ -252,15 +315,27 @@ class SniperEngine:
252
  temperature = temperature if temperature is not None else self.config.temperature
253
 
254
  # Tokenize
255
- if chat_messages:
256
- text = self.tokenizer.apply_chat_template(
257
- chat_messages, tokenize=False,
258
- add_generation_prompt=True, enable_thinking=False)
 
 
 
 
 
 
259
  else:
260
- text = self.tokenizer.apply_chat_template(
261
- [{"role": "user", "content": prompt}],
262
- tokenize=False, add_generation_prompt=True,
263
- enable_thinking=False)
 
 
 
 
 
 
264
 
265
  tokens = self.tokenizer.encode(text)
266
  input_ids = mx.array([tokens])
@@ -275,14 +350,39 @@ class SniperEngine:
275
  # Sample first token
276
  next_token = self._sample(logits[:, -1, :], temperature)
277
 
 
 
 
 
 
 
 
 
 
 
278
  # Autoregressive generation
279
  for _ in range(max_tokens):
280
- if next_token in {151643, 151645}: # EOS tokens
281
  break
282
  word = self.tokenizer.decode([next_token])
283
  if "<|im_end|>" in word or "<|endoftext|>" in word:
284
  break
285
- yield word
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
  input_ids = mx.array([[next_token]])
288
  logits = self.forward_token(input_ids)
 
108
  if mt in ("qwen3_moe", "qwen2_moe"):
109
  from mlx_lm.models.qwen3_moe import Model, ModelArgs
110
  return Model, ModelArgs
111
+ if mt in ("qwen3_5_moe",):
112
+ from mlx_lm.models.qwen3_5_moe import Model, ModelArgs
113
+ return Model, ModelArgs
114
  raise ValueError(f"Unsupported model_type: {mt}. "
115
+ f"Currently supported: qwen3_moe, qwen3_5_moe")
116
 
117
  def _quantize_model(self, model_config: dict, quant: dict):
118
  """Apply quantization matching the stored format."""
 
132
  class_predicate=class_predicate,
133
  )
134
 
135
+ def _get_layers_and_head(self):
136
+ """Get (layers, embed_tokens, norm, lm_head) for the model architecture."""
137
+ mt = self.config.model_type
138
+ if mt in ("qwen3_5_moe",):
139
+ lm = self.model.language_model
140
+ return lm.model.layers, lm.model.embed_tokens, lm.model.norm, lm.lm_head
141
+ return self.model.model.layers, self.model.model.embed_tokens, self.model.model.norm, self.model.lm_head
142
+
143
+ def _is_moe_layer(self, layer) -> bool:
144
+ """Check if a layer has a MoE MLP (vs dense MLP)."""
145
+ return hasattr(layer.mlp, "gate") and hasattr(layer.mlp, "switch_mlp")
146
+
147
+ def _has_shared_expert(self, layer) -> bool:
148
+ """Check if the MoE block has a shared expert."""
149
+ return hasattr(layer.mlp, "shared_expert")
150
+
151
  def reset_kv_cache(self):
152
  """Reset the KV cache for a new conversation."""
153
+ if hasattr(self.model, "make_cache"):
154
+ self.kv_cache = self.model.make_cache()
155
+ else:
156
+ from mlx_lm.models.cache import make_prompt_cache
157
+ self.kv_cache = make_prompt_cache(self.model)
158
 
159
  def forward_token(self, input_ids: mx.array) -> mx.array:
160
  """Run one forward pass with expert sniping.
161
 
162
+ Supports both qwen3_moe (standard attention) and qwen3_5_moe
163
+ (hybrid linear/full attention with shared experts).
164
  """
165
  from mlx_lm.models.base import create_attention_mask
166
 
167
  cfg = self.config
168
  bits = cfg.bits
169
  group_size = cfg.group_size
170
+ is_hybrid = cfg.model_type in ("qwen3_5_moe",)
171
+
172
+ layers, embed_tokens, norm, lm_head = self._get_layers_and_head()
173
+ h = embed_tokens(input_ids)
174
+
175
+ # Create masks
176
+ if is_hybrid:
177
+ from mlx_lm.models.base import create_ssm_mask
178
+ # Find first full-attention and first linear-attention layer cache
179
+ fa_cache = None
180
+ ssm_cache = None
181
+ for li, layer in enumerate(layers):
182
+ if hasattr(layer, "is_linear"):
183
+ if layer.is_linear and ssm_cache is None:
184
+ ssm_cache = self.kv_cache[li]
185
+ elif not layer.is_linear and fa_cache is None:
186
+ fa_cache = self.kv_cache[li]
187
+ fa_mask = create_attention_mask(h, fa_cache) if fa_cache is not None else None
188
+ ssm_mask = create_ssm_mask(h, ssm_cache) if ssm_cache is not None else None
189
+ else:
190
+ mask = create_attention_mask(h, self.kv_cache[0])
191
 
192
+ for i, layer in enumerate(layers):
 
 
 
193
  # ── Attention (pinned weights, always in RAM) ──
194
  normed = layer.input_layernorm(h)
195
+
196
+ if is_hybrid and hasattr(layer, "is_linear") and layer.is_linear:
197
+ attn_out = layer.linear_attn(normed, mask=ssm_mask, cache=self.kv_cache[i])
198
+ elif is_hybrid and hasattr(layer, "self_attn"):
199
+ attn_out = layer.self_attn(normed, mask=fa_mask, cache=self.kv_cache[i])
200
+ else:
201
+ attn_out = layer.self_attn(normed, mask=mask, cache=self.kv_cache[i])
202
+
203
  h = h + attn_out
204
  mx.eval(h) # Must eval before router (data-dependent)
205
 
206
+ # ── Post-attention norm ──
207
  normed = layer.post_attention_layernorm(h)
208
+
209
+ # ── Check if this is an MoE layer ──
210
+ if not self._is_moe_layer(layer):
211
+ # Dense MLP — just run it (pinned weights)
212
+ h = h + layer.mlp(normed)
213
+ continue
214
+
215
+ # ── Router: compute expert scores ──
216
  gates = layer.mlp.gate(normed)
217
  gates = mx.softmax(gates, axis=-1, precise=True)
218
  k = layer.mlp.top_k
 
281
  do = do.squeeze(-2)
282
 
283
  # Weighted sum of expert outputs
284
+ moe_out = (do * scores[..., None]).sum(axis=-2)
285
  del gw, gs, gb, uw, us, ub, dw, ds, db
286
 
287
+ # ── Shared expert (pinned, always runs) ──
288
+ if self._has_shared_expert(layer):
289
+ shared_out = layer.mlp.shared_expert(normed)
290
+ shared_out = mx.sigmoid(layer.mlp.shared_expert_gate(normed)) * shared_out
291
+ h = h + moe_out + shared_out
292
+ else:
293
+ h = h + moe_out
294
+
295
+ h = norm(h)
296
+ return lm_head(h)
297
 
298
  def generate(self, prompt, max_tokens=None, temperature=None,
299
  chat_messages=None):
 
315
  temperature = temperature if temperature is not None else self.config.temperature
316
 
317
  # Tokenize
318
+ messages = chat_messages if chat_messages else [{"role": "user", "content": prompt}]
319
+ if self.tokenizer.chat_template:
320
+ try:
321
+ text = self.tokenizer.apply_chat_template(
322
+ messages, tokenize=False,
323
+ add_generation_prompt=True, enable_thinking=False)
324
+ except TypeError:
325
+ text = self.tokenizer.apply_chat_template(
326
+ messages, tokenize=False,
327
+ add_generation_prompt=True)
328
  else:
329
+ # Fallback: Qwen/ChatML format
330
+ parts = []
331
+ for m in messages:
332
+ parts.append(f"<|im_start|>{m['role']}\n{m['content']}<|im_end|>")
333
+ parts.append("<|im_start|>assistant\n")
334
+ text = "\n".join(parts)
335
+
336
+ # Strip thinking from output by tracking state
337
+ self._in_thinking = False
338
+ self._thinking_buffer = ""
339
 
340
  tokens = self.tokenizer.encode(text)
341
  input_ids = mx.array([tokens])
 
350
  # Sample first token
351
  next_token = self._sample(logits[:, -1, :], temperature)
352
 
353
+ # Build EOS set from tokenizer config
354
+ eos_ids = set()
355
+ if hasattr(self.tokenizer, "eos_token_id"):
356
+ eid = self.tokenizer.eos_token_id
357
+ if isinstance(eid, list):
358
+ eos_ids.update(eid)
359
+ elif eid is not None:
360
+ eos_ids.add(eid)
361
+ eos_ids.update({151643, 151645, 248044, 248046}) # Qwen3 + Qwen3.5 EOS
362
+
363
  # Autoregressive generation
364
  for _ in range(max_tokens):
365
+ if next_token in eos_ids:
366
  break
367
  word = self.tokenizer.decode([next_token])
368
  if "<|im_end|>" in word or "<|endoftext|>" in word:
369
  break
370
+
371
+ # Filter out <think>...</think> blocks
372
+ self._thinking_buffer += word
373
+ if "<think>" in self._thinking_buffer:
374
+ self._in_thinking = True
375
+ self._thinking_buffer = ""
376
+ elif "</think>" in self._thinking_buffer:
377
+ self._in_thinking = False
378
+ # Yield anything after </think>
379
+ after = self._thinking_buffer.split("</think>", 1)[-1].lstrip("\n")
380
+ self._thinking_buffer = ""
381
+ if after:
382
+ yield after
383
+ elif not self._in_thinking:
384
+ yield self._thinking_buffer
385
+ self._thinking_buffer = ""
386
 
387
  input_ids = mx.array([[next_token]])
388
  logits = self.forward_token(input_ids)
stream_preprocess_35b.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Stream-preprocess Qwen3.5-35B-A3B-4bit: download one shard, process, delete."""
3
+
4
+ import os, sys, json, time, gc, shutil, glob
5
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
6
+ import numpy as np
7
+ import mlx.core as mx
8
+
9
+ REPO = "mlx-community/Qwen3.5-35B-A3B-4bit"
10
+ OUTPUT_DIR = os.path.expanduser("~/models/qwen35-35b")
11
+ PAGE_SIZE = 16384
12
+
13
+ TENSOR_NAMES = [
14
+ "gate_proj.weight", "gate_proj.scales", "gate_proj.biases",
15
+ "up_proj.weight", "up_proj.scales", "up_proj.biases",
16
+ "down_proj.weight", "down_proj.scales", "down_proj.biases",
17
+ ]
18
+
19
+
20
+ def convert_layer_to_bin(layer_data, layer_idx, num_experts, output_dir):
21
+ tensor_info = {}
22
+ expert_block_size = 0
23
+ for name in TENSOR_NAMES:
24
+ if name not in layer_data:
25
+ continue
26
+ t = layer_data[name]
27
+ per_expert_shape = list(t.shape[1:])
28
+ if t.dtype == mx.uint32:
29
+ elem_size = 4
30
+ elif t.dtype in (mx.bfloat16, mx.float16):
31
+ elem_size = 2
32
+ else:
33
+ elem_size = 4
34
+ nbytes = 1
35
+ for s in per_expert_shape:
36
+ nbytes *= s
37
+ nbytes *= elem_size
38
+ tensor_info[name] = {
39
+ "shape_per_expert": per_expert_shape,
40
+ "dtype": str(t.dtype).replace("mlx.core.", ""),
41
+ "nbytes": nbytes,
42
+ "inner_offset": expert_block_size,
43
+ }
44
+ expert_block_size += nbytes
45
+
46
+ header = {
47
+ "layer_idx": layer_idx,
48
+ "num_experts": num_experts,
49
+ "layout": {
50
+ "expert_block_size": expert_block_size,
51
+ "data_start": PAGE_SIZE,
52
+ "tensors": tensor_info,
53
+ }
54
+ }
55
+ header_bytes = json.dumps(header, indent=2).encode()
56
+ assert len(header_bytes) < PAGE_SIZE
57
+ header_bytes += b"\x00" * (PAGE_SIZE - len(header_bytes))
58
+
59
+ out_path = os.path.join(output_dir, "bin", f"moe_layer_{layer_idx:02d}.bin")
60
+ with open(out_path, "wb") as f:
61
+ f.write(header_bytes)
62
+ for expert_id in range(num_experts):
63
+ for name in TENSOR_NAMES:
64
+ if name not in layer_data:
65
+ continue
66
+ t = layer_data[name][expert_id]
67
+ if t.dtype == mx.bfloat16:
68
+ raw = np.array(t.astype(mx.float16)).astype(np.float16).tobytes()
69
+ elif t.dtype == mx.uint32:
70
+ raw = np.array(t).astype(np.uint32).tobytes()
71
+ else:
72
+ raw = np.array(t).tobytes()
73
+ f.write(raw)
74
+
75
+ return os.path.getsize(out_path)
76
+
77
+
78
+ def main():
79
+ from huggingface_hub import hf_hub_download
80
+
81
+ print("=" * 55)
82
+ print(" Stream Preprocess Qwen3.5-35B-A3B-4bit")
83
+ print(f" Output: {OUTPUT_DIR}")
84
+ print("=" * 55)
85
+
86
+ os.makedirs(os.path.join(OUTPUT_DIR, "bin"), exist_ok=True)
87
+
88
+ # Download config + tokenizer
89
+ for fname in ["config.json", "tokenizer.json", "tokenizer_config.json",
90
+ "special_tokens_map.json"]:
91
+ try:
92
+ path = hf_hub_download(REPO, fname, local_dir="/tmp/sniper_dl_35b")
93
+ shutil.copy(path, os.path.join(OUTPUT_DIR, fname))
94
+ print(f" Downloaded {fname}")
95
+ except Exception as e:
96
+ print(f" Skipped {fname}: {e}")
97
+
98
+ with open(os.path.join(OUTPUT_DIR, "config.json")) as f:
99
+ config = json.load(f)
100
+ text_cfg = config.get("text_config", config)
101
+ num_layers = text_cfg.get("num_hidden_layers", 40)
102
+ print(f" Layers: {num_layers}, Experts: {text_cfg.get('num_experts', 0)}")
103
+
104
+ idx_path = hf_hub_download(REPO, "model.safetensors.index.json",
105
+ local_dir="/tmp/sniper_dl_35b")
106
+ with open(idx_path) as f:
107
+ idx = json.load(f)
108
+ shards = sorted(set(idx["weight_map"].values()))
109
+ print(f" {len(shards)} shards")
110
+
111
+ existing = set()
112
+ for f in os.listdir(os.path.join(OUTPUT_DIR, "bin")):
113
+ if f.startswith("moe_layer_") and f.endswith(".bin"):
114
+ existing.add(int(f.split("_")[2].split(".")[0]))
115
+ if existing:
116
+ print(f" Already done: {sorted(existing)}")
117
+
118
+ pinned = {}
119
+ layers_done = set(existing)
120
+ partial_layers = {}
121
+
122
+ for si, shard_name in enumerate(shards):
123
+ print(f"\n [{si+1}/{len(shards)}] Downloading {shard_name}...")
124
+ t0 = time.time()
125
+ shard_path = hf_hub_download(REPO, shard_name, local_dir="/tmp/sniper_dl_35b")
126
+ dl_time = time.time() - t0
127
+ shard_size = os.path.getsize(shard_path) / 1e9
128
+ print(f" {shard_size:.1f} GB in {dl_time:.0f}s")
129
+
130
+ data = mx.load(shard_path)
131
+ print(f" {len(data)} tensors")
132
+
133
+ layer_experts = {}
134
+ for key, tensor in data.items():
135
+ # Skip vision tower
136
+ if "vision_tower" in key or "model.visual" in key:
137
+ continue
138
+
139
+ if "switch_mlp" in key and ".layers." in key:
140
+ layer = int(key.split(".layers.")[1].split(".")[0])
141
+ short = key.split(".switch_mlp.")[1]
142
+ layer_experts.setdefault(layer, {})[short] = tensor
143
+ elif "experts.gate_up_proj" in key and ".layers." in key:
144
+ # Fused gate+up — split
145
+ layer = int(key.split(".layers.")[1].split(".")[0])
146
+ gate_up = tensor
147
+ mid = gate_up.shape[-2] // 2
148
+ ld = layer_experts.setdefault(layer, {})
149
+ ld["gate_proj.weight"] = gate_up[..., :mid, :]
150
+ ld["up_proj.weight"] = gate_up[..., mid:, :]
151
+ elif "experts.down_proj" in key and ".layers." in key:
152
+ layer = int(key.split(".layers.")[1].split(".")[0])
153
+ layer_experts.setdefault(layer, {})["down_proj.weight"] = tensor
154
+ else:
155
+ pinned[key] = tensor
156
+
157
+ for layer_idx, tensors in layer_experts.items():
158
+ if layer_idx in layers_done:
159
+ continue
160
+ if layer_idx in partial_layers:
161
+ partial_layers[layer_idx].update(tensors)
162
+ tensors = partial_layers[layer_idx]
163
+
164
+ # Check how many tensor groups we have
165
+ # For quantized: need weight + scales + biases for each of gate/up/down = 9
166
+ # For non-quantized: just weight for gate/up/down = 3
167
+ n_keys = len(tensors)
168
+ has_all = all(f"{p}.weight" in tensors for p in ["gate_proj", "up_proj", "down_proj"])
169
+
170
+ if not has_all:
171
+ partial_layers[layer_idx] = tensors
172
+ print(f" Layer {layer_idx}: partial ({n_keys} tensors)")
173
+ continue
174
+
175
+ num_experts = tensors["gate_proj.weight"].shape[0]
176
+ sz = convert_layer_to_bin(tensors, layer_idx, num_experts, OUTPUT_DIR)
177
+ layers_done.add(layer_idx)
178
+ if layer_idx in partial_layers:
179
+ del partial_layers[layer_idx]
180
+ print(f" Layer {layer_idx}: {sz/1e6:.0f} MB ({num_experts} experts)")
181
+
182
+ del data, layer_experts
183
+ gc.collect()
184
+ mx.clear_cache()
185
+
186
+ try:
187
+ os.remove(shard_path)
188
+ print(f" Deleted shard ({shard_size:.1f} GB freed)")
189
+ except:
190
+ pass
191
+
192
+ # Handle remaining partials
193
+ for layer_idx, tensors in partial_layers.items():
194
+ if layer_idx in layers_done:
195
+ continue
196
+ has_all = all(f"{p}.weight" in tensors for p in ["gate_proj", "up_proj", "down_proj"])
197
+ if has_all:
198
+ num_experts = tensors["gate_proj.weight"].shape[0]
199
+ sz = convert_layer_to_bin(tensors, layer_idx, num_experts, OUTPUT_DIR)
200
+ layers_done.add(layer_idx)
201
+ print(f" Layer {layer_idx}: {sz/1e6:.0f} MB (merged)")
202
+
203
+ # Save pinned
204
+ if pinned:
205
+ print(f"\n Saving pinned ({len(pinned)} tensors)...")
206
+ mx.save_safetensors(os.path.join(OUTPUT_DIR, "pinned.safetensors"), pinned)
207
+ psz = os.path.getsize(os.path.join(OUTPUT_DIR, "pinned.safetensors")) / 1e9
208
+ print(f" Pinned: {psz:.2f} GB")
209
+ else:
210
+ psz = 0
211
+
212
+ shutil.rmtree("/tmp/sniper_dl_35b", ignore_errors=True)
213
+
214
+ bin_files = sorted(glob.glob(os.path.join(OUTPUT_DIR, "bin", "moe_layer_*.bin")))
215
+ total = sum(os.path.getsize(f) for f in bin_files)
216
+ missing = set(range(num_layers)) - layers_done
217
+ print(f"\n Expert layers: {len(bin_files)}/{num_layers}")
218
+ print(f" Expert total: {total/1e9:.2f} GB")
219
+ print(f" Pinned: {psz:.2f} GB")
220
+ if missing:
221
+ print(f" WARNING: Missing layers: {sorted(missing)}")
222
+ else:
223
+ print(f"\n All {num_layers} layers converted!")
224
+ print(f" Test: mlx-sniper run {OUTPUT_DIR} -p 'What is 2+2?' -v")
225
+
226
+
227
+ if __name__ == "__main__":
228
+ main()