#!/usr/bin/env python3 """Stream-preprocess Qwen3.5-35B-A3B-4bit: download one shard, process, delete.""" import os, sys, json, time, gc, shutil, glob sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src")) import numpy as np import mlx.core as mx REPO = "mlx-community/Qwen3.5-35B-A3B-4bit" OUTPUT_DIR = os.path.expanduser("~/models/qwen35-35b") PAGE_SIZE = 16384 TENSOR_NAMES = [ "gate_proj.weight", "gate_proj.scales", "gate_proj.biases", "up_proj.weight", "up_proj.scales", "up_proj.biases", "down_proj.weight", "down_proj.scales", "down_proj.biases", ] def convert_layer_to_bin(layer_data, layer_idx, num_experts, output_dir): tensor_info = {} expert_block_size = 0 for name in TENSOR_NAMES: if name not in layer_data: continue t = layer_data[name] per_expert_shape = list(t.shape[1:]) if t.dtype == mx.uint32: elem_size = 4 elif t.dtype in (mx.bfloat16, mx.float16): elem_size = 2 else: elem_size = 4 nbytes = 1 for s in per_expert_shape: nbytes *= s nbytes *= elem_size tensor_info[name] = { "shape_per_expert": per_expert_shape, "dtype": str(t.dtype).replace("mlx.core.", ""), "nbytes": nbytes, "inner_offset": expert_block_size, } expert_block_size += nbytes header = { "layer_idx": layer_idx, "num_experts": num_experts, "layout": { "expert_block_size": expert_block_size, "data_start": PAGE_SIZE, "tensors": tensor_info, } } header_bytes = json.dumps(header, indent=2).encode() assert len(header_bytes) < PAGE_SIZE header_bytes += b"\x00" * (PAGE_SIZE - len(header_bytes)) out_path = os.path.join(output_dir, "bin", f"moe_layer_{layer_idx:02d}.bin") with open(out_path, "wb") as f: f.write(header_bytes) for expert_id in range(num_experts): for name in TENSOR_NAMES: if name not in layer_data: continue t = layer_data[name][expert_id] if t.dtype == mx.bfloat16: raw = np.array(t.astype(mx.float16)).astype(np.float16).tobytes() elif t.dtype == mx.uint32: raw = np.array(t).astype(np.uint32).tobytes() else: raw = np.array(t).tobytes() f.write(raw) return os.path.getsize(out_path) def main(): from huggingface_hub import hf_hub_download print("=" * 55) print(" Stream Preprocess Qwen3.5-35B-A3B-4bit") print(f" Output: {OUTPUT_DIR}") print("=" * 55) os.makedirs(os.path.join(OUTPUT_DIR, "bin"), exist_ok=True) # Download config + tokenizer for fname in ["config.json", "tokenizer.json", "tokenizer_config.json", "special_tokens_map.json"]: try: path = hf_hub_download(REPO, fname, local_dir="/tmp/sniper_dl_35b") shutil.copy(path, os.path.join(OUTPUT_DIR, fname)) print(f" Downloaded {fname}") except Exception as e: print(f" Skipped {fname}: {e}") with open(os.path.join(OUTPUT_DIR, "config.json")) as f: config = json.load(f) text_cfg = config.get("text_config", config) num_layers = text_cfg.get("num_hidden_layers", 40) print(f" Layers: {num_layers}, Experts: {text_cfg.get('num_experts', 0)}") idx_path = hf_hub_download(REPO, "model.safetensors.index.json", local_dir="/tmp/sniper_dl_35b") with open(idx_path) as f: idx = json.load(f) shards = sorted(set(idx["weight_map"].values())) print(f" {len(shards)} shards") existing = set() for f in os.listdir(os.path.join(OUTPUT_DIR, "bin")): if f.startswith("moe_layer_") and f.endswith(".bin"): existing.add(int(f.split("_")[2].split(".")[0])) if existing: print(f" Already done: {sorted(existing)}") pinned = {} layers_done = set(existing) partial_layers = {} for si, shard_name in enumerate(shards): print(f"\n [{si+1}/{len(shards)}] Downloading {shard_name}...") t0 = time.time() shard_path = hf_hub_download(REPO, shard_name, local_dir="/tmp/sniper_dl_35b") dl_time = time.time() - t0 shard_size = os.path.getsize(shard_path) / 1e9 print(f" {shard_size:.1f} GB in {dl_time:.0f}s") data = mx.load(shard_path) print(f" {len(data)} tensors") layer_experts = {} for key, tensor in data.items(): # Skip vision tower if "vision_tower" in key or "model.visual" in key: continue if "switch_mlp" in key and ".layers." in key: layer = int(key.split(".layers.")[1].split(".")[0]) short = key.split(".switch_mlp.")[1] layer_experts.setdefault(layer, {})[short] = tensor elif "experts.gate_up_proj" in key and ".layers." in key: # Fused gate+up — split layer = int(key.split(".layers.")[1].split(".")[0]) gate_up = tensor mid = gate_up.shape[-2] // 2 ld = layer_experts.setdefault(layer, {}) ld["gate_proj.weight"] = gate_up[..., :mid, :] ld["up_proj.weight"] = gate_up[..., mid:, :] elif "experts.down_proj" in key and ".layers." in key: layer = int(key.split(".layers.")[1].split(".")[0]) layer_experts.setdefault(layer, {})["down_proj.weight"] = tensor else: pinned[key] = tensor for layer_idx, tensors in layer_experts.items(): if layer_idx in layers_done: continue if layer_idx in partial_layers: partial_layers[layer_idx].update(tensors) tensors = partial_layers[layer_idx] # Check how many tensor groups we have # For quantized: need weight + scales + biases for each of gate/up/down = 9 # For non-quantized: just weight for gate/up/down = 3 n_keys = len(tensors) has_all = all(f"{p}.weight" in tensors for p in ["gate_proj", "up_proj", "down_proj"]) if not has_all: partial_layers[layer_idx] = tensors print(f" Layer {layer_idx}: partial ({n_keys} tensors)") continue num_experts = tensors["gate_proj.weight"].shape[0] sz = convert_layer_to_bin(tensors, layer_idx, num_experts, OUTPUT_DIR) layers_done.add(layer_idx) if layer_idx in partial_layers: del partial_layers[layer_idx] print(f" Layer {layer_idx}: {sz/1e6:.0f} MB ({num_experts} experts)") del data, layer_experts gc.collect() mx.clear_cache() try: os.remove(shard_path) print(f" Deleted shard ({shard_size:.1f} GB freed)") except: pass # Handle remaining partials for layer_idx, tensors in partial_layers.items(): if layer_idx in layers_done: continue has_all = all(f"{p}.weight" in tensors for p in ["gate_proj", "up_proj", "down_proj"]) if has_all: num_experts = tensors["gate_proj.weight"].shape[0] sz = convert_layer_to_bin(tensors, layer_idx, num_experts, OUTPUT_DIR) layers_done.add(layer_idx) print(f" Layer {layer_idx}: {sz/1e6:.0f} MB (merged)") # Save pinned if pinned: print(f"\n Saving pinned ({len(pinned)} tensors)...") mx.save_safetensors(os.path.join(OUTPUT_DIR, "pinned.safetensors"), pinned) psz = os.path.getsize(os.path.join(OUTPUT_DIR, "pinned.safetensors")) / 1e9 print(f" Pinned: {psz:.2f} GB") else: psz = 0 shutil.rmtree("/tmp/sniper_dl_35b", ignore_errors=True) bin_files = sorted(glob.glob(os.path.join(OUTPUT_DIR, "bin", "moe_layer_*.bin"))) total = sum(os.path.getsize(f) for f in bin_files) missing = set(range(num_layers)) - layers_done print(f"\n Expert layers: {len(bin_files)}/{num_layers}") print(f" Expert total: {total/1e9:.2f} GB") print(f" Pinned: {psz:.2f} GB") if missing: print(f" WARNING: Missing layers: {sorted(missing)}") else: print(f"\n All {num_layers} layers converted!") print(f" Test: mlx-sniper run {OUTPUT_DIR} -p 'What is 2+2?' -v") if __name__ == "__main__": main()