| |
| """ |
| Stream-preprocess: download one shard at a time, process it, delete it. |
| Avoids needing 17 GB free for the full model download. |
| |
| Usage: |
| python3 stream_preprocess.py |
| """ |
|
|
| import os |
| import sys |
| import json |
| import time |
| import gc |
| import shutil |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src")) |
|
|
| import numpy as np |
| import mlx.core as mx |
|
|
| REPO = "mlx-community/Qwen3-30B-A3B-4bit" |
| OUTPUT_DIR = os.path.expanduser("~/models/qwen3-30b") |
| PAGE_SIZE = 16384 |
|
|
| TENSOR_NAMES = [ |
| "gate_proj.weight", "gate_proj.scales", "gate_proj.biases", |
| "up_proj.weight", "up_proj.scales", "up_proj.biases", |
| "down_proj.weight", "down_proj.scales", "down_proj.biases", |
| ] |
|
|
|
|
| def convert_layer_to_bin(layer_data, layer_idx, num_experts, output_dir): |
| tensor_info = {} |
| expert_block_size = 0 |
| for name in TENSOR_NAMES: |
| t = layer_data[name] |
| per_expert_shape = list(t.shape[1:]) |
| if t.dtype == mx.uint32: |
| elem_size = 4 |
| elif t.dtype in (mx.bfloat16, mx.float16): |
| elem_size = 2 |
| else: |
| elem_size = 4 |
| nbytes = 1 |
| for s in per_expert_shape: |
| nbytes *= s |
| nbytes *= elem_size |
| tensor_info[name] = { |
| "shape_per_expert": per_expert_shape, |
| "dtype": str(t.dtype).replace("mlx.core.", ""), |
| "nbytes": nbytes, |
| "inner_offset": expert_block_size, |
| } |
| expert_block_size += nbytes |
|
|
| header = { |
| "layer_idx": layer_idx, |
| "num_experts": num_experts, |
| "layout": { |
| "expert_block_size": expert_block_size, |
| "data_start": PAGE_SIZE, |
| "tensors": tensor_info, |
| } |
| } |
| header_bytes = json.dumps(header, indent=2).encode() |
| assert len(header_bytes) < PAGE_SIZE |
| header_bytes += b"\x00" * (PAGE_SIZE - len(header_bytes)) |
|
|
| out_path = os.path.join(output_dir, "bin", f"moe_layer_{layer_idx:02d}.bin") |
| with open(out_path, "wb") as f: |
| f.write(header_bytes) |
| for expert_id in range(num_experts): |
| for name in TENSOR_NAMES: |
| t = layer_data[name][expert_id] |
| if t.dtype == mx.bfloat16: |
| raw = np.array(t.astype(mx.float16)).astype(np.float16).tobytes() |
| elif t.dtype == mx.uint32: |
| raw = np.array(t).astype(np.uint32).tobytes() |
| else: |
| raw = np.array(t).tobytes() |
| f.write(raw) |
|
|
| return os.path.getsize(out_path) |
|
|
|
|
| def main(): |
| from huggingface_hub import hf_hub_download, HfApi |
|
|
| print("=" * 55) |
| print(" Stream Preprocess — one shard at a time") |
| print(f" Model: {REPO}") |
| print(f" Output: {OUTPUT_DIR}") |
| print("=" * 55) |
|
|
| os.makedirs(os.path.join(OUTPUT_DIR, "bin"), exist_ok=True) |
|
|
| |
| for fname in ["config.json", "tokenizer.json", "tokenizer_config.json", |
| "special_tokens_map.json"]: |
| try: |
| path = hf_hub_download(REPO, fname, local_dir="/tmp/sniper_dl") |
| shutil.copy(path, os.path.join(OUTPUT_DIR, fname)) |
| print(f" Downloaded {fname}") |
| except Exception as e: |
| print(f" Skipped {fname}: {e}") |
|
|
| |
| with open(os.path.join(OUTPUT_DIR, "config.json")) as f: |
| config = json.load(f) |
| num_layers = config.get("num_hidden_layers", 48) |
|
|
| |
| idx_path = hf_hub_download(REPO, "model.safetensors.index.json", |
| local_dir="/tmp/sniper_dl") |
| with open(idx_path) as f: |
| idx = json.load(f) |
| shards = sorted(set(idx["weight_map"].values())) |
| print(f"\n {len(shards)} shards to process") |
|
|
| |
| existing = set() |
| for f in os.listdir(os.path.join(OUTPUT_DIR, "bin")): |
| if f.startswith("moe_layer_") and f.endswith(".bin"): |
| existing.add(int(f.split("_")[2].split(".")[0])) |
| if existing: |
| print(f" Already done: layers {sorted(existing)}") |
|
|
| pinned = {} |
| layers_done = set(existing) |
|
|
| |
| partial_layers = {} |
|
|
| for si, shard_name in enumerate(shards): |
| print(f"\n [{si+1}/{len(shards)}] Downloading {shard_name}...") |
| t0 = time.time() |
|
|
| shard_path = hf_hub_download(REPO, shard_name, local_dir="/tmp/sniper_dl") |
| dl_time = time.time() - t0 |
| shard_size = os.path.getsize(shard_path) / 1e9 |
| print(f" Downloaded {shard_size:.1f} GB in {dl_time:.0f}s") |
|
|
| print(f" Loading tensors...") |
| data = mx.load(shard_path) |
| print(f" {len(data)} tensors") |
|
|
| |
| layer_experts = {} |
| for key, tensor in data.items(): |
| if "switch_mlp" in key: |
| layer = int(key.split(".layers.")[1].split(".")[0]) |
| short = key.split(".switch_mlp.")[1] |
| layer_experts.setdefault(layer, {})[short] = tensor |
| else: |
| pinned[key] = tensor |
|
|
| |
| for layer_idx, tensors in layer_experts.items(): |
| if layer_idx in layers_done: |
| continue |
|
|
| |
| if layer_idx in partial_layers: |
| partial_layers[layer_idx].update(tensors) |
| tensors = partial_layers[layer_idx] |
|
|
| if len(tensors) < 9: |
| |
| partial_layers[layer_idx] = tensors |
| print(f" Layer {layer_idx}: partial ({len(tensors)}/9 tensors)") |
| continue |
|
|
| num_experts = tensors[list(tensors.keys())[0]].shape[0] |
| sz = convert_layer_to_bin(tensors, layer_idx, num_experts, OUTPUT_DIR) |
| layers_done.add(layer_idx) |
| if layer_idx in partial_layers: |
| del partial_layers[layer_idx] |
| print(f" Layer {layer_idx}: {sz/1e6:.0f} MB") |
|
|
| del data, layer_experts |
| gc.collect() |
| mx.clear_cache() |
|
|
| |
| try: |
| os.remove(shard_path) |
| print(f" Deleted shard ({shard_size:.1f} GB freed)") |
| except: |
| pass |
|
|
| |
| if partial_layers: |
| print(f"\n {len(partial_layers)} partial layers remain — re-downloading...") |
| for layer_idx, tensors in partial_layers.items(): |
| if layer_idx in layers_done: |
| continue |
| if len(tensors) >= 9: |
| num_experts = tensors[list(tensors.keys())[0]].shape[0] |
| sz = convert_layer_to_bin(tensors, layer_idx, num_experts, OUTPUT_DIR) |
| layers_done.add(layer_idx) |
| print(f" Layer {layer_idx}: {sz/1e6:.0f} MB (merged)") |
|
|
| |
| pinned_path = os.path.join(OUTPUT_DIR, "pinned.safetensors") |
| if pinned: |
| print(f"\n Saving pinned ({len(pinned)} tensors)...") |
| mx.save_safetensors(pinned_path, pinned) |
| psz = os.path.getsize(pinned_path) / 1e9 |
| print(f" Pinned: {psz:.2f} GB") |
| else: |
| psz = os.path.getsize(pinned_path) / 1e9 if os.path.exists(pinned_path) else 0 |
|
|
| |
| shutil.rmtree("/tmp/sniper_dl", ignore_errors=True) |
|
|
| |
| import glob |
| bin_files = sorted(glob.glob(os.path.join(OUTPUT_DIR, "bin", "moe_layer_*.bin"))) |
| total = sum(os.path.getsize(f) for f in bin_files) |
| print(f"\n Expert layers: {len(bin_files)}/{num_layers}") |
| print(f" Expert total: {total/1e9:.2f} GB") |
| print(f" Pinned: {psz:.2f} GB") |
| print(f" Total: {(total/1e9 + psz):.2f} GB") |
|
|
| missing = set(range(num_layers)) - layers_done |
| if missing: |
| print(f"\n WARNING: Missing layers: {sorted(missing)}") |
| else: |
| print(f"\n All {num_layers} layers converted!") |
| print(f"\n Test with:") |
| print(f" mlx-sniper run {OUTPUT_DIR} -p 'What is 2+2?' -v") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|