| """One-time weight conversion: rename checkpoint keys to match Fast_dVLM model layout. |
| |
| Checkpoint naming β Model naming |
| βββββββββββββββββββββββββ ββββββββββββββββββββββββββ |
| model.layers.* β model.language_model.layers.* |
| model.embed_tokens.* β model.language_model.embed_tokens.* |
| model.norm.* β model.language_model.norm.* |
| visual.* β model.visual.* |
| lm_head.* β lm_head.* (unchanged) |
| |
| Usage: |
| python convert_weights.py |
| """ |
|
|
| import json |
| import os |
| import re |
| from pathlib import Path |
| from safetensors.torch import load_file, save_file |
|
|
|
|
| def rename_key(key: str) -> str: |
| if key.startswith("visual."): |
| return "model." + key |
| if re.match(r"^model\.(?!language_model\.|visual\.)", key): |
| return re.sub(r"^model\.", "model.language_model.", key) |
| return key |
|
|
|
|
| def main(): |
| model_dir = Path(__file__).parent |
|
|
| with open(model_dir / "model.safetensors.index.json") as f: |
| index = json.load(f) |
|
|
| old_weight_map = index["weight_map"] |
|
|
| shard_files = sorted(set(old_weight_map.values())) |
| print(f"Found {len(shard_files)} shard file(s): {shard_files}") |
|
|
| new_weight_map = {} |
| for shard_file in shard_files: |
| shard_path = model_dir / shard_file |
|
|
| |
| real_path = shard_path.resolve() |
| print(f"\nProcessing {shard_file} (reading from {real_path}) ...") |
| tensors = load_file(str(real_path)) |
|
|
| renamed = {} |
| changes = 0 |
| for old_key, tensor in tensors.items(): |
| new_key = rename_key(old_key) |
| if old_key != new_key: |
| changes += 1 |
| if changes <= 5: |
| print(f" {old_key} β {new_key}") |
| elif changes == 6: |
| print(" ... (more renames omitted)") |
| renamed[new_key] = tensor |
| new_weight_map[new_key] = shard_file |
|
|
| |
| if shard_path.is_symlink(): |
| shard_path.unlink() |
|
|
| save_file(renamed, str(shard_path)) |
| print(f" Renamed {changes} keys, saved {len(renamed)} tensors to {shard_file}") |
|
|
| index["weight_map"] = new_weight_map |
| with open(model_dir / "model.safetensors.index.json", "w") as f: |
| json.dump(index, f, indent=2) |
| print("\nUpdated model.safetensors.index.json") |
| print("Done!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|