| |
| """Rename mis-located visual encoder keys in Qwen3.5 SFT checkpoints. |
| |
| Symptom: After full-parameter SFT with `freeze_vision_tower: true`, the |
| saved checkpoint nests visual weights under `model.language_model.visual.*` |
| instead of the expected top-level `visual.*` (which vLLM looks for). |
| |
| This script renames every key matching `model.language_model.visual.*` |
| to `visual.*` and writes a new model.safetensors in place. |
| |
| Usage: |
| python3 patch_qwen35_visual_keys.py --sft /path/to/sft/output_dir |
| """ |
| import argparse |
| import json |
| from pathlib import Path |
|
|
| from safetensors import safe_open |
| from safetensors.torch import save_file |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--sft", required=True, help="path to sft model dir (contains model.safetensors)") |
| args = ap.parse_args() |
|
|
| sft_dir = Path(args.sft) |
| sft_st = sft_dir / "model.safetensors" |
| sft_index = sft_dir / "model.safetensors.index.json" |
|
|
| print(f"[1/3] loading SFT weights from {sft_dir}...") |
| sft_tensors = {} |
| if sft_st.exists(): |
| with safe_open(sft_st, framework="pt") as f: |
| for k in f.keys(): |
| sft_tensors[k] = f.get_tensor(k) |
| elif sft_index.exists(): |
| idx = json.loads(sft_index.read_text()) |
| for shard in set(idx["weight_map"].values()): |
| with safe_open(sft_dir / shard, framework="pt") as f: |
| for k in f.keys(): |
| sft_tensors[k] = f.get_tensor(k) |
| else: |
| raise FileNotFoundError(f"no model.safetensors or index in {sft_dir}") |
|
|
| print(f" loaded {len(sft_tensors)} tensors") |
|
|
| print("[2/3] renaming model.language_model.visual.* -> visual.* ...") |
| PREFIX = "model.language_model.visual." |
| new_tensors = {} |
| renamed = 0 |
| for k, v in sft_tensors.items(): |
| if k.startswith(PREFIX): |
| new_key = "visual." + k[len(PREFIX):] |
| new_tensors[new_key] = v |
| renamed += 1 |
| else: |
| new_tensors[k] = v |
|
|
| print(f" renamed {renamed} keys") |
| if renamed == 0: |
| print("[OK] nothing to rename; skipping write") |
| return |
|
|
| out_path = sft_dir / "model.safetensors" |
| print(f"[3/3] writing patched checkpoint -> {out_path}") |
| if sft_index.exists(): |
| print(f" removing stale index file {sft_index}") |
| sft_index.unlink() |
| save_file(new_tensors, str(out_path), metadata={"format": "pt"}) |
| sz_gb = out_path.stat().st_size / 1e9 |
| print(f"[OK] wrote {len(new_tensors)} tensors, size {sz_gb:.2f} GB") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|