File size: 2,603 Bytes
c79aa4d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | """One-time weight conversion: rename checkpoint keys to match Fast_dVLM model layout.
Checkpoint naming β Model naming
βββββββββββββββββββββββββ ββββββββββββββββββββββββββ
model.layers.* β model.language_model.layers.*
model.embed_tokens.* β model.language_model.embed_tokens.*
model.norm.* β model.language_model.norm.*
visual.* β model.visual.*
lm_head.* β lm_head.* (unchanged)
Usage:
python convert_weights.py
"""
import json
import os
import re
from pathlib import Path
from safetensors.torch import load_file, save_file
def rename_key(key: str) -> str:
if key.startswith("visual."):
return "model." + key
if re.match(r"^model\.(?!language_model\.|visual\.)", key):
return re.sub(r"^model\.", "model.language_model.", key)
return key
def main():
model_dir = Path(__file__).parent
with open(model_dir / "model.safetensors.index.json") as f:
index = json.load(f)
old_weight_map = index["weight_map"]
shard_files = sorted(set(old_weight_map.values()))
print(f"Found {len(shard_files)} shard file(s): {shard_files}")
new_weight_map = {}
for shard_file in shard_files:
shard_path = model_dir / shard_file
# Resolve symlinks to read from the real file
real_path = shard_path.resolve()
print(f"\nProcessing {shard_file} (reading from {real_path}) ...")
tensors = load_file(str(real_path))
renamed = {}
changes = 0
for old_key, tensor in tensors.items():
new_key = rename_key(old_key)
if old_key != new_key:
changes += 1
if changes <= 5:
print(f" {old_key} β {new_key}")
elif changes == 6:
print(" ... (more renames omitted)")
renamed[new_key] = tensor
new_weight_map[new_key] = shard_file
# Remove symlink if it is one, then write the real file
if shard_path.is_symlink():
shard_path.unlink()
save_file(renamed, str(shard_path))
print(f" Renamed {changes} keys, saved {len(renamed)} tensors to {shard_file}")
index["weight_map"] = new_weight_map
with open(model_dir / "model.safetensors.index.json", "w") as f:
json.dump(index, f, indent=2)
print("\nUpdated model.safetensors.index.json")
print("Done!")
if __name__ == "__main__":
main()
|