#!/usr/bin/env python3 """Rename mis-located visual encoder keys in Qwen3.5 SFT checkpoints. Symptom: After full-parameter SFT with `freeze_vision_tower: true`, the saved checkpoint nests visual weights under `model.language_model.visual.*` instead of the expected top-level `visual.*` (which vLLM looks for). This script renames every key matching `model.language_model.visual.*` to `visual.*` and writes a new model.safetensors in place. Usage: python3 patch_qwen35_visual_keys.py --sft /path/to/sft/output_dir """ import argparse import json from pathlib import Path from safetensors import safe_open from safetensors.torch import save_file def main(): ap = argparse.ArgumentParser() ap.add_argument("--sft", required=True, help="path to sft model dir (contains model.safetensors)") args = ap.parse_args() sft_dir = Path(args.sft) sft_st = sft_dir / "model.safetensors" sft_index = sft_dir / "model.safetensors.index.json" print(f"[1/3] loading SFT weights from {sft_dir}...") sft_tensors = {} if sft_st.exists(): with safe_open(sft_st, framework="pt") as f: for k in f.keys(): sft_tensors[k] = f.get_tensor(k) elif sft_index.exists(): idx = json.loads(sft_index.read_text()) for shard in set(idx["weight_map"].values()): with safe_open(sft_dir / shard, framework="pt") as f: for k in f.keys(): sft_tensors[k] = f.get_tensor(k) else: raise FileNotFoundError(f"no model.safetensors or index in {sft_dir}") print(f" loaded {len(sft_tensors)} tensors") print("[2/3] renaming model.language_model.visual.* -> visual.* ...") PREFIX = "model.language_model.visual." new_tensors = {} renamed = 0 for k, v in sft_tensors.items(): if k.startswith(PREFIX): new_key = "visual." + k[len(PREFIX):] new_tensors[new_key] = v renamed += 1 else: new_tensors[k] = v print(f" renamed {renamed} keys") if renamed == 0: print("[OK] nothing to rename; skipping write") return out_path = sft_dir / "model.safetensors" print(f"[3/3] writing patched checkpoint -> {out_path}") if sft_index.exists(): print(f" removing stale index file {sft_index}") sft_index.unlink() save_file(new_tensors, str(out_path), metadata={"format": "pt"}) sz_gb = out_path.stat().st_size / 1e9 print(f"[OK] wrote {len(new_tensors)} tensors, size {sz_gb:.2f} GB") if __name__ == "__main__": main()