| | import argparse |
| | import os |
| | from pathlib import Path |
| |
|
| | import numpy |
| | from transformers import AutoModel, AutoConfig |
| |
|
| |
|
| | def replace_key(key: str) -> str: |
| | key = key.replace(".layer.", ".layers.") |
| | key = key.replace(".self.key.", ".key_proj.") |
| | key = key.replace(".self.query.", ".query_proj.") |
| | key = key.replace(".self.value.", ".value_proj.") |
| | key = key.replace(".attention.output.dense.", ".attention.out_proj.") |
| | key = key.replace(".attention.output.LayerNorm.", ".ln1.") |
| | key = key.replace(".output.LayerNorm.", ".ln2.") |
| | key = key.replace(".intermediate.dense.", ".linear1.") |
| | key = key.replace(".output.dense.", ".linear2.") |
| | key = key.replace(".LayerNorm.", ".norm.") |
| | key = key.replace("pooler.dense.", "pooler.") |
| | return key |
| |
|
| |
|
| | def convert(bert_model: str, mlx_model: str) -> None: |
| | |
| | model = AutoModel.from_pretrained(bert_model) |
| | config = AutoConfig.from_pretrained(bert_model) |
| | |
| | |
| | output_dir = os.path.dirname(mlx_model) |
| | if output_dir and not os.path.exists(output_dir): |
| | os.makedirs(output_dir) |
| | |
| | |
| | config_path = os.path.join(output_dir, "config.json") |
| | with open(config_path, "w") as f: |
| | f.write(config.to_json_string()) |
| | |
| | print(f"Saved model config to {config_path}") |
| | |
| | |
| | tensors = { |
| | replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items() |
| | } |
| | numpy.savez(mlx_model, **tensors) |
| | print(f"Saved model weights to {mlx_model}") |
| | print(f"Model vocab size: {config.vocab_size}, hidden size: {config.hidden_size}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.") |
| | parser.add_argument( |
| | "--bert-model", |
| | type=str, |
| | default="bert-base-uncased", |
| | help="The huggingface name of the BERT model to save. Any BERT-like model can be specified.", |
| | ) |
| | parser.add_argument( |
| | "--mlx-model", |
| | type=str, |
| | default="weights/bert-base-uncased.npz", |
| | help="The output path for the MLX BERT weights.", |
| | ) |
| | args = parser.parse_args() |
| |
|
| | convert(args.bert_model, args.mlx_model) |
| |
|