File size: 6,720 Bytes
3dac39e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 | """
checkpoint_avg.py — Average weights from multiple OPF checkpoint directories.
Usage:
python scripts/checkpoint_avg.py \
--checkpoints ckpt/run1 ckpt/run2 ckpt/run3 \
--output ckpt/averaged
Each checkpoint directory must contain:
- model.safetensors (the weights)
- config.json (copied verbatim from the first checkpoint)
All checkpoints must share the same tensor names and shapes (i.e. same label
space and architecture). The averaged weights are saved as bfloat16 by default;
pass --dtype float32 to keep full precision.
"""
from __future__ import annotations
import argparse
import json
import shutil
import sys
from pathlib import Path
import torch
from safetensors.torch import load_file, save_file
# ---------------------------------------------------------------------------
# Core averaging logic
# ---------------------------------------------------------------------------
def load_tensors(ckpt_dir: Path) -> dict[str, torch.Tensor]:
weights_path = ckpt_dir / "model.safetensors"
if not weights_path.exists():
raise FileNotFoundError(f"model.safetensors not found in {ckpt_dir}")
return load_file(str(weights_path), device="cpu")
def average_checkpoints(
checkpoint_dirs: list[Path],
*,
dtype: torch.dtype = torch.bfloat16,
) -> dict[str, torch.Tensor]:
"""Load all checkpoints and return their element-wise average."""
n = len(checkpoint_dirs)
print(f"Loading {n} checkpoints...")
# Load first checkpoint to establish reference keys and shapes
reference = load_tensors(checkpoint_dirs[0])
ref_keys = set(reference.keys())
print(f" [{1}/{n}] {checkpoint_dirs[0]} ({len(ref_keys)} tensors)")
# Accumulate in float32 for numerical stability regardless of output dtype
accum: dict[str, torch.Tensor] = {
k: v.to(torch.float32) for k, v in reference.items()
}
for i, ckpt_dir in enumerate(checkpoint_dirs[1:], start=2):
tensors = load_tensors(ckpt_dir)
keys = set(tensors.keys())
# Validate key sets match
if keys != ref_keys:
missing = ref_keys - keys
extra = keys - ref_keys
msg_parts = [f"Checkpoint {ckpt_dir} has mismatched tensor keys."]
if missing:
msg_parts.append(f" Missing: {sorted(missing)}")
if extra:
msg_parts.append(f" Extra: {sorted(extra)}")
raise ValueError("\n".join(msg_parts))
# Validate shapes match
for k in ref_keys:
if tensors[k].shape != reference[k].shape:
raise ValueError(
f"Shape mismatch for tensor {k!r}: "
f"{reference[k].shape} vs {tensors[k].shape} in {ckpt_dir}"
)
# Accumulate
for k in ref_keys:
accum[k] += tensors[k].to(torch.float32)
print(f" [{i}/{n}] {ckpt_dir}")
# Divide and cast to output dtype
print(f"Averaging ({n} checkpoints) -> {dtype}...")
averaged: dict[str, torch.Tensor] = {
k: (v / n).to(dtype) for k, v in accum.items()
}
return averaged
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Average weights from multiple OPF checkpoint directories.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--checkpoints",
nargs="+",
required=True,
metavar="DIR",
help="Two or more checkpoint directories to average.",
)
parser.add_argument(
"--output",
required=True,
metavar="DIR",
help="Output directory for the averaged checkpoint.",
)
parser.add_argument(
"--dtype",
choices=["bfloat16", "float32", "float16"],
default="bfloat16",
help="Dtype for saved weights.",
)
parser.add_argument(
"--overwrite",
action="store_true",
help="Overwrite output directory if it already exists.",
)
return parser.parse_args(argv)
DTYPE_MAP = {
"bfloat16": torch.bfloat16,
"float32": torch.float32,
"float16": torch.float16,
}
def main(argv: list[str] | None = None) -> int:
args = parse_args(argv)
checkpoint_dirs = [Path(d).expanduser().resolve() for d in args.checkpoints]
output_dir = Path(args.output).expanduser().resolve()
dtype = DTYPE_MAP[args.dtype]
# Validate inputs
if len(checkpoint_dirs) < 2:
print("error: at least 2 checkpoints are required for averaging", file=sys.stderr)
return 1
for d in checkpoint_dirs:
if not d.is_dir():
print(f"error: checkpoint directory not found: {d}", file=sys.stderr)
return 1
# Check output
if output_dir.exists():
if not args.overwrite:
print(
f"error: output directory already exists: {output_dir}\n"
" Pass --overwrite to replace it.",
file=sys.stderr,
)
return 1
shutil.rmtree(output_dir)
output_dir.mkdir(parents=True)
# Average weights
try:
averaged = average_checkpoints(checkpoint_dirs, dtype=dtype)
except (FileNotFoundError, ValueError) as exc:
print(f"error: {exc}", file=sys.stderr)
return 1
# Save averaged weights
weights_path = output_dir / "model.safetensors"
print(f"Saving averaged weights -> {weights_path}")
save_file(averaged, str(weights_path))
# Copy config.json from the first checkpoint
src_config = checkpoint_dirs[0] / "config.json"
if not src_config.exists():
print(
f"warning: config.json not found in {checkpoint_dirs[0]}; skipping",
file=sys.stderr,
)
else:
dst_config = output_dir / "config.json"
shutil.copy2(src_config, dst_config)
print(f"Copied config.json from {checkpoint_dirs[0]} -> {dst_config}")
# Write a small provenance file so it's clear how this checkpoint was made
provenance = {
"type": "checkpoint_average",
"source_checkpoints": [str(d) for d in checkpoint_dirs],
"num_checkpoints": len(checkpoint_dirs),
"output_dtype": args.dtype,
}
(output_dir / "avg_provenance.json").write_text(
json.dumps(provenance, indent=2) + "\n", encoding="utf-8"
)
print(f"\nDone. Averaged checkpoint saved to: {output_dir}")
return 0
if __name__ == "__main__":
sys.exit(main())
|