| |
| """ |
| Convert chromadb/context-1 to MLX MXFP4 format. |
| |
| chromadb/context-1 uses a different weight format than openai/gpt-oss-20b: |
| - Dense BF16 tensors (not quantized blocks) |
| - gate_up_proj shape: (experts, hidden, intermediate*2) - interleaved |
| |
| This script handles the correct weight transformation for MLX. |
| """ |
|
|
| import os |
| import shutil |
| from huggingface_hub import snapshot_download |
| from safetensors import safe_open |
| from safetensors.torch import save_file |
| import subprocess |
|
|
| def convert_context1_to_mlx(output_dir: str, quantize: bool = True, q_bits: int = 4): |
| """Convert chromadb/context-1 to MLX format.""" |
|
|
| temp_dir = "/tmp/context1-mlx-converted" |
| os.makedirs(temp_dir, exist_ok=True) |
|
|
| print("=== Step 1: Download chromadb/context-1 ===") |
| model_path = snapshot_download("chromadb/context-1") |
| sf_path = os.path.join(model_path, "model.safetensors") |
|
|
| print("\n=== Step 2: Transform weights ===") |
| new_weights = {} |
| with safe_open(sf_path, framework="pt") as f: |
| for key in f.keys(): |
| tensor = f.get_tensor(key) |
|
|
| if "mlp.experts.gate_up_proj_bias" in key: |
| |
| gate_bias = tensor[:, ::2].contiguous() |
| up_bias = tensor[:, 1::2].contiguous() |
| new_weights[key.replace("gate_up_proj_bias", "gate_proj.bias")] = gate_bias |
| new_weights[key.replace("gate_up_proj_bias", "up_proj.bias")] = up_bias |
|
|
| elif "mlp.experts.gate_up_proj" in key and "bias" not in key: |
| |
| |
| t = tensor.transpose(1, 2) |
| gate_weight = t[:, ::2, :].contiguous() |
| up_weight = t[:, 1::2, :].contiguous() |
| new_weights[key.replace("gate_up_proj", "gate_proj.weight")] = gate_weight |
| new_weights[key.replace("gate_up_proj", "up_proj.weight")] = up_weight |
|
|
| elif "mlp.experts.down_proj_bias" in key: |
| new_weights[key.replace("down_proj_bias", "down_proj.bias")] = tensor |
|
|
| elif "mlp.experts.down_proj" in key and "bias" not in key: |
| |
| t = tensor.transpose(1, 2).contiguous() |
| new_weights[key.replace("down_proj", "down_proj.weight")] = t |
|
|
| else: |
| new_weights[key] = tensor |
|
|
| |
| save_file(new_weights, os.path.join(temp_dir, "model.safetensors")) |
|
|
| |
| for fname in ["config.json", "tokenizer.json", "tokenizer_config.json", "generation_config.json"]: |
| src = os.path.join(model_path, fname) |
| if os.path.exists(src): |
| shutil.copy(src, temp_dir) |
|
|
| print("\n=== Step 3: Convert to MLX ===") |
| cmd = ["python", "-m", "mlx_lm", "convert", "--hf-path", temp_dir, "--mlx-path", output_dir] |
| if quantize: |
| cmd.extend(["-q", "--q-bits", str(q_bits)]) |
|
|
| subprocess.run(cmd, check=True) |
|
|
| print(f"\n=== Done! Model saved to {output_dir} ===") |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser(description="Convert chromadb/context-1 to MLX") |
| parser.add_argument("--output", "-o", default="./context1-mlx-mxfp4", help="Output directory") |
| parser.add_argument("--no-quantize", action="store_true", help="Skip quantization (save as FP16)") |
| parser.add_argument("--q-bits", type=int, default=4, help="Quantization bits (default: 4)") |
| args = parser.parse_args() |
|
|
| convert_context1_to_mlx(args.output, quantize=not args.no_quantize, q_bits=args.q_bits) |
|
|