File size: 2,603 Bytes
c79aa4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""One-time weight conversion: rename checkpoint keys to match Fast_dVLM model layout.

Checkpoint naming          β†’  Model naming
─────────────────────────     ──────────────────────────
model.layers.*             β†’  model.language_model.layers.*
model.embed_tokens.*       β†’  model.language_model.embed_tokens.*
model.norm.*               β†’  model.language_model.norm.*
visual.*                   β†’  model.visual.*
lm_head.*                  β†’  lm_head.*  (unchanged)

Usage:
    python convert_weights.py
"""

import json
import os
import re
from pathlib import Path
from safetensors.torch import load_file, save_file


def rename_key(key: str) -> str:
    if key.startswith("visual."):
        return "model." + key
    if re.match(r"^model\.(?!language_model\.|visual\.)", key):
        return re.sub(r"^model\.", "model.language_model.", key)
    return key


def main():
    model_dir = Path(__file__).parent

    with open(model_dir / "model.safetensors.index.json") as f:
        index = json.load(f)

    old_weight_map = index["weight_map"]

    shard_files = sorted(set(old_weight_map.values()))
    print(f"Found {len(shard_files)} shard file(s): {shard_files}")

    new_weight_map = {}
    for shard_file in shard_files:
        shard_path = model_dir / shard_file

        # Resolve symlinks to read from the real file
        real_path = shard_path.resolve()
        print(f"\nProcessing {shard_file} (reading from {real_path}) ...")
        tensors = load_file(str(real_path))

        renamed = {}
        changes = 0
        for old_key, tensor in tensors.items():
            new_key = rename_key(old_key)
            if old_key != new_key:
                changes += 1
                if changes <= 5:
                    print(f"  {old_key}  β†’  {new_key}")
                elif changes == 6:
                    print("  ... (more renames omitted)")
            renamed[new_key] = tensor
            new_weight_map[new_key] = shard_file

        # Remove symlink if it is one, then write the real file
        if shard_path.is_symlink():
            shard_path.unlink()

        save_file(renamed, str(shard_path))
        print(f"  Renamed {changes} keys, saved {len(renamed)} tensors to {shard_file}")

    index["weight_map"] = new_weight_map
    with open(model_dir / "model.safetensors.index.json", "w") as f:
        json.dump(index, f, indent=2)
    print("\nUpdated model.safetensors.index.json")
    print("Done!")


if __name__ == "__main__":
    main()