File size: 8,174 Bytes
2d8a9bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
#!/usr/bin/env python3
"""
Stream-preprocess: download one shard at a time, process it, delete it.
Avoids needing 17 GB free for the full model download.

Usage:
    python3 stream_preprocess.py
"""

import os
import sys
import json
import time
import gc
import shutil

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))

import numpy as np
import mlx.core as mx

REPO = "mlx-community/Qwen3-30B-A3B-4bit"
OUTPUT_DIR = os.path.expanduser("~/models/qwen3-30b")
PAGE_SIZE = 16384

TENSOR_NAMES = [
    "gate_proj.weight", "gate_proj.scales", "gate_proj.biases",
    "up_proj.weight", "up_proj.scales", "up_proj.biases",
    "down_proj.weight", "down_proj.scales", "down_proj.biases",
]


def convert_layer_to_bin(layer_data, layer_idx, num_experts, output_dir):
    tensor_info = {}
    expert_block_size = 0
    for name in TENSOR_NAMES:
        t = layer_data[name]
        per_expert_shape = list(t.shape[1:])
        if t.dtype == mx.uint32:
            elem_size = 4
        elif t.dtype in (mx.bfloat16, mx.float16):
            elem_size = 2
        else:
            elem_size = 4
        nbytes = 1
        for s in per_expert_shape:
            nbytes *= s
        nbytes *= elem_size
        tensor_info[name] = {
            "shape_per_expert": per_expert_shape,
            "dtype": str(t.dtype).replace("mlx.core.", ""),
            "nbytes": nbytes,
            "inner_offset": expert_block_size,
        }
        expert_block_size += nbytes

    header = {
        "layer_idx": layer_idx,
        "num_experts": num_experts,
        "layout": {
            "expert_block_size": expert_block_size,
            "data_start": PAGE_SIZE,
            "tensors": tensor_info,
        }
    }
    header_bytes = json.dumps(header, indent=2).encode()
    assert len(header_bytes) < PAGE_SIZE
    header_bytes += b"\x00" * (PAGE_SIZE - len(header_bytes))

    out_path = os.path.join(output_dir, "bin", f"moe_layer_{layer_idx:02d}.bin")
    with open(out_path, "wb") as f:
        f.write(header_bytes)
        for expert_id in range(num_experts):
            for name in TENSOR_NAMES:
                t = layer_data[name][expert_id]
                if t.dtype == mx.bfloat16:
                    raw = np.array(t.astype(mx.float16)).astype(np.float16).tobytes()
                elif t.dtype == mx.uint32:
                    raw = np.array(t).astype(np.uint32).tobytes()
                else:
                    raw = np.array(t).tobytes()
                f.write(raw)

    return os.path.getsize(out_path)


