| |
| """ |
| Export Mamba2-130M Model Weights for OpenCL |
| |
| This script loads the pretrained mamba2-130m model and exports all weights |
| to binary files that OpenCL can load. |
| |
| Usage: |
| python export_mamba2_130m_weights.py [output_dir] [model_name] |
| |
| output_dir: Directory to save exported weights (default: ./mamba2_130m_weights) |
| model_name: HuggingFace model name (default: cartesia-ai/mamba2-130m-mlx) |
| """ |
|
|
| import sys |
| import os |
| import json |
| import numpy as np |
| import mlx.core as mx |
| from pathlib import Path |
|
|
| |
| sys.path.insert(0, '../cartesia-mlx') |
| import cartesia_mlx as cmx |
|
|
|
|
| def export_weights(model, output_dir): |
| """Export all model weights to binary files.""" |
| os.makedirs(output_dir, exist_ok=True) |
| |
| print(f"Exporting Mamba2-130M model weights to: {output_dir}") |
| print(f"Model config: d_model={model.d_model}, vocab={model.n_tokens}") |
| print() |
| |
| |
| def get_weight(layer, attr_path): |
| """Navigate nested attributes to get weight, handling quantized layers.""" |
| obj = layer |
| for attr in attr_path.split('.'): |
| if attr == 'weight' and hasattr(obj, 'weight'): |
| |
| if hasattr(obj, 'scales') and hasattr(obj, 'biases'): |
| |
| w = obj.weight |
| s = obj.scales |
| b = obj.biases |
| group_size = obj.group_size |
| bits = obj.bits |
| |
| dequantized = mx.dequantize(w, s, b, group_size, bits) |
| mx.eval(dequantized) |
| return np.array(dequantized, dtype=np.float32) |
| else: |
| return np.array(obj.weight, dtype=np.float32) |
| obj = getattr(obj, attr) |
| return np.array(obj, dtype=np.float32) |
| |
| |
| print("Exporting embedding...") |
| emb_weight = np.array(model.embedding.encoder.weight, dtype=np.float32) |
| emb_weight.tofile(f"{output_dir}/embedding_weight.bin") |
| print(f" embedding_weight.bin: {emb_weight.shape}") |
| |
| |
| print(f"\nExporting {len(model.model.layers)} layers...") |
| for i, layer in enumerate(model.model.layers): |
| |
| layer_obj = layer.layer |
| layer_type = type(layer_obj).__name__ |
| |
| |
| |
| if hasattr(layer_obj, 'conv_weight') or hasattr(layer_obj, 'A'): |
| layer_dir = f"{output_dir}/layer_{i:03d}_ssd" |
| else: |
| layer_dir = f"{output_dir}/layer_{i:03d}" |
| |
| os.makedirs(layer_dir, exist_ok=True) |
| |
| print(f" Layer {i}: {layer_type} -> {os.path.basename(layer_dir)}") |
| |
| try: |
| |
| if hasattr(layer_obj, 'conv_weight') or hasattr(layer_obj, 'A'): |
| |
| |
| if hasattr(layer_obj.in_proj, 'linear'): |
| in_proj_w = get_weight(layer_obj.in_proj.linear, 'weight') |
| else: |
| in_proj_w = get_weight(layer_obj.in_proj, 'weight') |
| in_proj_w.tofile(f"{layer_dir}/in_proj_weight.bin") |
| print(f" in_proj: {in_proj_w.shape}") |
| |
| |
| if hasattr(layer_obj, 'conv_weight'): |
| conv_w = np.array(layer_obj.conv_weight, dtype=np.float32) |
| conv_w.tofile(f"{layer_dir}/conv_weight.bin") |
| print(f" conv_weight: {conv_w.shape}") |
| |
| if hasattr(layer_obj, 'conv_bias'): |
| conv_b = np.array(layer_obj.conv_bias, dtype=np.float32) |
| conv_b.tofile(f"{layer_dir}/conv_bias.bin") |
| print(f" conv_bias: {conv_b.shape}") |
| |
| |
| if hasattr(layer_obj, 'A'): |
| A_log = np.array(layer_obj.A, dtype=np.float32) |
| A_log.tofile(f"{layer_dir}/A_log.bin") |
| print(f" A_log: {A_log.shape}") |
| |
| if hasattr(layer_obj, 'dt_bias'): |
| dt_bias = np.array(layer_obj.dt_bias, dtype=np.float32) |
| dt_bias.tofile(f"{layer_dir}/dt_bias.bin") |
| print(f" dt_bias: {dt_bias.shape}") |
| |
| if hasattr(layer_obj, 'D'): |
| D = np.array(layer_obj.D, dtype=np.float32) |
| D.tofile(f"{layer_dir}/D.bin") |
| print(f" D: {D.shape}") |
| |
| |
| if hasattr(layer_obj.out_proj, 'linear'): |
| out_proj_w = get_weight(layer_obj.out_proj.linear, 'weight') |
| else: |
| out_proj_w = get_weight(layer_obj.out_proj, 'weight') |
| out_proj_w.tofile(f"{layer_dir}/out_proj_weight.bin") |
| print(f" out_proj: {out_proj_w.shape}") |
| |
| |
| if hasattr(layer_obj, 'rms_norm'): |
| rms_norm_w = np.array(layer_obj.rms_norm.weight, dtype=np.float32) |
| rms_norm_w.tofile(f"{layer_dir}/rms_norm_weight.bin") |
| print(f" rms_norm (internal): {rms_norm_w.shape}") |
| |
| |
| elif hasattr(layer_obj, 'in_proj') and hasattr(layer_obj, 'gate_proj') and hasattr(layer_obj, 'out_proj'): |
| |
| if hasattr(layer_obj.gate_proj, 'linear'): |
| gate_w = get_weight(layer_obj.gate_proj.linear, 'weight') |
| else: |
| gate_w = get_weight(layer_obj.gate_proj, 'weight') |
| gate_w.tofile(f"{layer_dir}/gate_weight.bin") |
| print(f" gate_proj: {gate_w.shape}") |
| |
| if hasattr(layer_obj.in_proj, 'linear'): |
| in_proj_w = get_weight(layer_obj.in_proj.linear, 'weight') |
| else: |
| in_proj_w = get_weight(layer_obj.in_proj, 'weight') |
| in_proj_w.tofile(f"{layer_dir}/in_proj_weight.bin") |
| print(f" in_proj: {in_proj_w.shape}") |
| |
| if hasattr(layer_obj.out_proj, 'linear'): |
| out_proj_w = get_weight(layer_obj.out_proj.linear, 'weight') |
| else: |
| out_proj_w = get_weight(layer_obj.out_proj, 'weight') |
| out_proj_w.tofile(f"{layer_dir}/out_proj_weight.bin") |
| print(f" out_proj: {out_proj_w.shape}") |
| |
| |
| elif hasattr(layer_obj, 'qkv_proj') and hasattr(layer_obj, 'out_proj'): |
| |
| if hasattr(layer_obj.qkv_proj, 'linear'): |
| qkv_w = get_weight(layer_obj.qkv_proj.linear, 'weight') |
| else: |
| qkv_w = get_weight(layer_obj.qkv_proj, 'weight') |
| qkv_w.tofile(f"{layer_dir}/qkv_weight.bin") |
| print(f" qkv_proj: {qkv_w.shape}") |
| |
| if hasattr(layer_obj.out_proj, 'linear'): |
| out_w = get_weight(layer_obj.out_proj.linear, 'weight') |
| else: |
| out_w = get_weight(layer_obj.out_proj, 'weight') |
| out_w.tofile(f"{layer_dir}/out_weight.bin") |
| print(f" out_proj: {out_w.shape}") |
| |
| |
| if hasattr(layer, 'norm') and layer.norm is not None: |
| norm_w = np.array(layer.norm.weight, dtype=np.float32) |
| norm_w.tofile(f"{layer_dir}/norm_weight.bin") |
| print(f" norm: {norm_w.shape}") |
| |
| except Exception as e: |
| print(f" ERROR exporting layer {i}: {e}") |
| import traceback |
| traceback.print_exc() |
| continue |
| |
| |
| if hasattr(model.model, 'post_norm') and model.model.post_norm: |
| if hasattr(model.model, 'norm') and model.model.norm is not None: |
| print("\nExporting post-norm weights...") |
| post_norm_w = np.array(model.model.norm.weight, dtype=np.float32) |
| post_norm_w.tofile(f"{output_dir}/post_norm_weight.bin") |
| print(f" post_norm_weight.bin: {post_norm_w.shape}") |
| |
| |
| print("\nExporting LM head...") |
| if hasattr(model.head, 'linear'): |
| lm_head_w = get_weight(model.head.linear, 'weight') |
| else: |
| lm_head_w = get_weight(model.head, 'weight') |
| lm_head_w.tofile(f"{output_dir}/lm_head_weight.bin") |
| print(f" lm_head_weight.bin: {lm_head_w.shape}") |
| |
| |
| metadata = { |
| "model_name": "mamba2-130m-mlx", |
| "d_model": int(model.d_model), |
| "n_tokens": int(model.n_tokens), |
| "n_layers": len(model.model.layers), |
| "model_type": "mamba2" |
| } |
| with open(f"{output_dir}/metadata.json", 'w') as f: |
| json.dump(metadata, f, indent=2) |
| print(f"\n✓ Metadata saved to {output_dir}/metadata.json") |
| |
| print(f"\n✓ Export complete! Weights saved to: {output_dir}") |
| return output_dir |
|
|
|
|
| def main(): |
| import argparse |
| |
| parser = argparse.ArgumentParser(description="Export Mamba2-130M weights for OpenCL") |
| parser.add_argument("output_dir", nargs='?', default="./mamba2_130m_weights", |
| help="Output directory for weights (default: ./mamba2_130m_weights)") |
| parser.add_argument("--model", default="cartesia-ai/mamba2-130m-mlx", |
| help="Model name or path (default: cartesia-ai/mamba2-130m-mlx)") |
| |
| args = parser.parse_args() |
| |
| print(f"Loading model: {args.model}") |
| print("This may take a moment...") |
| |
| try: |
| model = cmx.from_pretrained(args.model) |
| model.set_dtype(mx.float32) |
| print(f"✓ Model loaded: {type(model).__name__}") |
| print(f" d_model={model.d_model}, vocab={model.n_tokens}, layers={len(model.model.layers)}") |
| |
| export_weights(model, args.output_dir) |
| |
| except Exception as e: |
| print(f"Error: {e}") |
| import traceback |
| traceback.print_exc() |
| return 1 |
| |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|
|
|
|
|