File size: 2,579 Bytes
114c561
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#!/usr/bin/env python3
"""Rename mis-located visual encoder keys in Qwen3.5 SFT checkpoints.

Symptom: After full-parameter SFT with `freeze_vision_tower: true`, the
saved checkpoint nests visual weights under `model.language_model.visual.*`
instead of the expected top-level `visual.*` (which vLLM looks for).

This script renames every key matching `model.language_model.visual.*`
to `visual.*` and writes a new model.safetensors in place.

Usage:
  python3 patch_qwen35_visual_keys.py --sft /path/to/sft/output_dir
"""
import argparse
import json
from pathlib import Path

from safetensors import safe_open
from safetensors.torch import save_file


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--sft", required=True, help="path to sft model dir (contains model.safetensors)")
    args = ap.parse_args()

    sft_dir = Path(args.sft)
    sft_st = sft_dir / "model.safetensors"
    sft_index = sft_dir / "model.safetensors.index.json"

    print(f"[1/3] loading SFT weights from {sft_dir}...")
    sft_tensors = {}
    if sft_st.exists():
        with safe_open(sft_st, framework="pt") as f:
            for k in f.keys():
                sft_tensors[k] = f.get_tensor(k)
    elif sft_index.exists():
        idx = json.loads(sft_index.read_text())
        for shard in set(idx["weight_map"].values()):
            with safe_open(sft_dir / shard, framework="pt") as f:
                for k in f.keys():
                    sft_tensors[k] = f.get_tensor(k)
    else:
        raise FileNotFoundError(f"no model.safetensors or index in {sft_dir}")

    print(f"  loaded {len(sft_tensors)} tensors")

    print("[2/3] renaming model.language_model.visual.* -> visual.* ...")
    PREFIX = "model.language_model.visual."
    new_tensors = {}
    renamed = 0
    for k, v in sft_tensors.items():
        if k.startswith(PREFIX):
            new_key = "visual." + k[len(PREFIX):]
            new_tensors[new_key] = v
            renamed += 1
        else:
            new_tensors[k] = v

    print(f"  renamed {renamed} keys")
    if renamed == 0:
        print("[OK] nothing to rename; skipping write")
        return

    out_path = sft_dir / "model.safetensors"
    print(f"[3/3] writing patched checkpoint -> {out_path}")
    if sft_index.exists():
        print(f"  removing stale index file {sft_index}")
        sft_index.unlink()
    save_file(new_tensors, str(out_path), metadata={"format": "pt"})
    sz_gb = out_path.stat().st_size / 1e9
    print(f"[OK] wrote {len(new_tensors)} tensors, size {sz_gb:.2f} GB")


if __name__ == "__main__":
    main()