arcspan / scripts /checkpoint_avg.py
chairulridjal's picture
Add files using upload-large-folder tool
3dac39e verified
"""
checkpoint_avg.py — Average weights from multiple OPF checkpoint directories.
Usage:
python scripts/checkpoint_avg.py \
--checkpoints ckpt/run1 ckpt/run2 ckpt/run3 \
--output ckpt/averaged
Each checkpoint directory must contain:
- model.safetensors (the weights)
- config.json (copied verbatim from the first checkpoint)
All checkpoints must share the same tensor names and shapes (i.e. same label
space and architecture). The averaged weights are saved as bfloat16 by default;
pass --dtype float32 to keep full precision.
"""
from __future__ import annotations
import argparse
import json
import shutil
import sys
from pathlib import Path
import torch
from safetensors.torch import load_file, save_file
# ---------------------------------------------------------------------------
# Core averaging logic
# ---------------------------------------------------------------------------
def load_tensors(ckpt_dir: Path) -> dict[str, torch.Tensor]:
weights_path = ckpt_dir / "model.safetensors"
if not weights_path.exists():
raise FileNotFoundError(f"model.safetensors not found in {ckpt_dir}")
return load_file(str(weights_path), device="cpu")
def average_checkpoints(
checkpoint_dirs: list[Path],
*,
dtype: torch.dtype = torch.bfloat16,
) -> dict[str, torch.Tensor]:
"""Load all checkpoints and return their element-wise average."""
n = len(checkpoint_dirs)
print(f"Loading {n} checkpoints...")
# Load first checkpoint to establish reference keys and shapes
reference = load_tensors(checkpoint_dirs[0])
ref_keys = set(reference.keys())
print(f" [{1}/{n}] {checkpoint_dirs[0]} ({len(ref_keys)} tensors)")
# Accumulate in float32 for numerical stability regardless of output dtype
accum: dict[str, torch.Tensor] = {
k: v.to(torch.float32) for k, v in reference.items()
}
for i, ckpt_dir in enumerate(checkpoint_dirs[1:], start=2):
tensors = load_tensors(ckpt_dir)
keys = set(tensors.keys())
# Validate key sets match
if keys != ref_keys:
missing = ref_keys - keys
extra = keys - ref_keys
msg_parts = [f"Checkpoint {ckpt_dir} has mismatched tensor keys."]
if missing:
msg_parts.append(f" Missing: {sorted(missing)}")
if extra:
msg_parts.append(f" Extra: {sorted(extra)}")
raise ValueError("\n".join(msg_parts))
# Validate shapes match
for k in ref_keys:
if tensors[k].shape != reference[k].shape:
raise ValueError(
f"Shape mismatch for tensor {k!r}: "
f"{reference[k].shape} vs {tensors[k].shape} in {ckpt_dir}"
)
# Accumulate
for k in ref_keys:
accum[k] += tensors[k].to(torch.float32)
print(f" [{i}/{n}] {ckpt_dir}")
# Divide and cast to output dtype
print(f"Averaging ({n} checkpoints) -> {dtype}...")
averaged: dict[str, torch.Tensor] = {
k: (v / n).to(dtype) for k, v in accum.items()
}
return averaged
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Average weights from multiple OPF checkpoint directories.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--checkpoints",
nargs="+",
required=True,
metavar="DIR",
help="Two or more checkpoint directories to average.",
)
parser.add_argument(
"--output",
required=True,
metavar="DIR",
help="Output directory for the averaged checkpoint.",
)
parser.add_argument(
"--dtype",
choices=["bfloat16", "float32", "float16"],
default="bfloat16",
help="Dtype for saved weights.",
)
parser.add_argument(
"--overwrite",
action="store_true",
help="Overwrite output directory if it already exists.",
)
return parser.parse_args(argv)
DTYPE_MAP = {
"bfloat16": torch.bfloat16,
"float32": torch.float32,
"float16": torch.float16,
}
def main(argv: list[str] | None = None) -> int:
args = parse_args(argv)
checkpoint_dirs = [Path(d).expanduser().resolve() for d in args.checkpoints]
output_dir = Path(args.output).expanduser().resolve()
dtype = DTYPE_MAP[args.dtype]
# Validate inputs
if len(checkpoint_dirs) < 2:
print("error: at least 2 checkpoints are required for averaging", file=sys.stderr)
return 1
for d in checkpoint_dirs:
if not d.is_dir():
print(f"error: checkpoint directory not found: {d}", file=sys.stderr)
return 1
# Check output
if output_dir.exists():
if not args.overwrite:
print(
f"error: output directory already exists: {output_dir}\n"
" Pass --overwrite to replace it.",
file=sys.stderr,
)
return 1
shutil.rmtree(output_dir)
output_dir.mkdir(parents=True)
# Average weights
try:
averaged = average_checkpoints(checkpoint_dirs, dtype=dtype)
except (FileNotFoundError, ValueError) as exc:
print(f"error: {exc}", file=sys.stderr)
return 1
# Save averaged weights
weights_path = output_dir / "model.safetensors"
print(f"Saving averaged weights -> {weights_path}")
save_file(averaged, str(weights_path))
# Copy config.json from the first checkpoint
src_config = checkpoint_dirs[0] / "config.json"
if not src_config.exists():
print(
f"warning: config.json not found in {checkpoint_dirs[0]}; skipping",
file=sys.stderr,
)
else:
dst_config = output_dir / "config.json"
shutil.copy2(src_config, dst_config)
print(f"Copied config.json from {checkpoint_dirs[0]} -> {dst_config}")
# Write a small provenance file so it's clear how this checkpoint was made
provenance = {
"type": "checkpoint_average",
"source_checkpoints": [str(d) for d in checkpoint_dirs],
"num_checkpoints": len(checkpoint_dirs),
"output_dtype": args.dtype,
}
(output_dir / "avg_provenance.json").write_text(
json.dumps(provenance, indent=2) + "\n", encoding="utf-8"
)
print(f"\nDone. Averaged checkpoint saved to: {output_dir}")
return 0
if __name__ == "__main__":
sys.exit(main())