Visual Document Retrieval
ColPali
Safetensors
English
modernvbert
vidore-experimental
vidore
modernvbert / convert_model_weights.py
paultltc's picture
Add model migration script and README guide
29f3700 verified
from __future__ import annotations
import argparse
import json
import shutil
from pathlib import Path
import torch
from safetensors.torch import load_file, save_file
TEXT_EMBED_KEY = "model.text_model.embeddings.tok_embeddings.weight"
TEXT_EXTRA_EMBED_KEY = "model.text_model.embeddings.tok_embeddings.additional_embedding.weight"
CONNECTOR_IN = "model.connector.modality_projection.proj.weight"
CONNECTOR_OUT = "model.connector.modality_projection.weight"
VISION_PREFIX_IN = "model.vision_model."
VISION_PREFIX_OUT = "model.vision_model.vision_model."
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Convert a legacy ModernVBERT checkpoint into the official transformers format."
)
parser.add_argument("input_dir", type=Path, help="Legacy model directory")
parser.add_argument("output_dir", type=Path, help="Converted model directory")
parser.add_argument(
"--config-template",
type=Path,
default=Path(__file__).resolve().with_name("config.json"),
help="Transformers-format config.json template to copy into the converted output",
)
return parser.parse_args()
def ensure_new_dir(path: Path) -> None:
if path.exists():
raise FileExistsError(f"{path} already exists; refusing to overwrite it")
path.mkdir(parents=True)
def copy_support_files(src: Path, dst: Path) -> None:
excluded = {"model.safetensors", "config.json", "BUILD_INFO.json"}
for item in src.iterdir():
if item.name in excluded:
continue
target = dst / item.name
if item.is_dir():
shutil.copytree(item, target)
else:
shutil.copy2(item, target)
def convert_model_weights(src_dir: Path, dst_dir: Path) -> dict[str, int]:
src_weights = load_file(str(src_dir / "model.safetensors"))
out_weights = {}
merged_embeddings = 0
renamed_connector = 0
renamed_vision = 0
for key, value in src_weights.items():
if key == TEXT_EXTRA_EMBED_KEY:
continue
if key == TEXT_EMBED_KEY and TEXT_EXTRA_EMBED_KEY in src_weights:
value = torch.cat([value, src_weights[TEXT_EXTRA_EMBED_KEY]], dim=0)
merged_embeddings += 1
if key == CONNECTOR_IN:
key = CONNECTOR_OUT
renamed_connector += 1
elif key.startswith(VISION_PREFIX_IN) and not key.startswith(VISION_PREFIX_OUT):
key = VISION_PREFIX_OUT + key[len(VISION_PREFIX_IN) :]
renamed_vision += 1
out_weights[key] = value
save_file(out_weights, str(dst_dir / "model.safetensors"))
return {
"source_tensor_count": len(src_weights),
"output_tensor_count": len(out_weights),
"merged_token_embedding_tables": merged_embeddings,
"renamed_connector_tensors": renamed_connector,
"renamed_vision_tensors": renamed_vision,
}
def write_config(template_path: Path, dst_dir: Path) -> dict[str, str]:
if not template_path.exists():
raise FileNotFoundError(f"Config template not found: {template_path}")
config = json.loads(template_path.read_text())
(dst_dir / "config.json").write_text(json.dumps(config, indent=2) + "\n")
return {"config_template": str(template_path)}
def main() -> None:
args = parse_args()
ensure_new_dir(args.output_dir)
copy_support_files(args.input_dir, args.output_dir)
weight_info = convert_model_weights(args.input_dir, args.output_dir)
config_info = write_config(args.config_template, args.output_dir)
build_info = {
"description": "Legacy ModernVBERT checkpoint converted to the official transformers format.",
"input_dir": str(args.input_dir),
"output_dir": str(args.output_dir),
**weight_info,
**config_info,
"key_mapping": {
TEXT_EXTRA_EMBED_KEY: f"merged into {TEXT_EMBED_KEY}",
CONNECTOR_IN: CONNECTOR_OUT,
VISION_PREFIX_IN: VISION_PREFIX_OUT,
},
}
(args.output_dir / "BUILD_INFO.json").write_text(json.dumps(build_info, indent=2) + "\n")
print(f"Wrote {args.output_dir}")
print(f"Converted {weight_info['output_tensor_count']} model tensors")
if __name__ == "__main__":
main()