v0.2.0: Add Qwen3.5-35B-A3B support (5.78 tok/s, 19.5 GB on 16 GB RAM)
Browse files- README.md +7 -8
- src/mlx_expert_sniper/config.py +10 -8
- src/mlx_expert_sniper/preprocess.py +10 -0
- src/mlx_expert_sniper/sniper.py +124 -24
- stream_preprocess_35b.py +228 -0
README.md
CHANGED
|
@@ -1,17 +1,16 @@
|
|
| 1 |
# CLI Agent — `mlx-expert-sniper`
|
| 2 |
|
| 3 |
Pip-installable CLI that wraps the Expert Sniper research into a production tool.
|
|
|
|
| 4 |
|
| 5 |
## Verified Results (M4 Mac Mini, 16 GB)
|
| 6 |
|
| 7 |
-
|
|
| 8 |
-
|--------|-------|
|
| 9 |
-
|
|
| 10 |
-
|
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
| RAM used | 0.87 GB pinned |
|
| 14 |
-
| Output | Coherent code, math, essays |
|
| 15 |
|
| 16 |
## Install
|
| 17 |
|
|
|
|
| 1 |
# CLI Agent — `mlx-expert-sniper`
|
| 2 |
|
| 3 |
Pip-installable CLI that wraps the Expert Sniper research into a production tool.
|
| 4 |
+
Run MoE models larger than your RAM on Apple Silicon.
|
| 5 |
|
| 6 |
## Verified Results (M4 Mac Mini, 16 GB)
|
| 7 |
|
| 8 |
+
| Model | Size | Standard mlx_lm | Sniper tok/s | RAM pinned | Cache hit |
|
| 9 |
+
|-------|------|-----------------|--------------|------------|-----------|
|
| 10 |
+
| Qwen3-30B-A3B | 17.2 GB | OOM | **4.33 tok/s** | 0.87 GB | 88.5% |
|
| 11 |
+
| Qwen3.5-35B-A3B | 19.5 GB | OOM | **5.78 tok/s** | 1.38 GB | 75% |
|
| 12 |
+
|
| 13 |
+
Both models exceed 16 GB RAM. Both produce coherent multi-paragraph output.
|
|
|
|
|
|
|
| 14 |
|
| 15 |
## Install
|
| 16 |
|
src/mlx_expert_sniper/config.py
CHANGED
|
@@ -52,19 +52,21 @@ class SniperConfig:
|
|
| 52 |
with open(config_path) as f:
|
| 53 |
model_config = json.load(f)
|
| 54 |
|
| 55 |
-
|
|
|
|
|
|
|
| 56 |
|
| 57 |
cfg = cls(
|
| 58 |
sniper_dir=sniper_dir,
|
| 59 |
bits=quant.get("bits", 4),
|
| 60 |
group_size=quant.get("group_size", 64),
|
| 61 |
-
num_hidden_layers=
|
| 62 |
-
num_experts=
|
| 63 |
-
num_experts_per_tok=
|
| 64 |
-
hidden_size=
|
| 65 |
-
moe_intermediate_size=
|
| 66 |
-
vocab_size=
|
| 67 |
-
norm_topk_prob=
|
| 68 |
model_type=model_config.get("model_type", ""),
|
| 69 |
tokenizer_name=model_config.get("_name_or_path", ""),
|
| 70 |
)
|
|
|
|
| 52 |
with open(config_path) as f:
|
| 53 |
model_config = json.load(f)
|
| 54 |
|
| 55 |
+
# Handle nested configs (qwen3_5_moe has text_config)
|
| 56 |
+
quant = model_config.get("quantization", model_config.get("quantization_config", {}))
|
| 57 |
+
text_cfg = model_config.get("text_config", model_config)
|
| 58 |
|
| 59 |
cfg = cls(
|
| 60 |
sniper_dir=sniper_dir,
|
| 61 |
bits=quant.get("bits", 4),
|
| 62 |
group_size=quant.get("group_size", 64),
|
| 63 |
+
num_hidden_layers=text_cfg.get("num_hidden_layers", 0),
|
| 64 |
+
num_experts=text_cfg.get("num_experts", 0),
|
| 65 |
+
num_experts_per_tok=text_cfg.get("num_experts_per_tok", 0),
|
| 66 |
+
hidden_size=text_cfg.get("hidden_size", 0),
|
| 67 |
+
moe_intermediate_size=text_cfg.get("moe_intermediate_size", 0),
|
| 68 |
+
vocab_size=text_cfg.get("vocab_size", 0),
|
| 69 |
+
norm_topk_prob=text_cfg.get("norm_topk_prob", True),
|
| 70 |
model_type=model_config.get("model_type", ""),
|
| 71 |
tokenizer_name=model_config.get("_name_or_path", ""),
|
| 72 |
)
|
src/mlx_expert_sniper/preprocess.py
CHANGED
|
@@ -163,6 +163,16 @@ def preprocess_model(model_dir: str, output_dir: str, verbose: bool = True):
|
|
| 163 |
layer = int(key.split(".layers.")[1].split(".")[0])
|
| 164 |
short = key.split(".switch_mlp.")[1]
|
| 165 |
layer_experts.setdefault(layer, {})[short] = tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
else:
|
| 167 |
pinned[key] = tensor
|
| 168 |
|
|
|
|
| 163 |
layer = int(key.split(".layers.")[1].split(".")[0])
|
| 164 |
short = key.split(".switch_mlp.")[1]
|
| 165 |
layer_experts.setdefault(layer, {})[short] = tensor
|
| 166 |
+
elif "experts.gate_up_proj" in key:
|
| 167 |
+
# qwen3_5_moe: fused gate+up proj — split into separate tensors
|
| 168 |
+
layer = int(key.split(".layers.")[1].split(".")[0])
|
| 169 |
+
gate_up = tensor
|
| 170 |
+
mid = gate_up.shape[-2] // 2
|
| 171 |
+
layer_experts.setdefault(layer, {})["gate_proj.weight"] = gate_up[..., :mid, :]
|
| 172 |
+
layer_experts.setdefault(layer, {})["up_proj.weight"] = gate_up[..., mid:, :]
|
| 173 |
+
elif "experts.down_proj" in key:
|
| 174 |
+
layer = int(key.split(".layers.")[1].split(".")[0])
|
| 175 |
+
layer_experts.setdefault(layer, {})["down_proj.weight"] = tensor
|
| 176 |
else:
|
| 177 |
pinned[key] = tensor
|
| 178 |
|
src/mlx_expert_sniper/sniper.py
CHANGED
|
@@ -108,8 +108,11 @@ class SniperEngine:
|
|
| 108 |
if mt in ("qwen3_moe", "qwen2_moe"):
|
| 109 |
from mlx_lm.models.qwen3_moe import Model, ModelArgs
|
| 110 |
return Model, ModelArgs
|
|
|
|
|
|
|
|
|
|
| 111 |
raise ValueError(f"Unsupported model_type: {mt}. "
|
| 112 |
-
f"Currently supported: qwen3_moe,
|
| 113 |
|
| 114 |
def _quantize_model(self, model_config: dict, quant: dict):
|
| 115 |
"""Apply quantization matching the stored format."""
|
|
@@ -129,35 +132,87 @@ class SniperEngine:
|
|
| 129 |
class_predicate=class_predicate,
|
| 130 |
)
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
def reset_kv_cache(self):
|
| 133 |
"""Reset the KV cache for a new conversation."""
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
def forward_token(self, input_ids: mx.array) -> mx.array:
|
| 138 |
"""Run one forward pass with expert sniping.
|
| 139 |
|
| 140 |
-
|
| 141 |
-
|
| 142 |
"""
|
| 143 |
from mlx_lm.models.base import create_attention_mask
|
| 144 |
|
| 145 |
cfg = self.config
|
| 146 |
bits = cfg.bits
|
| 147 |
group_size = cfg.group_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
-
|
| 150 |
-
mask = create_attention_mask(h, self.kv_cache[0])
|
| 151 |
-
|
| 152 |
-
for i, layer in enumerate(self.model.model.layers):
|
| 153 |
# ── Attention (pinned weights, always in RAM) ──
|
| 154 |
normed = layer.input_layernorm(h)
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
h = h + attn_out
|
| 157 |
mx.eval(h) # Must eval before router (data-dependent)
|
| 158 |
|
| 159 |
-
# ──
|
| 160 |
normed = layer.post_attention_layernorm(h)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
gates = layer.mlp.gate(normed)
|
| 162 |
gates = mx.softmax(gates, axis=-1, precise=True)
|
| 163 |
k = layer.mlp.top_k
|
|
@@ -226,11 +281,19 @@ class SniperEngine:
|
|
| 226 |
do = do.squeeze(-2)
|
| 227 |
|
| 228 |
# Weighted sum of expert outputs
|
| 229 |
-
|
| 230 |
del gw, gs, gb, uw, us, ub, dw, ds, db
|
| 231 |
|
| 232 |
-
|
| 233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
|
| 235 |
def generate(self, prompt, max_tokens=None, temperature=None,
|
| 236 |
chat_messages=None):
|
|
@@ -252,15 +315,27 @@ class SniperEngine:
|
|
| 252 |
temperature = temperature if temperature is not None else self.config.temperature
|
| 253 |
|
| 254 |
# Tokenize
|
| 255 |
-
if chat_messages:
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
else:
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
tokens = self.tokenizer.encode(text)
|
| 266 |
input_ids = mx.array([tokens])
|
|
@@ -275,14 +350,39 @@ class SniperEngine:
|
|
| 275 |
# Sample first token
|
| 276 |
next_token = self._sample(logits[:, -1, :], temperature)
|
| 277 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
# Autoregressive generation
|
| 279 |
for _ in range(max_tokens):
|
| 280 |
-
if next_token in
|
| 281 |
break
|
| 282 |
word = self.tokenizer.decode([next_token])
|
| 283 |
if "<|im_end|>" in word or "<|endoftext|>" in word:
|
| 284 |
break
|
| 285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
input_ids = mx.array([[next_token]])
|
| 288 |
logits = self.forward_token(input_ids)
|
|
|
|
| 108 |
if mt in ("qwen3_moe", "qwen2_moe"):
|
| 109 |
from mlx_lm.models.qwen3_moe import Model, ModelArgs
|
| 110 |
return Model, ModelArgs
|
| 111 |
+
if mt in ("qwen3_5_moe",):
|
| 112 |
+
from mlx_lm.models.qwen3_5_moe import Model, ModelArgs
|
| 113 |
+
return Model, ModelArgs
|
| 114 |
raise ValueError(f"Unsupported model_type: {mt}. "
|
| 115 |
+
f"Currently supported: qwen3_moe, qwen3_5_moe")
|
| 116 |
|
| 117 |
def _quantize_model(self, model_config: dict, quant: dict):
|
| 118 |
"""Apply quantization matching the stored format."""
|
|
|
|
| 132 |
class_predicate=class_predicate,
|
| 133 |
)
|
| 134 |
|
| 135 |
+
def _get_layers_and_head(self):
|
| 136 |
+
"""Get (layers, embed_tokens, norm, lm_head) for the model architecture."""
|
| 137 |
+
mt = self.config.model_type
|
| 138 |
+
if mt in ("qwen3_5_moe",):
|
| 139 |
+
lm = self.model.language_model
|
| 140 |
+
return lm.model.layers, lm.model.embed_tokens, lm.model.norm, lm.lm_head
|
| 141 |
+
return self.model.model.layers, self.model.model.embed_tokens, self.model.model.norm, self.model.lm_head
|
| 142 |
+
|
| 143 |
+
def _is_moe_layer(self, layer) -> bool:
|
| 144 |
+
"""Check if a layer has a MoE MLP (vs dense MLP)."""
|
| 145 |
+
return hasattr(layer.mlp, "gate") and hasattr(layer.mlp, "switch_mlp")
|
| 146 |
+
|
| 147 |
+
def _has_shared_expert(self, layer) -> bool:
|
| 148 |
+
"""Check if the MoE block has a shared expert."""
|
| 149 |
+
return hasattr(layer.mlp, "shared_expert")
|
| 150 |
+
|
| 151 |
def reset_kv_cache(self):
|
| 152 |
"""Reset the KV cache for a new conversation."""
|
| 153 |
+
if hasattr(self.model, "make_cache"):
|
| 154 |
+
self.kv_cache = self.model.make_cache()
|
| 155 |
+
else:
|
| 156 |
+
from mlx_lm.models.cache import make_prompt_cache
|
| 157 |
+
self.kv_cache = make_prompt_cache(self.model)
|
| 158 |
|
| 159 |
def forward_token(self, input_ids: mx.array) -> mx.array:
|
| 160 |
"""Run one forward pass with expert sniping.
|
| 161 |
|
| 162 |
+
Supports both qwen3_moe (standard attention) and qwen3_5_moe
|
| 163 |
+
(hybrid linear/full attention with shared experts).
|
| 164 |
"""
|
| 165 |
from mlx_lm.models.base import create_attention_mask
|
| 166 |
|
| 167 |
cfg = self.config
|
| 168 |
bits = cfg.bits
|
| 169 |
group_size = cfg.group_size
|
| 170 |
+
is_hybrid = cfg.model_type in ("qwen3_5_moe",)
|
| 171 |
+
|
| 172 |
+
layers, embed_tokens, norm, lm_head = self._get_layers_and_head()
|
| 173 |
+
h = embed_tokens(input_ids)
|
| 174 |
+
|
| 175 |
+
# Create masks
|
| 176 |
+
if is_hybrid:
|
| 177 |
+
from mlx_lm.models.base import create_ssm_mask
|
| 178 |
+
# Find first full-attention and first linear-attention layer cache
|
| 179 |
+
fa_cache = None
|
| 180 |
+
ssm_cache = None
|
| 181 |
+
for li, layer in enumerate(layers):
|
| 182 |
+
if hasattr(layer, "is_linear"):
|
| 183 |
+
if layer.is_linear and ssm_cache is None:
|
| 184 |
+
ssm_cache = self.kv_cache[li]
|
| 185 |
+
elif not layer.is_linear and fa_cache is None:
|
| 186 |
+
fa_cache = self.kv_cache[li]
|
| 187 |
+
fa_mask = create_attention_mask(h, fa_cache) if fa_cache is not None else None
|
| 188 |
+
ssm_mask = create_ssm_mask(h, ssm_cache) if ssm_cache is not None else None
|
| 189 |
+
else:
|
| 190 |
+
mask = create_attention_mask(h, self.kv_cache[0])
|
| 191 |
|
| 192 |
+
for i, layer in enumerate(layers):
|
|
|
|
|
|
|
|
|
|
| 193 |
# ── Attention (pinned weights, always in RAM) ──
|
| 194 |
normed = layer.input_layernorm(h)
|
| 195 |
+
|
| 196 |
+
if is_hybrid and hasattr(layer, "is_linear") and layer.is_linear:
|
| 197 |
+
attn_out = layer.linear_attn(normed, mask=ssm_mask, cache=self.kv_cache[i])
|
| 198 |
+
elif is_hybrid and hasattr(layer, "self_attn"):
|
| 199 |
+
attn_out = layer.self_attn(normed, mask=fa_mask, cache=self.kv_cache[i])
|
| 200 |
+
else:
|
| 201 |
+
attn_out = layer.self_attn(normed, mask=mask, cache=self.kv_cache[i])
|
| 202 |
+
|
| 203 |
h = h + attn_out
|
| 204 |
mx.eval(h) # Must eval before router (data-dependent)
|
| 205 |
|
| 206 |
+
# ── Post-attention norm ──
|
| 207 |
normed = layer.post_attention_layernorm(h)
|
| 208 |
+
|
| 209 |
+
# ── Check if this is an MoE layer ──
|
| 210 |
+
if not self._is_moe_layer(layer):
|
| 211 |
+
# Dense MLP — just run it (pinned weights)
|
| 212 |
+
h = h + layer.mlp(normed)
|
| 213 |
+
continue
|
| 214 |
+
|
| 215 |
+
# ── Router: compute expert scores ──
|
| 216 |
gates = layer.mlp.gate(normed)
|
| 217 |
gates = mx.softmax(gates, axis=-1, precise=True)
|
| 218 |
k = layer.mlp.top_k
|
|
|
|
| 281 |
do = do.squeeze(-2)
|
| 282 |
|
| 283 |
# Weighted sum of expert outputs
|
| 284 |
+
moe_out = (do * scores[..., None]).sum(axis=-2)
|
| 285 |
del gw, gs, gb, uw, us, ub, dw, ds, db
|
| 286 |
|
| 287 |
+
# ── Shared expert (pinned, always runs) ──
|
| 288 |
+
if self._has_shared_expert(layer):
|
| 289 |
+
shared_out = layer.mlp.shared_expert(normed)
|
| 290 |
+
shared_out = mx.sigmoid(layer.mlp.shared_expert_gate(normed)) * shared_out
|
| 291 |
+
h = h + moe_out + shared_out
|
| 292 |
+
else:
|
| 293 |
+
h = h + moe_out
|
| 294 |
+
|
| 295 |
+
h = norm(h)
|
| 296 |
+
return lm_head(h)
|
| 297 |
|
| 298 |
def generate(self, prompt, max_tokens=None, temperature=None,
|
| 299 |
chat_messages=None):
|
|
|
|
| 315 |
temperature = temperature if temperature is not None else self.config.temperature
|
| 316 |
|
| 317 |
# Tokenize
|
| 318 |
+
messages = chat_messages if chat_messages else [{"role": "user", "content": prompt}]
|
| 319 |
+
if self.tokenizer.chat_template:
|
| 320 |
+
try:
|
| 321 |
+
text = self.tokenizer.apply_chat_template(
|
| 322 |
+
messages, tokenize=False,
|
| 323 |
+
add_generation_prompt=True, enable_thinking=False)
|
| 324 |
+
except TypeError:
|
| 325 |
+
text = self.tokenizer.apply_chat_template(
|
| 326 |
+
messages, tokenize=False,
|
| 327 |
+
add_generation_prompt=True)
|
| 328 |
else:
|
| 329 |
+
# Fallback: Qwen/ChatML format
|
| 330 |
+
parts = []
|
| 331 |
+
for m in messages:
|
| 332 |
+
parts.append(f"<|im_start|>{m['role']}\n{m['content']}<|im_end|>")
|
| 333 |
+
parts.append("<|im_start|>assistant\n")
|
| 334 |
+
text = "\n".join(parts)
|
| 335 |
+
|
| 336 |
+
# Strip thinking from output by tracking state
|
| 337 |
+
self._in_thinking = False
|
| 338 |
+
self._thinking_buffer = ""
|
| 339 |
|
| 340 |
tokens = self.tokenizer.encode(text)
|
| 341 |
input_ids = mx.array([tokens])
|
|
|
|
| 350 |
# Sample first token
|
| 351 |
next_token = self._sample(logits[:, -1, :], temperature)
|
| 352 |
|
| 353 |
+
# Build EOS set from tokenizer config
|
| 354 |
+
eos_ids = set()
|
| 355 |
+
if hasattr(self.tokenizer, "eos_token_id"):
|
| 356 |
+
eid = self.tokenizer.eos_token_id
|
| 357 |
+
if isinstance(eid, list):
|
| 358 |
+
eos_ids.update(eid)
|
| 359 |
+
elif eid is not None:
|
| 360 |
+
eos_ids.add(eid)
|
| 361 |
+
eos_ids.update({151643, 151645, 248044, 248046}) # Qwen3 + Qwen3.5 EOS
|
| 362 |
+
|
| 363 |
# Autoregressive generation
|
| 364 |
for _ in range(max_tokens):
|
| 365 |
+
if next_token in eos_ids:
|
| 366 |
break
|
| 367 |
word = self.tokenizer.decode([next_token])
|
| 368 |
if "<|im_end|>" in word or "<|endoftext|>" in word:
|
| 369 |
break
|
| 370 |
+
|
| 371 |
+
# Filter out <think>...</think> blocks
|
| 372 |
+
self._thinking_buffer += word
|
| 373 |
+
if "<think>" in self._thinking_buffer:
|
| 374 |
+
self._in_thinking = True
|
| 375 |
+
self._thinking_buffer = ""
|
| 376 |
+
elif "</think>" in self._thinking_buffer:
|
| 377 |
+
self._in_thinking = False
|
| 378 |
+
# Yield anything after </think>
|
| 379 |
+
after = self._thinking_buffer.split("</think>", 1)[-1].lstrip("\n")
|
| 380 |
+
self._thinking_buffer = ""
|
| 381 |
+
if after:
|
| 382 |
+
yield after
|
| 383 |
+
elif not self._in_thinking:
|
| 384 |
+
yield self._thinking_buffer
|
| 385 |
+
self._thinking_buffer = ""
|
| 386 |
|
| 387 |
input_ids = mx.array([[next_token]])
|
| 388 |
logits = self.forward_token(input_ids)
|
stream_preprocess_35b.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Stream-preprocess Qwen3.5-35B-A3B-4bit: download one shard, process, delete."""
|
| 3 |
+
|
| 4 |
+
import os, sys, json, time, gc, shutil, glob
|
| 5 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
|
| 6 |
+
import numpy as np
|
| 7 |
+
import mlx.core as mx
|
| 8 |
+
|
| 9 |
+
REPO = "mlx-community/Qwen3.5-35B-A3B-4bit"
|
| 10 |
+
OUTPUT_DIR = os.path.expanduser("~/models/qwen35-35b")
|
| 11 |
+
PAGE_SIZE = 16384
|
| 12 |
+
|
| 13 |
+
TENSOR_NAMES = [
|
| 14 |
+
"gate_proj.weight", "gate_proj.scales", "gate_proj.biases",
|
| 15 |
+
"up_proj.weight", "up_proj.scales", "up_proj.biases",
|
| 16 |
+
"down_proj.weight", "down_proj.scales", "down_proj.biases",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def convert_layer_to_bin(layer_data, layer_idx, num_experts, output_dir):
|
| 21 |
+
tensor_info = {}
|
| 22 |
+
expert_block_size = 0
|
| 23 |
+
for name in TENSOR_NAMES:
|
| 24 |
+
if name not in layer_data:
|
| 25 |
+
continue
|
| 26 |
+
t = layer_data[name]
|
| 27 |
+
per_expert_shape = list(t.shape[1:])
|
| 28 |
+
if t.dtype == mx.uint32:
|
| 29 |
+
elem_size = 4
|
| 30 |
+
elif t.dtype in (mx.bfloat16, mx.float16):
|
| 31 |
+
elem_size = 2
|
| 32 |
+
else:
|
| 33 |
+
elem_size = 4
|
| 34 |
+
nbytes = 1
|
| 35 |
+
for s in per_expert_shape:
|
| 36 |
+
nbytes *= s
|
| 37 |
+
nbytes *= elem_size
|
| 38 |
+
tensor_info[name] = {
|
| 39 |
+
"shape_per_expert": per_expert_shape,
|
| 40 |
+
"dtype": str(t.dtype).replace("mlx.core.", ""),
|
| 41 |
+
"nbytes": nbytes,
|
| 42 |
+
"inner_offset": expert_block_size,
|
| 43 |
+
}
|
| 44 |
+
expert_block_size += nbytes
|
| 45 |
+
|
| 46 |
+
header = {
|
| 47 |
+
"layer_idx": layer_idx,
|
| 48 |
+
"num_experts": num_experts,
|
| 49 |
+
"layout": {
|
| 50 |
+
"expert_block_size": expert_block_size,
|
| 51 |
+
"data_start": PAGE_SIZE,
|
| 52 |
+
"tensors": tensor_info,
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
header_bytes = json.dumps(header, indent=2).encode()
|
| 56 |
+
assert len(header_bytes) < PAGE_SIZE
|
| 57 |
+
header_bytes += b"\x00" * (PAGE_SIZE - len(header_bytes))
|
| 58 |
+
|
| 59 |
+
out_path = os.path.join(output_dir, "bin", f"moe_layer_{layer_idx:02d}.bin")
|
| 60 |
+
with open(out_path, "wb") as f:
|
| 61 |
+
f.write(header_bytes)
|
| 62 |
+
for expert_id in range(num_experts):
|
| 63 |
+
for name in TENSOR_NAMES:
|
| 64 |
+
if name not in layer_data:
|
| 65 |
+
continue
|
| 66 |
+
t = layer_data[name][expert_id]
|
| 67 |
+
if t.dtype == mx.bfloat16:
|
| 68 |
+
raw = np.array(t.astype(mx.float16)).astype(np.float16).tobytes()
|
| 69 |
+
elif t.dtype == mx.uint32:
|
| 70 |
+
raw = np.array(t).astype(np.uint32).tobytes()
|
| 71 |
+
else:
|
| 72 |
+
raw = np.array(t).tobytes()
|
| 73 |
+
f.write(raw)
|
| 74 |
+
|
| 75 |
+
return os.path.getsize(out_path)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def main():
|
| 79 |
+
from huggingface_hub import hf_hub_download
|
| 80 |
+
|
| 81 |
+
print("=" * 55)
|
| 82 |
+
print(" Stream Preprocess Qwen3.5-35B-A3B-4bit")
|
| 83 |
+
print(f" Output: {OUTPUT_DIR}")
|
| 84 |
+
print("=" * 55)
|
| 85 |
+
|
| 86 |
+
os.makedirs(os.path.join(OUTPUT_DIR, "bin"), exist_ok=True)
|
| 87 |
+
|
| 88 |
+
# Download config + tokenizer
|
| 89 |
+
for fname in ["config.json", "tokenizer.json", "tokenizer_config.json",
|
| 90 |
+
"special_tokens_map.json"]:
|
| 91 |
+
try:
|
| 92 |
+
path = hf_hub_download(REPO, fname, local_dir="/tmp/sniper_dl_35b")
|
| 93 |
+
shutil.copy(path, os.path.join(OUTPUT_DIR, fname))
|
| 94 |
+
print(f" Downloaded {fname}")
|
| 95 |
+
except Exception as e:
|
| 96 |
+
print(f" Skipped {fname}: {e}")
|
| 97 |
+
|
| 98 |
+
with open(os.path.join(OUTPUT_DIR, "config.json")) as f:
|
| 99 |
+
config = json.load(f)
|
| 100 |
+
text_cfg = config.get("text_config", config)
|
| 101 |
+
num_layers = text_cfg.get("num_hidden_layers", 40)
|
| 102 |
+
print(f" Layers: {num_layers}, Experts: {text_cfg.get('num_experts', 0)}")
|
| 103 |
+
|
| 104 |
+
idx_path = hf_hub_download(REPO, "model.safetensors.index.json",
|
| 105 |
+
local_dir="/tmp/sniper_dl_35b")
|
| 106 |
+
with open(idx_path) as f:
|
| 107 |
+
idx = json.load(f)
|
| 108 |
+
shards = sorted(set(idx["weight_map"].values()))
|
| 109 |
+
print(f" {len(shards)} shards")
|
| 110 |
+
|
| 111 |
+
existing = set()
|
| 112 |
+
for f in os.listdir(os.path.join(OUTPUT_DIR, "bin")):
|
| 113 |
+
if f.startswith("moe_layer_") and f.endswith(".bin"):
|
| 114 |
+
existing.add(int(f.split("_")[2].split(".")[0]))
|
| 115 |
+
if existing:
|
| 116 |
+
print(f" Already done: {sorted(existing)}")
|
| 117 |
+
|
| 118 |
+
pinned = {}
|
| 119 |
+
layers_done = set(existing)
|
| 120 |
+
partial_layers = {}
|
| 121 |
+
|
| 122 |
+
for si, shard_name in enumerate(shards):
|
| 123 |
+
print(f"\n [{si+1}/{len(shards)}] Downloading {shard_name}...")
|
| 124 |
+
t0 = time.time()
|
| 125 |
+
shard_path = hf_hub_download(REPO, shard_name, local_dir="/tmp/sniper_dl_35b")
|
| 126 |
+
dl_time = time.time() - t0
|
| 127 |
+
shard_size = os.path.getsize(shard_path) / 1e9
|
| 128 |
+
print(f" {shard_size:.1f} GB in {dl_time:.0f}s")
|
| 129 |
+
|
| 130 |
+
data = mx.load(shard_path)
|
| 131 |
+
print(f" {len(data)} tensors")
|
| 132 |
+
|
| 133 |
+
layer_experts = {}
|
| 134 |
+
for key, tensor in data.items():
|
| 135 |
+
# Skip vision tower
|
| 136 |
+
if "vision_tower" in key or "model.visual" in key:
|
| 137 |
+
continue
|
| 138 |
+
|
| 139 |
+
if "switch_mlp" in key and ".layers." in key:
|
| 140 |
+
layer = int(key.split(".layers.")[1].split(".")[0])
|
| 141 |
+
short = key.split(".switch_mlp.")[1]
|
| 142 |
+
layer_experts.setdefault(layer, {})[short] = tensor
|
| 143 |
+
elif "experts.gate_up_proj" in key and ".layers." in key:
|
| 144 |
+
# Fused gate+up — split
|
| 145 |
+
layer = int(key.split(".layers.")[1].split(".")[0])
|
| 146 |
+
gate_up = tensor
|
| 147 |
+
mid = gate_up.shape[-2] // 2
|
| 148 |
+
ld = layer_experts.setdefault(layer, {})
|
| 149 |
+
ld["gate_proj.weight"] = gate_up[..., :mid, :]
|
| 150 |
+
ld["up_proj.weight"] = gate_up[..., mid:, :]
|
| 151 |
+
elif "experts.down_proj" in key and ".layers." in key:
|
| 152 |
+
layer = int(key.split(".layers.")[1].split(".")[0])
|
| 153 |
+
layer_experts.setdefault(layer, {})["down_proj.weight"] = tensor
|
| 154 |
+
else:
|
| 155 |
+
pinned[key] = tensor
|
| 156 |
+
|
| 157 |
+
for layer_idx, tensors in layer_experts.items():
|
| 158 |
+
if layer_idx in layers_done:
|
| 159 |
+
continue
|
| 160 |
+
if layer_idx in partial_layers:
|
| 161 |
+
partial_layers[layer_idx].update(tensors)
|
| 162 |
+
tensors = partial_layers[layer_idx]
|
| 163 |
+
|
| 164 |
+
# Check how many tensor groups we have
|
| 165 |
+
# For quantized: need weight + scales + biases for each of gate/up/down = 9
|
| 166 |
+
# For non-quantized: just weight for gate/up/down = 3
|
| 167 |
+
n_keys = len(tensors)
|
| 168 |
+
has_all = all(f"{p}.weight" in tensors for p in ["gate_proj", "up_proj", "down_proj"])
|
| 169 |
+
|
| 170 |
+
if not has_all:
|
| 171 |
+
partial_layers[layer_idx] = tensors
|
| 172 |
+
print(f" Layer {layer_idx}: partial ({n_keys} tensors)")
|
| 173 |
+
continue
|
| 174 |
+
|
| 175 |
+
num_experts = tensors["gate_proj.weight"].shape[0]
|
| 176 |
+
sz = convert_layer_to_bin(tensors, layer_idx, num_experts, OUTPUT_DIR)
|
| 177 |
+
layers_done.add(layer_idx)
|
| 178 |
+
if layer_idx in partial_layers:
|
| 179 |
+
del partial_layers[layer_idx]
|
| 180 |
+
print(f" Layer {layer_idx}: {sz/1e6:.0f} MB ({num_experts} experts)")
|
| 181 |
+
|
| 182 |
+
del data, layer_experts
|
| 183 |
+
gc.collect()
|
| 184 |
+
mx.clear_cache()
|
| 185 |
+
|
| 186 |
+
try:
|
| 187 |
+
os.remove(shard_path)
|
| 188 |
+
print(f" Deleted shard ({shard_size:.1f} GB freed)")
|
| 189 |
+
except:
|
| 190 |
+
pass
|
| 191 |
+
|
| 192 |
+
# Handle remaining partials
|
| 193 |
+
for layer_idx, tensors in partial_layers.items():
|
| 194 |
+
if layer_idx in layers_done:
|
| 195 |
+
continue
|
| 196 |
+
has_all = all(f"{p}.weight" in tensors for p in ["gate_proj", "up_proj", "down_proj"])
|
| 197 |
+
if has_all:
|
| 198 |
+
num_experts = tensors["gate_proj.weight"].shape[0]
|
| 199 |
+
sz = convert_layer_to_bin(tensors, layer_idx, num_experts, OUTPUT_DIR)
|
| 200 |
+
layers_done.add(layer_idx)
|
| 201 |
+
print(f" Layer {layer_idx}: {sz/1e6:.0f} MB (merged)")
|
| 202 |
+
|
| 203 |
+
# Save pinned
|
| 204 |
+
if pinned:
|
| 205 |
+
print(f"\n Saving pinned ({len(pinned)} tensors)...")
|
| 206 |
+
mx.save_safetensors(os.path.join(OUTPUT_DIR, "pinned.safetensors"), pinned)
|
| 207 |
+
psz = os.path.getsize(os.path.join(OUTPUT_DIR, "pinned.safetensors")) / 1e9
|
| 208 |
+
print(f" Pinned: {psz:.2f} GB")
|
| 209 |
+
else:
|
| 210 |
+
psz = 0
|
| 211 |
+
|
| 212 |
+
shutil.rmtree("/tmp/sniper_dl_35b", ignore_errors=True)
|
| 213 |
+
|
| 214 |
+
bin_files = sorted(glob.glob(os.path.join(OUTPUT_DIR, "bin", "moe_layer_*.bin")))
|
| 215 |
+
total = sum(os.path.getsize(f) for f in bin_files)
|
| 216 |
+
missing = set(range(num_layers)) - layers_done
|
| 217 |
+
print(f"\n Expert layers: {len(bin_files)}/{num_layers}")
|
| 218 |
+
print(f" Expert total: {total/1e9:.2f} GB")
|
| 219 |
+
print(f" Pinned: {psz:.2f} GB")
|
| 220 |
+
if missing:
|
| 221 |
+
print(f" WARNING: Missing layers: {sorted(missing)}")
|
| 222 |
+
else:
|
| 223 |
+
print(f"\n All {num_layers} layers converted!")
|
| 224 |
+
print(f" Test: mlx-sniper run {OUTPUT_DIR} -p 'What is 2+2?' -v")
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
if __name__ == "__main__":
|
| 228 |
+
main()
|