mlx-expert-sniper / split_gemma4.py
waltgrace's picture
initial release: deploy code + split scripts
0e41b61 verified
#!/usr/bin/env python3
"""
Split Gemma 4 SwitchLinear stacked experts into per-expert bin files.
Gemma 4 stores experts as (128, out, in) stacked tensors.
This script unstacks them into the layer_XX.bin format that expert_io.py reads.
"""
import os, json, gc, time, glob, argparse
import numpy as np
import mlx.core as mx
PAGE_SIZE = 16384
def main():
parser = argparse.ArgumentParser(description="Split Gemma 4 for Expert Sniper")
parser.add_argument("--input", "-i", default="~/models/gemma4-26b-4bit")
parser.add_argument("--output", "-o", default="~/models/gemma4-stream")
args = parser.parse_args()
INPUT_DIR = os.path.expanduser(args.input)
OUTPUT_DIR = os.path.expanduser(args.output)
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(f"{OUTPUT_DIR}/bin", exist_ok=True)
config = json.load(open(f"{INPUT_DIR}/config.json"))
tc = config.get("text_config", config)
NUM_LAYERS = tc["num_hidden_layers"]
NUM_EXPERTS = tc["num_experts"]
print(f"Gemma 4 Split (SwitchLinear unstack)")
print(f" Input: {INPUT_DIR}")
print(f" Output: {OUTPUT_DIR}")
print(f" Layers: {NUM_LAYERS}, Experts: {NUM_EXPERTS}")
print()
# Load all weights
print("Loading safetensors...")
t0 = time.time()
all_weights = {}
for sf in sorted(glob.glob(f"{INPUT_DIR}/model-*.safetensors")):
print(f" {os.path.basename(sf)}")
all_weights.update(mx.load(sf))
# Identify expert and pinned keys
pinned = {}
expert_tensors = {} # layer_idx -> {tensor_name: (128, ...)}
EXPERT_PREFIX = "language_model.model.layers.{}.experts.switch_glu.{}.{}"
PROJ_NAMES = ["gate_proj", "up_proj", "down_proj"]
COMP_NAMES = ["weight", "scales", "biases"]
for key, val in all_weights.items():
is_expert = False
for li in range(NUM_LAYERS):
for proj in PROJ_NAMES:
for comp in COMP_NAMES:
expected = EXPERT_PREFIX.format(li, proj, comp)
if key == expected:
if li not in expert_tensors:
expert_tensors[li] = {}
# Store with the name format expert_io expects
tensor_name = f"switch_mlp.{proj}.{comp}"
expert_tensors[li][tensor_name] = val
is_expert = True
break
if is_expert:
break
if is_expert:
break
if not is_expert:
pinned[key] = val
print(f"\n Expert layers: {len(expert_tensors)}")
print(f" Pinned keys: {len(pinned)}")
# Determine per-expert block layout from first layer
first_layer = expert_tensors[0]
tensor_layout = {}
inner_offset = 0
for tname in sorted(first_layer.keys()):
arr = first_layer[tname]
# Shape is (128, ...) — per-expert shape is arr.shape[1:]
per_expert_shape = list(arr.shape[1:])
# dtype
if arr.dtype == mx.uint32:
dtype_str = "uint32"
elem_size = 4
elif arr.dtype == mx.bfloat16:
dtype_str = "bfloat16"
elem_size = 2
elif arr.dtype == mx.float16:
dtype_str = "float16"
elem_size = 2
elif arr.dtype == mx.float32:
dtype_str = "float32"
elem_size = 4
else:
dtype_str = str(arr.dtype).replace("mlx.core.", "")
elem_size = 2
nbytes = 1
for d in per_expert_shape:
nbytes *= d
nbytes *= elem_size
tensor_layout[tname] = {
"inner_offset": inner_offset,
"nbytes": nbytes,
"shape_per_expert": per_expert_shape,
"dtype": dtype_str,
}
inner_offset += nbytes
expert_block_size = inner_offset
data_start = PAGE_SIZE
print(f" Expert block: {expert_block_size} bytes ({expert_block_size/1024:.1f} KB)")
print()
# Write layer files
total_expert_bytes = 0
for layer_idx in range(NUM_LAYERS):
lt = time.time()
layer_data = expert_tensors[layer_idx]
header = {
"format": "expert_sniper_v1",
"model": "gemma4-26b-a4b",
"layer_idx": layer_idx,
"num_experts": NUM_EXPERTS,
"layout": {
"expert_block_size": expert_block_size,
"data_start": data_start,
"tensors": tensor_layout,
}
}
header_bytes = json.dumps(header, indent=2).encode("utf-8")
assert len(header_bytes) < PAGE_SIZE
header_padded = header_bytes + b"\x00" * (PAGE_SIZE - len(header_bytes))
layer_path = f"{OUTPUT_DIR}/bin/layer_{layer_idx:02d}.bin"
with open(layer_path, "wb") as f:
f.write(header_padded)
for eid in range(NUM_EXPERTS):
expert_data = bytearray()
for tname in sorted(tensor_layout.keys()):
stacked = layer_data[tname] # (128, ...)
single = stacked[eid] # (...)
mx.eval(single)
if single.dtype == mx.uint32:
np_arr = np.array(single).view(np.uint32)
elif single.dtype == mx.bfloat16:
np_arr = np.array(single.view(mx.uint16))
elif single.dtype == mx.float32:
np_arr = np.array(single).view(np.float32)
elif single.dtype == mx.float16:
np_arr = np.array(single).view(np.uint16)
else:
np_arr = np.array(single)
expert_data.extend(np_arr.tobytes())
# Pad to exact block size
if len(expert_data) < expert_block_size:
expert_data.extend(b"\x00" * (expert_block_size - len(expert_data)))
f.write(bytes(expert_data[:expert_block_size]))
file_size = os.path.getsize(layer_path)
total_expert_bytes += file_size
elapsed = time.time() - lt
print(f" Layer {layer_idx:2d}/{NUM_LAYERS}: {file_size/1e6:.1f} MB ({elapsed:.0f}s)")
# Free this layer's expert data
del expert_tensors[layer_idx]
gc.collect()
# Save pinned
pinned_path = f"{OUTPUT_DIR}/pinned.safetensors"
mx.save_safetensors(pinned_path, pinned)
pinned_bytes = sum(v.nbytes for v in pinned.values())
print(f"\nSaved pinned.safetensors: {pinned_bytes/1e9:.2f} GB ({len(pinned)} keys)")
# Config for streaming
stream_config = dict(tc)
stream_config["quantization"] = config.get("quantization", {"bits": 4, "group_size": 64})
stream_config["streaming"] = {"pinned_file": "pinned.safetensors", "expert_dir": "bin"}
with open(f"{OUTPUT_DIR}/config.json", "w") as f:
json.dump(stream_config, f, indent=2)
# Copy tokenizer files
import shutil
for tf in ["tokenizer.json", "tokenizer_config.json", "chat_template.jinja",
"generation_config.json", "processor_config.json"]:
src = f"{INPUT_DIR}/{tf}"
if os.path.exists(src):
shutil.copy(src, f"{OUTPUT_DIR}/{tf}")
elapsed = time.time() - t0
print(f"\nDone in {elapsed:.0f}s!")
print(f"Pinned: {pinned_bytes/1e9:.2f} GB, Experts: {total_expert_bytes/1e9:.2f} GB")
if __name__ == "__main__":
main()