""" checkpoint_avg.py — Average weights from multiple OPF checkpoint directories. Usage: python scripts/checkpoint_avg.py \ --checkpoints ckpt/run1 ckpt/run2 ckpt/run3 \ --output ckpt/averaged Each checkpoint directory must contain: - model.safetensors (the weights) - config.json (copied verbatim from the first checkpoint) All checkpoints must share the same tensor names and shapes (i.e. same label space and architecture). The averaged weights are saved as bfloat16 by default; pass --dtype float32 to keep full precision. """ from __future__ import annotations import argparse import json import shutil import sys from pathlib import Path import torch from safetensors.torch import load_file, save_file # --------------------------------------------------------------------------- # Core averaging logic # --------------------------------------------------------------------------- def load_tensors(ckpt_dir: Path) -> dict[str, torch.Tensor]: weights_path = ckpt_dir / "model.safetensors" if not weights_path.exists(): raise FileNotFoundError(f"model.safetensors not found in {ckpt_dir}") return load_file(str(weights_path), device="cpu") def average_checkpoints( checkpoint_dirs: list[Path], *, dtype: torch.dtype = torch.bfloat16, ) -> dict[str, torch.Tensor]: """Load all checkpoints and return their element-wise average.""" n = len(checkpoint_dirs) print(f"Loading {n} checkpoints...") # Load first checkpoint to establish reference keys and shapes reference = load_tensors(checkpoint_dirs[0]) ref_keys = set(reference.keys()) print(f" [{1}/{n}] {checkpoint_dirs[0]} ({len(ref_keys)} tensors)") # Accumulate in float32 for numerical stability regardless of output dtype accum: dict[str, torch.Tensor] = { k: v.to(torch.float32) for k, v in reference.items() } for i, ckpt_dir in enumerate(checkpoint_dirs[1:], start=2): tensors = load_tensors(ckpt_dir) keys = set(tensors.keys()) # Validate key sets match if keys != ref_keys: missing = ref_keys - keys extra = keys - ref_keys msg_parts = [f"Checkpoint {ckpt_dir} has mismatched tensor keys."] if missing: msg_parts.append(f" Missing: {sorted(missing)}") if extra: msg_parts.append(f" Extra: {sorted(extra)}") raise ValueError("\n".join(msg_parts)) # Validate shapes match for k in ref_keys: if tensors[k].shape != reference[k].shape: raise ValueError( f"Shape mismatch for tensor {k!r}: " f"{reference[k].shape} vs {tensors[k].shape} in {ckpt_dir}" ) # Accumulate for k in ref_keys: accum[k] += tensors[k].to(torch.float32) print(f" [{i}/{n}] {ckpt_dir}") # Divide and cast to output dtype print(f"Averaging ({n} checkpoints) -> {dtype}...") averaged: dict[str, torch.Tensor] = { k: (v / n).to(dtype) for k, v in accum.items() } return averaged # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def parse_args(argv: list[str] | None = None) -> argparse.Namespace: parser = argparse.ArgumentParser( description="Average weights from multiple OPF checkpoint directories.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--checkpoints", nargs="+", required=True, metavar="DIR", help="Two or more checkpoint directories to average.", ) parser.add_argument( "--output", required=True, metavar="DIR", help="Output directory for the averaged checkpoint.", ) parser.add_argument( "--dtype", choices=["bfloat16", "float32", "float16"], default="bfloat16", help="Dtype for saved weights.", ) parser.add_argument( "--overwrite", action="store_true", help="Overwrite output directory if it already exists.", ) return parser.parse_args(argv) DTYPE_MAP = { "bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16, } def main(argv: list[str] | None = None) -> int: args = parse_args(argv) checkpoint_dirs = [Path(d).expanduser().resolve() for d in args.checkpoints] output_dir = Path(args.output).expanduser().resolve() dtype = DTYPE_MAP[args.dtype] # Validate inputs if len(checkpoint_dirs) < 2: print("error: at least 2 checkpoints are required for averaging", file=sys.stderr) return 1 for d in checkpoint_dirs: if not d.is_dir(): print(f"error: checkpoint directory not found: {d}", file=sys.stderr) return 1 # Check output if output_dir.exists(): if not args.overwrite: print( f"error: output directory already exists: {output_dir}\n" " Pass --overwrite to replace it.", file=sys.stderr, ) return 1 shutil.rmtree(output_dir) output_dir.mkdir(parents=True) # Average weights try: averaged = average_checkpoints(checkpoint_dirs, dtype=dtype) except (FileNotFoundError, ValueError) as exc: print(f"error: {exc}", file=sys.stderr) return 1 # Save averaged weights weights_path = output_dir / "model.safetensors" print(f"Saving averaged weights -> {weights_path}") save_file(averaged, str(weights_path)) # Copy config.json from the first checkpoint src_config = checkpoint_dirs[0] / "config.json" if not src_config.exists(): print( f"warning: config.json not found in {checkpoint_dirs[0]}; skipping", file=sys.stderr, ) else: dst_config = output_dir / "config.json" shutil.copy2(src_config, dst_config) print(f"Copied config.json from {checkpoint_dirs[0]} -> {dst_config}") # Write a small provenance file so it's clear how this checkpoint was made provenance = { "type": "checkpoint_average", "source_checkpoints": [str(d) for d in checkpoint_dirs], "num_checkpoints": len(checkpoint_dirs), "output_dtype": args.dtype, } (output_dir / "avg_provenance.json").write_text( json.dumps(provenance, indent=2) + "\n", encoding="utf-8" ) print(f"\nDone. Averaged checkpoint saved to: {output_dir}") return 0 if __name__ == "__main__": sys.exit(main())