mamba2-130m-opencl / export_mamba2_130m_weights.py
a8nova's picture
Upload export_mamba2_130m_weights.py
64bad2b verified
#!/usr/bin/env python3
"""
Export Mamba2-130M Model Weights for OpenCL
This script loads the pretrained mamba2-130m model and exports all weights
to binary files that OpenCL can load.
Usage:
python export_mamba2_130m_weights.py [output_dir] [model_name]
output_dir: Directory to save exported weights (default: ./mamba2_130m_weights)
model_name: HuggingFace model name (default: cartesia-ai/mamba2-130m-mlx)
"""
import sys
import os
import json
import numpy as np
import mlx.core as mx
from pathlib import Path
# Add path to cartesia_mlx
sys.path.insert(0, '../cartesia-mlx')
import cartesia_mlx as cmx
def export_weights(model, output_dir):
"""Export all model weights to binary files."""
os.makedirs(output_dir, exist_ok=True)
print(f"Exporting Mamba2-130M model weights to: {output_dir}")
print(f"Model config: d_model={model.d_model}, vocab={model.n_tokens}")
print()
# Helper to get weight from different layer types
def get_weight(layer, attr_path):
"""Navigate nested attributes to get weight, handling quantized layers."""
obj = layer
for attr in attr_path.split('.'):
if attr == 'weight' and hasattr(obj, 'weight'):
# For QuantizedLinear, use MLX's built-in dequantize
if hasattr(obj, 'scales') and hasattr(obj, 'biases'):
# This is a QuantizedLinear - use MLX's dequantize function
w = obj.weight
s = obj.scales
b = obj.biases
group_size = obj.group_size
bits = obj.bits
# Use MLX's dequantize function
dequantized = mx.dequantize(w, s, b, group_size, bits)
mx.eval(dequantized)
return np.array(dequantized, dtype=np.float32)
else:
return np.array(obj.weight, dtype=np.float32)
obj = getattr(obj, attr)
return np.array(obj, dtype=np.float32)
# Export embedding weights
print("Exporting embedding...")
emb_weight = np.array(model.embedding.encoder.weight, dtype=np.float32)
emb_weight.tofile(f"{output_dir}/embedding_weight.bin")
print(f" embedding_weight.bin: {emb_weight.shape}")
# Export sequence model layers
print(f"\nExporting {len(model.model.layers)} layers...")
for i, layer in enumerate(model.model.layers):
# Get layer type
layer_obj = layer.layer # ResidualBlock wraps the actual layer
layer_type = type(layer_obj).__name__
# SSD layers use _ssd suffix, others don't
# Check for SSD-specific attributes first (conv_weight, A) to distinguish from SwiGLU
if hasattr(layer_obj, 'conv_weight') or hasattr(layer_obj, 'A'):
layer_dir = f"{output_dir}/layer_{i:03d}_ssd"
else:
layer_dir = f"{output_dir}/layer_{i:03d}"
os.makedirs(layer_dir, exist_ok=True)
print(f" Layer {i}: {layer_type} -> {os.path.basename(layer_dir)}")
try:
# Check for SSD layer (has conv_weight or A)
if hasattr(layer_obj, 'conv_weight') or hasattr(layer_obj, 'A'):
# SSD layer
# Try linear.weight first (for non-quantized), then just weight (for quantized)
if hasattr(layer_obj.in_proj, 'linear'):
in_proj_w = get_weight(layer_obj.in_proj.linear, 'weight')
else:
in_proj_w = get_weight(layer_obj.in_proj, 'weight')
in_proj_w.tofile(f"{layer_dir}/in_proj_weight.bin")
print(f" in_proj: {in_proj_w.shape}")
# Export conv_weight and conv_bias (direct attributes, not wrapped in conv)
if hasattr(layer_obj, 'conv_weight'):
conv_w = np.array(layer_obj.conv_weight, dtype=np.float32)
conv_w.tofile(f"{layer_dir}/conv_weight.bin")
print(f" conv_weight: {conv_w.shape}")
if hasattr(layer_obj, 'conv_bias'):
conv_b = np.array(layer_obj.conv_bias, dtype=np.float32)
conv_b.tofile(f"{layer_dir}/conv_bias.bin")
print(f" conv_bias: {conv_b.shape}")
# Export SSM parameters
if hasattr(layer_obj, 'A'):
A_log = np.array(layer_obj.A, dtype=np.float32)
A_log.tofile(f"{layer_dir}/A_log.bin")
print(f" A_log: {A_log.shape}")
if hasattr(layer_obj, 'dt_bias'):
dt_bias = np.array(layer_obj.dt_bias, dtype=np.float32)
dt_bias.tofile(f"{layer_dir}/dt_bias.bin")
print(f" dt_bias: {dt_bias.shape}")
if hasattr(layer_obj, 'D'):
D = np.array(layer_obj.D, dtype=np.float32)
D.tofile(f"{layer_dir}/D.bin")
print(f" D: {D.shape}")
# Export out_proj
if hasattr(layer_obj.out_proj, 'linear'):
out_proj_w = get_weight(layer_obj.out_proj.linear, 'weight')
else:
out_proj_w = get_weight(layer_obj.out_proj, 'weight')
out_proj_w.tofile(f"{layer_dir}/out_proj_weight.bin")
print(f" out_proj: {out_proj_w.shape}")
# Export SSD's internal rms_norm weights (different from ResidualBlock's norm!)
if hasattr(layer_obj, 'rms_norm'):
rms_norm_w = np.array(layer_obj.rms_norm.weight, dtype=np.float32)
rms_norm_w.tofile(f"{layer_dir}/rms_norm_weight.bin")
print(f" rms_norm (internal): {rms_norm_w.shape}")
# Check for SwiGLU layer (has in_proj, gate_proj, out_proj)
elif hasattr(layer_obj, 'in_proj') and hasattr(layer_obj, 'gate_proj') and hasattr(layer_obj, 'out_proj'):
# SwiGLU layer
if hasattr(layer_obj.gate_proj, 'linear'):
gate_w = get_weight(layer_obj.gate_proj.linear, 'weight')
else:
gate_w = get_weight(layer_obj.gate_proj, 'weight')
gate_w.tofile(f"{layer_dir}/gate_weight.bin")
print(f" gate_proj: {gate_w.shape}")
if hasattr(layer_obj.in_proj, 'linear'):
in_proj_w = get_weight(layer_obj.in_proj.linear, 'weight')
else:
in_proj_w = get_weight(layer_obj.in_proj, 'weight')
in_proj_w.tofile(f"{layer_dir}/in_proj_weight.bin")
print(f" in_proj: {in_proj_w.shape}")
if hasattr(layer_obj.out_proj, 'linear'):
out_proj_w = get_weight(layer_obj.out_proj.linear, 'weight')
else:
out_proj_w = get_weight(layer_obj.out_proj, 'weight')
out_proj_w.tofile(f"{layer_dir}/out_proj_weight.bin")
print(f" out_proj: {out_proj_w.shape}")
# Check for Attention layer (has qkv_proj, out_proj)
elif hasattr(layer_obj, 'qkv_proj') and hasattr(layer_obj, 'out_proj'):
# Attention layer
if hasattr(layer_obj.qkv_proj, 'linear'):
qkv_w = get_weight(layer_obj.qkv_proj.linear, 'weight')
else:
qkv_w = get_weight(layer_obj.qkv_proj, 'weight')
qkv_w.tofile(f"{layer_dir}/qkv_weight.bin")
print(f" qkv_proj: {qkv_w.shape}")
if hasattr(layer_obj.out_proj, 'linear'):
out_w = get_weight(layer_obj.out_proj.linear, 'weight')
else:
out_w = get_weight(layer_obj.out_proj, 'weight')
out_w.tofile(f"{layer_dir}/out_weight.bin")
print(f" out_proj: {out_w.shape}")
# Export norm weights if present
if hasattr(layer, 'norm') and layer.norm is not None:
norm_w = np.array(layer.norm.weight, dtype=np.float32)
norm_w.tofile(f"{layer_dir}/norm_weight.bin")
print(f" norm: {norm_w.shape}")
except Exception as e:
print(f" ERROR exporting layer {i}: {e}")
import traceback
traceback.print_exc()
continue
# Export post-norm weights if the sequence model has post_norm
if hasattr(model.model, 'post_norm') and model.model.post_norm:
if hasattr(model.model, 'norm') and model.model.norm is not None:
print("\nExporting post-norm weights...")
post_norm_w = np.array(model.model.norm.weight, dtype=np.float32)
post_norm_w.tofile(f"{output_dir}/post_norm_weight.bin")
print(f" post_norm_weight.bin: {post_norm_w.shape}")
# Export LM head weights
print("\nExporting LM head...")
if hasattr(model.head, 'linear'):
lm_head_w = get_weight(model.head.linear, 'weight')
else:
lm_head_w = get_weight(model.head, 'weight')
lm_head_w.tofile(f"{output_dir}/lm_head_weight.bin")
print(f" lm_head_weight.bin: {lm_head_w.shape}")
# Export metadata
metadata = {
"model_name": "mamba2-130m-mlx",
"d_model": int(model.d_model),
"n_tokens": int(model.n_tokens),
"n_layers": len(model.model.layers),
"model_type": "mamba2"
}
with open(f"{output_dir}/metadata.json", 'w') as f:
json.dump(metadata, f, indent=2)
print(f"\n✓ Metadata saved to {output_dir}/metadata.json")
print(f"\n✓ Export complete! Weights saved to: {output_dir}")
return output_dir
def main():
import argparse
parser = argparse.ArgumentParser(description="Export Mamba2-130M weights for OpenCL")
parser.add_argument("output_dir", nargs='?', default="./mamba2_130m_weights",
help="Output directory for weights (default: ./mamba2_130m_weights)")
parser.add_argument("--model", default="cartesia-ai/mamba2-130m-mlx",
help="Model name or path (default: cartesia-ai/mamba2-130m-mlx)")
args = parser.parse_args()
print(f"Loading model: {args.model}")
print("This may take a moment...")
try:
model = cmx.from_pretrained(args.model)
model.set_dtype(mx.float32)
print(f"✓ Model loaded: {type(model).__name__}")
print(f" d_model={model.d_model}, vocab={model.n_tokens}, layers={len(model.model.layers)}")
export_weights(model, args.output_dir)
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
sys.exit(main())