| """ |
| checkpoint_avg.py — Average weights from multiple OPF checkpoint directories. |
| |
| Usage: |
| python scripts/checkpoint_avg.py \ |
| --checkpoints ckpt/run1 ckpt/run2 ckpt/run3 \ |
| --output ckpt/averaged |
| |
| Each checkpoint directory must contain: |
| - model.safetensors (the weights) |
| - config.json (copied verbatim from the first checkpoint) |
| |
| All checkpoints must share the same tensor names and shapes (i.e. same label |
| space and architecture). The averaged weights are saved as bfloat16 by default; |
| pass --dtype float32 to keep full precision. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import shutil |
| import sys |
| from pathlib import Path |
|
|
| import torch |
| from safetensors.torch import load_file, save_file |
|
|
|
|
| |
| |
| |
|
|
| def load_tensors(ckpt_dir: Path) -> dict[str, torch.Tensor]: |
| weights_path = ckpt_dir / "model.safetensors" |
| if not weights_path.exists(): |
| raise FileNotFoundError(f"model.safetensors not found in {ckpt_dir}") |
| return load_file(str(weights_path), device="cpu") |
|
|
|
|
| def average_checkpoints( |
| checkpoint_dirs: list[Path], |
| *, |
| dtype: torch.dtype = torch.bfloat16, |
| ) -> dict[str, torch.Tensor]: |
| """Load all checkpoints and return their element-wise average.""" |
| n = len(checkpoint_dirs) |
| print(f"Loading {n} checkpoints...") |
|
|
| |
| reference = load_tensors(checkpoint_dirs[0]) |
| ref_keys = set(reference.keys()) |
| print(f" [{1}/{n}] {checkpoint_dirs[0]} ({len(ref_keys)} tensors)") |
|
|
| |
| accum: dict[str, torch.Tensor] = { |
| k: v.to(torch.float32) for k, v in reference.items() |
| } |
|
|
| for i, ckpt_dir in enumerate(checkpoint_dirs[1:], start=2): |
| tensors = load_tensors(ckpt_dir) |
| keys = set(tensors.keys()) |
|
|
| |
| if keys != ref_keys: |
| missing = ref_keys - keys |
| extra = keys - ref_keys |
| msg_parts = [f"Checkpoint {ckpt_dir} has mismatched tensor keys."] |
| if missing: |
| msg_parts.append(f" Missing: {sorted(missing)}") |
| if extra: |
| msg_parts.append(f" Extra: {sorted(extra)}") |
| raise ValueError("\n".join(msg_parts)) |
|
|
| |
| for k in ref_keys: |
| if tensors[k].shape != reference[k].shape: |
| raise ValueError( |
| f"Shape mismatch for tensor {k!r}: " |
| f"{reference[k].shape} vs {tensors[k].shape} in {ckpt_dir}" |
| ) |
|
|
| |
| for k in ref_keys: |
| accum[k] += tensors[k].to(torch.float32) |
|
|
| print(f" [{i}/{n}] {ckpt_dir}") |
|
|
| |
| print(f"Averaging ({n} checkpoints) -> {dtype}...") |
| averaged: dict[str, torch.Tensor] = { |
| k: (v / n).to(dtype) for k, v in accum.items() |
| } |
| return averaged |
|
|
|
|
| |
| |
| |
|
|
| def parse_args(argv: list[str] | None = None) -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Average weights from multiple OPF checkpoint directories.", |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| ) |
| parser.add_argument( |
| "--checkpoints", |
| nargs="+", |
| required=True, |
| metavar="DIR", |
| help="Two or more checkpoint directories to average.", |
| ) |
| parser.add_argument( |
| "--output", |
| required=True, |
| metavar="DIR", |
| help="Output directory for the averaged checkpoint.", |
| ) |
| parser.add_argument( |
| "--dtype", |
| choices=["bfloat16", "float32", "float16"], |
| default="bfloat16", |
| help="Dtype for saved weights.", |
| ) |
| parser.add_argument( |
| "--overwrite", |
| action="store_true", |
| help="Overwrite output directory if it already exists.", |
| ) |
| return parser.parse_args(argv) |
|
|
|
|
| DTYPE_MAP = { |
| "bfloat16": torch.bfloat16, |
| "float32": torch.float32, |
| "float16": torch.float16, |
| } |
|
|
|
|
| def main(argv: list[str] | None = None) -> int: |
| args = parse_args(argv) |
|
|
| checkpoint_dirs = [Path(d).expanduser().resolve() for d in args.checkpoints] |
| output_dir = Path(args.output).expanduser().resolve() |
| dtype = DTYPE_MAP[args.dtype] |
|
|
| |
| if len(checkpoint_dirs) < 2: |
| print("error: at least 2 checkpoints are required for averaging", file=sys.stderr) |
| return 1 |
|
|
| for d in checkpoint_dirs: |
| if not d.is_dir(): |
| print(f"error: checkpoint directory not found: {d}", file=sys.stderr) |
| return 1 |
|
|
| |
| if output_dir.exists(): |
| if not args.overwrite: |
| print( |
| f"error: output directory already exists: {output_dir}\n" |
| " Pass --overwrite to replace it.", |
| file=sys.stderr, |
| ) |
| return 1 |
| shutil.rmtree(output_dir) |
|
|
| output_dir.mkdir(parents=True) |
|
|
| |
| try: |
| averaged = average_checkpoints(checkpoint_dirs, dtype=dtype) |
| except (FileNotFoundError, ValueError) as exc: |
| print(f"error: {exc}", file=sys.stderr) |
| return 1 |
|
|
| |
| weights_path = output_dir / "model.safetensors" |
| print(f"Saving averaged weights -> {weights_path}") |
| save_file(averaged, str(weights_path)) |
|
|
| |
| src_config = checkpoint_dirs[0] / "config.json" |
| if not src_config.exists(): |
| print( |
| f"warning: config.json not found in {checkpoint_dirs[0]}; skipping", |
| file=sys.stderr, |
| ) |
| else: |
| dst_config = output_dir / "config.json" |
| shutil.copy2(src_config, dst_config) |
| print(f"Copied config.json from {checkpoint_dirs[0]} -> {dst_config}") |
|
|
| |
| provenance = { |
| "type": "checkpoint_average", |
| "source_checkpoints": [str(d) for d in checkpoint_dirs], |
| "num_checkpoints": len(checkpoint_dirs), |
| "output_dtype": args.dtype, |
| } |
| (output_dir / "avg_provenance.json").write_text( |
| json.dumps(provenance, indent=2) + "\n", encoding="utf-8" |
| ) |
|
|
| print(f"\nDone. Averaged checkpoint saved to: {output_dir}") |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|