File size: 2,579 Bytes
114c561 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 | #!/usr/bin/env python3
"""Rename mis-located visual encoder keys in Qwen3.5 SFT checkpoints.
Symptom: After full-parameter SFT with `freeze_vision_tower: true`, the
saved checkpoint nests visual weights under `model.language_model.visual.*`
instead of the expected top-level `visual.*` (which vLLM looks for).
This script renames every key matching `model.language_model.visual.*`
to `visual.*` and writes a new model.safetensors in place.
Usage:
python3 patch_qwen35_visual_keys.py --sft /path/to/sft/output_dir
"""
import argparse
import json
from pathlib import Path
from safetensors import safe_open
from safetensors.torch import save_file
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--sft", required=True, help="path to sft model dir (contains model.safetensors)")
args = ap.parse_args()
sft_dir = Path(args.sft)
sft_st = sft_dir / "model.safetensors"
sft_index = sft_dir / "model.safetensors.index.json"
print(f"[1/3] loading SFT weights from {sft_dir}...")
sft_tensors = {}
if sft_st.exists():
with safe_open(sft_st, framework="pt") as f:
for k in f.keys():
sft_tensors[k] = f.get_tensor(k)
elif sft_index.exists():
idx = json.loads(sft_index.read_text())
for shard in set(idx["weight_map"].values()):
with safe_open(sft_dir / shard, framework="pt") as f:
for k in f.keys():
sft_tensors[k] = f.get_tensor(k)
else:
raise FileNotFoundError(f"no model.safetensors or index in {sft_dir}")
print(f" loaded {len(sft_tensors)} tensors")
print("[2/3] renaming model.language_model.visual.* -> visual.* ...")
PREFIX = "model.language_model.visual."
new_tensors = {}
renamed = 0
for k, v in sft_tensors.items():
if k.startswith(PREFIX):
new_key = "visual." + k[len(PREFIX):]
new_tensors[new_key] = v
renamed += 1
else:
new_tensors[k] = v
print(f" renamed {renamed} keys")
if renamed == 0:
print("[OK] nothing to rename; skipping write")
return
out_path = sft_dir / "model.safetensors"
print(f"[3/3] writing patched checkpoint -> {out_path}")
if sft_index.exists():
print(f" removing stale index file {sft_index}")
sft_index.unlink()
save_file(new_tensors, str(out_path), metadata={"format": "pt"})
sz_gb = out_path.stat().st_size / 1e9
print(f"[OK] wrote {len(new_tensors)} tensors, size {sz_gb:.2f} GB")
if __name__ == "__main__":
main()
|