File size: 4,416 Bytes
e50b018
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#!/usr/bin/env python3
"""Patch a Gemma-4 SFT checkpoint that is missing some weights.

DeepSpeed ZeRO3 sometimes drops sliding-window-layer K/V weights when saving.
This script copies the missing weights from the base model into the SFT
checkpoint, producing a complete model.safetensors plus an updated
model.safetensors.index.json (if needed) so the model can be loaded by vLLM.

Usage:
  python3 patch_gemma_checkpoint.py \
      --base /path/to/Gemma-4-E4B-it \
      --sft  /path/to/sft/output_dir \
      [--out /path/to/output_dir]   # default: in-place
"""
import argparse
import json
import os
import shutil
from pathlib import Path

import torch
from safetensors import safe_open
from safetensors.torch import save_file


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--base", required=True, help="path to base model dir")
    ap.add_argument("--sft", required=True, help="path to sft model dir")
    ap.add_argument("--out", default=None, help="output dir (default: --sft)")
    args = ap.parse_args()

    base_dir = Path(args.base)
    sft_dir = Path(args.sft)
    out_dir = Path(args.out) if args.out else sft_dir
    out_dir.mkdir(parents=True, exist_ok=True)

    base_st = base_dir / "model.safetensors"
    sft_st = sft_dir / "model.safetensors"
    if not base_st.exists():
        # multi-shard base model
        base_st_index = base_dir / "model.safetensors.index.json"
        assert base_st_index.exists(), f"missing {base_st} or its index"
    if not sft_st.exists():
        sft_st_index = sft_dir / "model.safetensors.index.json"
        assert sft_st_index.exists(), f"missing {sft_st} or its index"

    # load all sft tensors
    print(f"[1/4] loading SFT weights from {sft_dir}...")
    sft_tensors = {}
    sft_index_file = sft_dir / "model.safetensors.index.json"
    if sft_st.exists():
        with safe_open(sft_st, framework="pt") as f:
            for k in f.keys():
                sft_tensors[k] = f.get_tensor(k)
    else:
        idx = json.loads(sft_index_file.read_text())
        for shard in set(idx["weight_map"].values()):
            with safe_open(sft_dir / shard, framework="pt") as f:
                for k in f.keys():
                    sft_tensors[k] = f.get_tensor(k)

    # load all base tensor keys
    print(f"[2/4] scanning base weights from {base_dir}...")
    base_keys_to_files = {}
    if base_st.exists():
        with safe_open(base_st, framework="pt") as f:
            for k in f.keys():
                base_keys_to_files[k] = base_st
    else:
        idx = json.loads((base_dir / "model.safetensors.index.json").read_text())
        for k, shard in idx["weight_map"].items():
            base_keys_to_files[k] = base_dir / shard

    base_keys = set(base_keys_to_files.keys())
    sft_keys = set(sft_tensors.keys())
    missing = sorted(base_keys - sft_keys)
    extra = sorted(sft_keys - base_keys)

    print(f"  base keys: {len(base_keys)}")
    print(f"  sft  keys: {len(sft_keys)}")
    print(f"  missing in sft: {len(missing)}  (will copy from base)")
    print(f"  extra in sft  : {len(extra)}    (kept as-is)")

    if not missing:
        print("[OK] nothing to patch; sft is already complete")
        return

    # group missing keys by source shard, copy in batch
    print(f"[3/4] copying {len(missing)} missing weights from base...")
    by_shard = {}
    for k in missing:
        by_shard.setdefault(base_keys_to_files[k], []).append(k)

    for shard_path, keys in by_shard.items():
        with safe_open(shard_path, framework="pt") as f:
            for k in keys:
                t = f.get_tensor(k)
                if t.dtype != torch.bfloat16:
                    t = t.to(torch.bfloat16)
                sft_tensors[k] = t

    # write back as a single safetensors file
    out_path = out_dir / "model.safetensors"
    print(f"[4/4] writing patched checkpoint -> {out_path}")
    # remove any stale single-file or index file in out_dir to avoid mismatch
    if out_dir == sft_dir:
        for stale in [out_dir / "model.safetensors.index.json"]:
            if stale.exists():
                print(f"  removing stale {stale}")
                stale.unlink()
    save_file(sft_tensors, str(out_path), metadata={"format": "pt"})
    print(f"[OK] saved {len(sft_tensors)} tensors to {out_path}")
    print(f"     size: {out_path.stat().st_size / 1e9:.2f} GB")


if __name__ == "__main__":
    main()