def main():
    from huggingface_hub import hf_hub_download, HfApi

    print("=" * 55)
    print("  Stream Preprocess — one shard at a time")
    print(f"  Model:  {REPO}")
    print(f"  Output: {OUTPUT_DIR}")
    print("=" * 55)

    os.makedirs(os.path.join(OUTPUT_DIR, "bin"), exist_ok=True)

    # Download config + tokenizer files (small)
    for fname in ["config.json", "tokenizer.json", "tokenizer_config.json",
                   "special_tokens_map.json"]:
        try:
            path = hf_hub_download(REPO, fname, local_dir="/tmp/sniper_dl")
            shutil.copy(path, os.path.join(OUTPUT_DIR, fname))
            print(f"  Downloaded {fname}")
        except Exception as e:
            print(f"  Skipped {fname}: {e}")

    # Get shard list
    with open(os.path.join(OUTPUT_DIR, "config.json")) as f:
        config = json.load(f)
    num_layers = config.get("num_hidden_layers", 48)

    # Download the index to find shard names
    idx_path = hf_hub_download(REPO, "model.safetensors.index.json",
                                local_dir="/tmp/sniper_dl")
    with open(idx_path) as f:
        idx = json.load(f)
    shards = sorted(set(idx["weight_map"].values()))
    print(f"\n  {len(shards)} shards to process")

    # Check which layers already done
    existing = set()
    for f in os.listdir(os.path.join(OUTPUT_DIR, "bin")):
        if f.startswith("moe_layer_") and f.endswith(".bin"):
            existing.add(int(f.split("_")[2].split(".")[0]))
    if existing:
        print(f"  Already done: layers {sorted(existing)}")

    pinned = {}
    layers_done = set(existing)

    # Track partial layers that span shards
    partial_layers = {}

    for si, shard_name in enumerate(shards):
        print(f"\n  [{si+1}/{len(shards)}] Downloading {shard_name}...")
        t0 = time.time()

        shard_path = hf_hub_download(REPO, shard_name, local_dir="/tmp/sniper_dl")
        dl_time = time.time() - t0
        shard_size = os.path.getsize(shard_path) / 1e9
        print(f"    Downloaded {shard_size:.1f} GB in {dl_time:.0f}s")

        print(f"    Loading tensors...")
        data = mx.load(shard_path)
        print(f"    {len(data)} tensors")

        # Classify
        layer_experts = {}
        for key, tensor in data.items():
            if "switch_mlp" in key:
                layer = int(key.split(".layers.")[1].split(".")[0])
                short = key.split(".switch_mlp.")[1]
                layer_experts.setdefault(layer, {})[short] = tensor
            else:
                pinned[key] = tensor

        # Convert complete expert layers
        for layer_idx, tensors in layer_experts.items():
            if layer_idx in layers_done:
                continue

            # Merge with partial data from previous shards
            if layer_idx in partial_layers:
                partial_layers[layer_idx].update(tensors)
                tensors = partial_layers[layer_idx]

            if len(tensors) < 9:
                # Partial — save for later
                partial_layers[layer_idx] = tensors
                print(f"    Layer {layer_idx}: partial ({len(tensors)}/9 tensors)")
                continue

            num_experts = tensors[list(tensors.keys())[0]].shape[0]
            sz = convert_layer_to_bin(tensors, layer_idx, num_experts, OUTPUT_DIR)
            layers_done.add(layer_idx)
            if layer_idx in partial_layers:
                del partial_layers[layer_idx]
            print(f"    Layer {layer_idx}: {sz/1e6:.0f} MB")

        del data, layer_experts
        gc.collect()
        mx.clear_cache()

        # Delete the downloaded shard to free disk
        try:
            os.remove(shard_path)
            print(f"    Deleted shard ({shard_size:.1f} GB freed)")
        except:
            pass

    # Handle remaining partial layers
    if partial_layers:
        print(f"\n  {len(partial_layers)} partial layers remain — re-downloading...")
        for layer_idx, tensors in partial_layers.items():
            if layer_idx in layers_done:
                continue
            if len(tensors) >= 9:
                num_experts = tensors[list(tensors.keys())[0]].shape[0]
                sz = convert_layer_to_bin(tensors, layer_idx, num_experts, OUTPUT_DIR)
                layers_done.add(layer_idx)
                print(f"    Layer {layer_idx}: {sz/1e6:.0f} MB (merged)")

    # Save pinned if we don't already have it or if it's stale
    pinned_path = os.path.join(OUTPUT_DIR, "pinned.safetensors")
    if pinned:
        print(f"\n  Saving pinned ({len(pinned)} tensors)...")
        mx.save_safetensors(pinned_path, pinned)
        psz = os.path.getsize(pinned_path) / 1e9
        print(f"    Pinned: {psz:.2f} GB")
    else:
        psz = os.path.getsize(pinned_path) / 1e9 if os.path.exists(pinned_path) else 0

    # Clean up temp downloads
    shutil.rmtree("/tmp/sniper_dl", ignore_errors=True)

    # Summary
    import glob
    bin_files = sorted(glob.glob(os.path.join(OUTPUT_DIR, "bin", "moe_layer_*.bin")))
    total = sum(os.path.getsize(f) for f in bin_files)
    print(f"\n  Expert layers: {len(bin_files)}/{num_layers}")
    print(f"  Expert total:  {total/1e9:.2f} GB")
    print(f"  Pinned:        {psz:.2f} GB")
    print(f"  Total:         {(total/1e9 + psz):.2f} GB")

    missing = set(range(num_layers)) - layers_done
    if missing:
        print(f"\n  WARNING: Missing layers: {sorted(missing)}")
    else:
        print(f"\n  All {num_layers} layers converted!")
        print(f"\n  Test with:")
        print(f"    mlx-sniper run {OUTPUT_DIR} -p 'What is 2+2?' -v")


if __name__ == "__main__":
    main()