File size: 4,416 Bytes
e50b018 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 | #!/usr/bin/env python3
"""Patch a Gemma-4 SFT checkpoint that is missing some weights.
DeepSpeed ZeRO3 sometimes drops sliding-window-layer K/V weights when saving.
This script copies the missing weights from the base model into the SFT
checkpoint, producing a complete model.safetensors plus an updated
model.safetensors.index.json (if needed) so the model can be loaded by vLLM.
Usage:
python3 patch_gemma_checkpoint.py \
--base /path/to/Gemma-4-E4B-it \
--sft /path/to/sft/output_dir \
[--out /path/to/output_dir] # default: in-place
"""
import argparse
import json
import os
import shutil
from pathlib import Path
import torch
from safetensors import safe_open
from safetensors.torch import save_file
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--base", required=True, help="path to base model dir")
ap.add_argument("--sft", required=True, help="path to sft model dir")
ap.add_argument("--out", default=None, help="output dir (default: --sft)")
args = ap.parse_args()
base_dir = Path(args.base)
sft_dir = Path(args.sft)
out_dir = Path(args.out) if args.out else sft_dir
out_dir.mkdir(parents=True, exist_ok=True)
base_st = base_dir / "model.safetensors"
sft_st = sft_dir / "model.safetensors"
if not base_st.exists():
# multi-shard base model
base_st_index = base_dir / "model.safetensors.index.json"
assert base_st_index.exists(), f"missing {base_st} or its index"
if not sft_st.exists():
sft_st_index = sft_dir / "model.safetensors.index.json"
assert sft_st_index.exists(), f"missing {sft_st} or its index"
# load all sft tensors
print(f"[1/4] loading SFT weights from {sft_dir}...")
sft_tensors = {}
sft_index_file = sft_dir / "model.safetensors.index.json"
if sft_st.exists():
with safe_open(sft_st, framework="pt") as f:
for k in f.keys():
sft_tensors[k] = f.get_tensor(k)
else:
idx = json.loads(sft_index_file.read_text())
for shard in set(idx["weight_map"].values()):
with safe_open(sft_dir / shard, framework="pt") as f:
for k in f.keys():
sft_tensors[k] = f.get_tensor(k)
# load all base tensor keys
print(f"[2/4] scanning base weights from {base_dir}...")
base_keys_to_files = {}
if base_st.exists():
with safe_open(base_st, framework="pt") as f:
for k in f.keys():
base_keys_to_files[k] = base_st
else:
idx = json.loads((base_dir / "model.safetensors.index.json").read_text())
for k, shard in idx["weight_map"].items():
base_keys_to_files[k] = base_dir / shard
base_keys = set(base_keys_to_files.keys())
sft_keys = set(sft_tensors.keys())
missing = sorted(base_keys - sft_keys)
extra = sorted(sft_keys - base_keys)
print(f" base keys: {len(base_keys)}")
print(f" sft keys: {len(sft_keys)}")
print(f" missing in sft: {len(missing)} (will copy from base)")
print(f" extra in sft : {len(extra)} (kept as-is)")
if not missing:
print("[OK] nothing to patch; sft is already complete")
return
# group missing keys by source shard, copy in batch
print(f"[3/4] copying {len(missing)} missing weights from base...")
by_shard = {}
for k in missing:
by_shard.setdefault(base_keys_to_files[k], []).append(k)
for shard_path, keys in by_shard.items():
with safe_open(shard_path, framework="pt") as f:
for k in keys:
t = f.get_tensor(k)
if t.dtype != torch.bfloat16:
t = t.to(torch.bfloat16)
sft_tensors[k] = t
# write back as a single safetensors file
out_path = out_dir / "model.safetensors"
print(f"[4/4] writing patched checkpoint -> {out_path}")
# remove any stale single-file or index file in out_dir to avoid mismatch
if out_dir == sft_dir:
for stale in [out_dir / "model.safetensors.index.json"]:
if stale.exists():
print(f" removing stale {stale}")
stale.unlink()
save_file(sft_tensors, str(out_path), metadata={"format": "pt"})
print(f"[OK] saved {len(sft_tensors)} tensors to {out_path}")
print(f" size: {out_path.stat().st_size / 1e9:.2f} GB")
if __name__ == "__main__":
main()
|