CosFly-Track / scripts /patch_qwen35_visual_keys.py
Ys404's picture
Add scripts and checkpoints (CosFly-Track release)
114c561 verified
raw
history blame
2.58 kB
#!/usr/bin/env python3
"""Rename mis-located visual encoder keys in Qwen3.5 SFT checkpoints.
Symptom: After full-parameter SFT with `freeze_vision_tower: true`, the
saved checkpoint nests visual weights under `model.language_model.visual.*`
instead of the expected top-level `visual.*` (which vLLM looks for).
This script renames every key matching `model.language_model.visual.*`
to `visual.*` and writes a new model.safetensors in place.
Usage:
python3 patch_qwen35_visual_keys.py --sft /path/to/sft/output_dir
"""
import argparse
import json
from pathlib import Path
from safetensors import safe_open
from safetensors.torch import save_file
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--sft", required=True, help="path to sft model dir (contains model.safetensors)")
args = ap.parse_args()
sft_dir = Path(args.sft)
sft_st = sft_dir / "model.safetensors"
sft_index = sft_dir / "model.safetensors.index.json"
print(f"[1/3] loading SFT weights from {sft_dir}...")
sft_tensors = {}
if sft_st.exists():
with safe_open(sft_st, framework="pt") as f:
for k in f.keys():
sft_tensors[k] = f.get_tensor(k)
elif sft_index.exists():
idx = json.loads(sft_index.read_text())
for shard in set(idx["weight_map"].values()):
with safe_open(sft_dir / shard, framework="pt") as f:
for k in f.keys():
sft_tensors[k] = f.get_tensor(k)
else:
raise FileNotFoundError(f"no model.safetensors or index in {sft_dir}")
print(f" loaded {len(sft_tensors)} tensors")
print("[2/3] renaming model.language_model.visual.* -> visual.* ...")
PREFIX = "model.language_model.visual."
new_tensors = {}
renamed = 0
for k, v in sft_tensors.items():
if k.startswith(PREFIX):
new_key = "visual." + k[len(PREFIX):]
new_tensors[new_key] = v
renamed += 1
else:
new_tensors[k] = v
print(f" renamed {renamed} keys")
if renamed == 0:
print("[OK] nothing to rename; skipping write")
return
out_path = sft_dir / "model.safetensors"
print(f"[3/3] writing patched checkpoint -> {out_path}")
if sft_index.exists():
print(f" removing stale index file {sft_index}")
sft_index.unlink()
save_file(new_tensors, str(out_path), metadata={"format": "pt"})
sz_gb = out_path.stat().st_size / 1e9
print(f"[OK] wrote {len(new_tensors)} tensors, size {sz_gb:.2f} GB")
if __name__ == "__main__":
main()