mlx-expert-sniper / stream_preprocess_35b.py
waltgrace's picture
v0.2.0: Add Qwen3.5-35B-A3B support (5.78 tok/s, 19.5 GB on 16 GB RAM)
d14a3c2 verified
#!/usr/bin/env python3
"""Stream-preprocess Qwen3.5-35B-A3B-4bit: download one shard, process, delete."""
import os, sys, json, time, gc, shutil, glob
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
import numpy as np
import mlx.core as mx
REPO = "mlx-community/Qwen3.5-35B-A3B-4bit"
OUTPUT_DIR = os.path.expanduser("~/models/qwen35-35b")
PAGE_SIZE = 16384
TENSOR_NAMES = [
"gate_proj.weight", "gate_proj.scales", "gate_proj.biases",
"up_proj.weight", "up_proj.scales", "up_proj.biases",
"down_proj.weight", "down_proj.scales", "down_proj.biases",
]
def convert_layer_to_bin(layer_data, layer_idx, num_experts, output_dir):
tensor_info = {}
expert_block_size = 0
for name in TENSOR_NAMES:
if name not in layer_data:
continue
t = layer_data[name]
per_expert_shape = list(t.shape[1:])
if t.dtype == mx.uint32:
elem_size = 4
elif t.dtype in (mx.bfloat16, mx.float16):
elem_size = 2
else:
elem_size = 4
nbytes = 1
for s in per_expert_shape:
nbytes *= s
nbytes *= elem_size
tensor_info[name] = {
"shape_per_expert": per_expert_shape,
"dtype": str(t.dtype).replace("mlx.core.", ""),
"nbytes": nbytes,
"inner_offset": expert_block_size,
}
expert_block_size += nbytes
header = {
"layer_idx": layer_idx,
"num_experts": num_experts,
"layout": {
"expert_block_size": expert_block_size,
"data_start": PAGE_SIZE,
"tensors": tensor_info,
}
}
header_bytes = json.dumps(header, indent=2).encode()
assert len(header_bytes) < PAGE_SIZE
header_bytes += b"\x00" * (PAGE_SIZE - len(header_bytes))
out_path = os.path.join(output_dir, "bin", f"moe_layer_{layer_idx:02d}.bin")
with open(out_path, "wb") as f:
f.write(header_bytes)
for expert_id in range(num_experts):
for name in TENSOR_NAMES:
if name not in layer_data:
continue
t = layer_data[name][expert_id]
if t.dtype == mx.bfloat16:
raw = np.array(t.astype(mx.float16)).astype(np.float16).tobytes()
elif t.dtype == mx.uint32:
raw = np.array(t).astype(np.uint32).tobytes()
else:
raw = np.array(t).tobytes()
f.write(raw)
return os.path.getsize(out_path)
def main():
from huggingface_hub import hf_hub_download
print("=" * 55)
print(" Stream Preprocess Qwen3.5-35B-A3B-4bit")
print(f" Output: {OUTPUT_DIR}")
print("=" * 55)
os.makedirs(os.path.join(OUTPUT_DIR, "bin"), exist_ok=True)
# Download config + tokenizer
for fname in ["config.json", "tokenizer.json", "tokenizer_config.json",
"special_tokens_map.json"]:
try:
path = hf_hub_download(REPO, fname, local_dir="/tmp/sniper_dl_35b")
shutil.copy(path, os.path.join(OUTPUT_DIR, fname))
print(f" Downloaded {fname}")
except Exception as e:
print(f" Skipped {fname}: {e}")
with open(os.path.join(OUTPUT_DIR, "config.json")) as f:
config = json.load(f)
text_cfg = config.get("text_config", config)
num_layers = text_cfg.get("num_hidden_layers", 40)
print(f" Layers: {num_layers}, Experts: {text_cfg.get('num_experts', 0)}")
idx_path = hf_hub_download(REPO, "model.safetensors.index.json",
local_dir="/tmp/sniper_dl_35b")
with open(idx_path) as f:
idx = json.load(f)
shards = sorted(set(idx["weight_map"].values()))
print(f" {len(shards)} shards")
existing = set()
for f in os.listdir(os.path.join(OUTPUT_DIR, "bin")):
if f.startswith("moe_layer_") and f.endswith(".bin"):
existing.add(int(f.split("_")[2].split(".")[0]))
if existing:
print(f" Already done: {sorted(existing)}")
pinned = {}
layers_done = set(existing)
partial_layers = {}
for si, shard_name in enumerate(shards):
print(f"\n [{si+1}/{len(shards)}] Downloading {shard_name}...")
t0 = time.time()
shard_path = hf_hub_download(REPO, shard_name, local_dir="/tmp/sniper_dl_35b")
dl_time = time.time() - t0
shard_size = os.path.getsize(shard_path) / 1e9
print(f" {shard_size:.1f} GB in {dl_time:.0f}s")
data = mx.load(shard_path)
print(f" {len(data)} tensors")
layer_experts = {}
for key, tensor in data.items():
# Skip vision tower
if "vision_tower" in key or "model.visual" in key:
continue
if "switch_mlp" in key and ".layers." in key:
layer = int(key.split(".layers.")[1].split(".")[0])
short = key.split(".switch_mlp.")[1]
layer_experts.setdefault(layer, {})[short] = tensor
elif "experts.gate_up_proj" in key and ".layers." in key:
# Fused gate+up — split
layer = int(key.split(".layers.")[1].split(".")[0])
gate_up = tensor
mid = gate_up.shape[-2] // 2
ld = layer_experts.setdefault(layer, {})
ld["gate_proj.weight"] = gate_up[..., :mid, :]
ld["up_proj.weight"] = gate_up[..., mid:, :]
elif "experts.down_proj" in key and ".layers." in key:
layer = int(key.split(".layers.")[1].split(".")[0])
layer_experts.setdefault(layer, {})["down_proj.weight"] = tensor
else:
pinned[key] = tensor
for layer_idx, tensors in layer_experts.items():
if layer_idx in layers_done:
continue
if layer_idx in partial_layers:
partial_layers[layer_idx].update(tensors)
tensors = partial_layers[layer_idx]
# Check how many tensor groups we have
# For quantized: need weight + scales + biases for each of gate/up/down = 9
# For non-quantized: just weight for gate/up/down = 3
n_keys = len(tensors)
has_all = all(f"{p}.weight" in tensors for p in ["gate_proj", "up_proj", "down_proj"])
if not has_all:
partial_layers[layer_idx] = tensors
print(f" Layer {layer_idx}: partial ({n_keys} tensors)")
continue
num_experts = tensors["gate_proj.weight"].shape[0]
sz = convert_layer_to_bin(tensors, layer_idx, num_experts, OUTPUT_DIR)
layers_done.add(layer_idx)
if layer_idx in partial_layers:
del partial_layers[layer_idx]
print(f" Layer {layer_idx}: {sz/1e6:.0f} MB ({num_experts} experts)")
del data, layer_experts
gc.collect()
mx.clear_cache()
try:
os.remove(shard_path)
print(f" Deleted shard ({shard_size:.1f} GB freed)")
except:
pass
# Handle remaining partials
for layer_idx, tensors in partial_layers.items():
if layer_idx in layers_done:
continue
has_all = all(f"{p}.weight" in tensors for p in ["gate_proj", "up_proj", "down_proj"])
if has_all:
num_experts = tensors["gate_proj.weight"].shape[0]
sz = convert_layer_to_bin(tensors, layer_idx, num_experts, OUTPUT_DIR)
layers_done.add(layer_idx)
print(f" Layer {layer_idx}: {sz/1e6:.0f} MB (merged)")
# Save pinned
if pinned:
print(f"\n Saving pinned ({len(pinned)} tensors)...")
mx.save_safetensors(os.path.join(OUTPUT_DIR, "pinned.safetensors"), pinned)
psz = os.path.getsize(os.path.join(OUTPUT_DIR, "pinned.safetensors")) / 1e9
print(f" Pinned: {psz:.2f} GB")
else:
psz = 0
shutil.rmtree("/tmp/sniper_dl_35b", ignore_errors=True)
bin_files = sorted(glob.glob(os.path.join(OUTPUT_DIR, "bin", "moe_layer_*.bin")))
total = sum(os.path.getsize(f) for f in bin_files)
missing = set(range(num_layers)) - layers_done
print(f"\n Expert layers: {len(bin_files)}/{num_layers}")
print(f" Expert total: {total/1e9:.2f} GB")
print(f" Pinned: {psz:.2f} GB")
if missing:
print(f" WARNING: Missing layers: {sorted(missing)}")
else:
print(f"\n All {num_layers} layers converted!")
print(f" Test: mlx-sniper run {OUTPUT_DIR} -p 'What is 2+2?' -v")
if __name__ == "__main__":
main()