context-1-MLX-MXFP4 / convert_context1_to_mlx.py
foadmk's picture
Upload chromadb/context-1 converted to MLX MXFP4
a1ca0b7 verified
#!/usr/bin/env python3
"""
Convert chromadb/context-1 to MLX MXFP4 format.
chromadb/context-1 uses a different weight format than openai/gpt-oss-20b:
- Dense BF16 tensors (not quantized blocks)
- gate_up_proj shape: (experts, hidden, intermediate*2) - interleaved
This script handles the correct weight transformation for MLX.
"""
import os
import shutil
from huggingface_hub import snapshot_download
from safetensors import safe_open
from safetensors.torch import save_file
import subprocess
def convert_context1_to_mlx(output_dir: str, quantize: bool = True, q_bits: int = 4):
"""Convert chromadb/context-1 to MLX format."""
temp_dir = "/tmp/context1-mlx-converted"
os.makedirs(temp_dir, exist_ok=True)
print("=== Step 1: Download chromadb/context-1 ===")
model_path = snapshot_download("chromadb/context-1")
sf_path = os.path.join(model_path, "model.safetensors")
print("\n=== Step 2: Transform weights ===")
new_weights = {}
with safe_open(sf_path, framework="pt") as f:
for key in f.keys():
tensor = f.get_tensor(key)
if "mlp.experts.gate_up_proj_bias" in key:
# Interleaved split: even indices = gate, odd = up
gate_bias = tensor[:, ::2].contiguous()
up_bias = tensor[:, 1::2].contiguous()
new_weights[key.replace("gate_up_proj_bias", "gate_proj.bias")] = gate_bias
new_weights[key.replace("gate_up_proj_bias", "up_proj.bias")] = up_bias
elif "mlp.experts.gate_up_proj" in key and "bias" not in key:
# Transpose: (experts, hidden, intermediate*2) -> (experts, intermediate*2, hidden)
# Then interleaved split on middle dimension
t = tensor.transpose(1, 2)
gate_weight = t[:, ::2, :].contiguous()
up_weight = t[:, 1::2, :].contiguous()
new_weights[key.replace("gate_up_proj", "gate_proj.weight")] = gate_weight
new_weights[key.replace("gate_up_proj", "up_proj.weight")] = up_weight
elif "mlp.experts.down_proj_bias" in key:
new_weights[key.replace("down_proj_bias", "down_proj.bias")] = tensor
elif "mlp.experts.down_proj" in key and "bias" not in key:
# Transpose down_proj
t = tensor.transpose(1, 2).contiguous()
new_weights[key.replace("down_proj", "down_proj.weight")] = t
else:
new_weights[key] = tensor
# Save transformed weights
save_file(new_weights, os.path.join(temp_dir, "model.safetensors"))
# Copy config files
for fname in ["config.json", "tokenizer.json", "tokenizer_config.json", "generation_config.json"]:
src = os.path.join(model_path, fname)
if os.path.exists(src):
shutil.copy(src, temp_dir)
print("\n=== Step 3: Convert to MLX ===")
cmd = ["python", "-m", "mlx_lm", "convert", "--hf-path", temp_dir, "--mlx-path", output_dir]
if quantize:
cmd.extend(["-q", "--q-bits", str(q_bits)])
subprocess.run(cmd, check=True)
print(f"\n=== Done! Model saved to {output_dir} ===")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Convert chromadb/context-1 to MLX")
parser.add_argument("--output", "-o", default="./context1-mlx-mxfp4", help="Output directory")
parser.add_argument("--no-quantize", action="store_true", help="Skip quantization (save as FP16)")
parser.add_argument("--q-bits", type=int, default=4, help="Quantization bits (default: 4)")
args = parser.parse_args()
convert_context1_to_mlx(args.output, quantize=not args.no_quantize, q_bits=args.q_bits)