| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import shutil |
| from pathlib import Path |
|
|
| import torch |
| from safetensors.torch import load_file, save_file |
|
|
|
|
| TEXT_EMBED_KEY = "model.text_model.embeddings.tok_embeddings.weight" |
| TEXT_EXTRA_EMBED_KEY = "model.text_model.embeddings.tok_embeddings.additional_embedding.weight" |
| CONNECTOR_IN = "model.connector.modality_projection.proj.weight" |
| CONNECTOR_OUT = "model.connector.modality_projection.weight" |
| VISION_PREFIX_IN = "model.vision_model." |
| VISION_PREFIX_OUT = "model.vision_model.vision_model." |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Convert a legacy ModernVBERT checkpoint into the official transformers format." |
| ) |
| parser.add_argument("input_dir", type=Path, help="Legacy model directory") |
| parser.add_argument("output_dir", type=Path, help="Converted model directory") |
| parser.add_argument( |
| "--config-template", |
| type=Path, |
| default=Path(__file__).resolve().with_name("config.json"), |
| help="Transformers-format config.json template to copy into the converted output", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def ensure_new_dir(path: Path) -> None: |
| if path.exists(): |
| raise FileExistsError(f"{path} already exists; refusing to overwrite it") |
| path.mkdir(parents=True) |
|
|
|
|
| def copy_support_files(src: Path, dst: Path) -> None: |
| excluded = {"model.safetensors", "config.json", "BUILD_INFO.json"} |
| for item in src.iterdir(): |
| if item.name in excluded: |
| continue |
| target = dst / item.name |
| if item.is_dir(): |
| shutil.copytree(item, target) |
| else: |
| shutil.copy2(item, target) |
|
|
|
|
| def convert_model_weights(src_dir: Path, dst_dir: Path) -> dict[str, int]: |
| src_weights = load_file(str(src_dir / "model.safetensors")) |
| out_weights = {} |
| merged_embeddings = 0 |
| renamed_connector = 0 |
| renamed_vision = 0 |
|
|
| for key, value in src_weights.items(): |
| if key == TEXT_EXTRA_EMBED_KEY: |
| continue |
| if key == TEXT_EMBED_KEY and TEXT_EXTRA_EMBED_KEY in src_weights: |
| value = torch.cat([value, src_weights[TEXT_EXTRA_EMBED_KEY]], dim=0) |
| merged_embeddings += 1 |
| if key == CONNECTOR_IN: |
| key = CONNECTOR_OUT |
| renamed_connector += 1 |
| elif key.startswith(VISION_PREFIX_IN) and not key.startswith(VISION_PREFIX_OUT): |
| key = VISION_PREFIX_OUT + key[len(VISION_PREFIX_IN) :] |
| renamed_vision += 1 |
| out_weights[key] = value |
|
|
| save_file(out_weights, str(dst_dir / "model.safetensors")) |
| return { |
| "source_tensor_count": len(src_weights), |
| "output_tensor_count": len(out_weights), |
| "merged_token_embedding_tables": merged_embeddings, |
| "renamed_connector_tensors": renamed_connector, |
| "renamed_vision_tensors": renamed_vision, |
| } |
|
|
|
|
| def write_config(template_path: Path, dst_dir: Path) -> dict[str, str]: |
| if not template_path.exists(): |
| raise FileNotFoundError(f"Config template not found: {template_path}") |
| config = json.loads(template_path.read_text()) |
| (dst_dir / "config.json").write_text(json.dumps(config, indent=2) + "\n") |
| return {"config_template": str(template_path)} |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| ensure_new_dir(args.output_dir) |
| copy_support_files(args.input_dir, args.output_dir) |
|
|
| weight_info = convert_model_weights(args.input_dir, args.output_dir) |
| config_info = write_config(args.config_template, args.output_dir) |
|
|
| build_info = { |
| "description": "Legacy ModernVBERT checkpoint converted to the official transformers format.", |
| "input_dir": str(args.input_dir), |
| "output_dir": str(args.output_dir), |
| **weight_info, |
| **config_info, |
| "key_mapping": { |
| TEXT_EXTRA_EMBED_KEY: f"merged into {TEXT_EMBED_KEY}", |
| CONNECTOR_IN: CONNECTOR_OUT, |
| VISION_PREFIX_IN: VISION_PREFIX_OUT, |
| }, |
| } |
| (args.output_dir / "BUILD_INFO.json").write_text(json.dumps(build_info, indent=2) + "\n") |
|
|
| print(f"Wrote {args.output_dir}") |
| print(f"Converted {weight_info['output_tensor_count']} model tensors") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|