#!/usr/bin/env python3 """Archive older checkpoint .pt files while keeping key recovery points. Default behavior is a dry run. Use --execute to actually move files. """ from __future__ import annotations import argparse import errno import re import subprocess import sys from dataclasses import dataclass from pathlib import Path DEFAULT_ARCHIVE_ROOT = Path( "/apdcephfs_cq10/share_1603164/user/schmittzhu/code/ckpts" ) EXTRA_KEEP_RELATIVE = { # run_ov1_v10_phase1_cls.sh defaults to this v9 ep3 checkpoint. "checkpoints/spatial_beats_ov1_local_spatial_v9_ov123_exp/" "03_ov123_top4/epoch_0003.pt", } EPOCH_RE = re.compile(r"^epoch[_-]?(\d+)\.pt$") TRAILING_NUM_RE = re.compile(r"^(.*?)(\d+)\.pt$") @dataclass(frozen=True) class PtFile: path: Path size: int mtime: float def format_size(num_bytes: int) -> str: units = ["B", "KiB", "MiB", "GiB", "TiB", "PiB"] size = float(num_bytes) for unit in units: if size < 1024 or unit == units[-1]: if unit == "B": return f"{num_bytes} B" return f"{size:.1f} {unit}" size /= 1024 raise AssertionError("unreachable") def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description=( "Move old checkpoints to an archive directory, preserving original " "relative paths. Keeps best.pt, last.pt, the max-numbered epoch " "checkpoint per directory, and explicit keep-list exceptions." ) ) parser.add_argument( "--checkpoints-root", type=Path, default=Path("checkpoints"), help="Checkpoint root to scan. Default: checkpoints", ) parser.add_argument( "--archive-root", type=Path, default=DEFAULT_ARCHIVE_ROOT, help=f"Archive root. Default: {DEFAULT_ARCHIVE_ROOT}", ) parser.add_argument( "--execute", action="store_true", help="Actually move files. Without this flag, only prints a dry run.", ) parser.add_argument( "--max-depth", type=int, default=3, help=( "Maximum depth under the repo root to scan. Default: 3, matching " "the current checkpoint layout." ), ) parser.add_argument( "--list-files", action="store_true", help="Print every move candidate, not just summary and top directories.", ) parser.add_argument( "--policy", choices=("best-last-max", "minimal"), default="best-last-max", help=( "Retention policy. best-last-max keeps best.pt, last.pt, and the " "max-numbered epoch checkpoint per directory. minimal keeps best.pt " "when present, otherwise last.pt, plus explicit keep-list exceptions." ), ) return parser.parse_args() def relative_to_cwd(path: Path) -> Path: if not path.is_absolute(): return path return path.relative_to(Path.cwd().resolve()) def collect_pt_files(root: Path, max_depth: int) -> dict[Path, list[PtFile]]: by_dir: dict[Path, list[PtFile]] = {} if not root.is_dir(): raise FileNotFoundError(f"checkpoint root does not exist: {root}") result = subprocess.run( [ "find", str(root), "-maxdepth", str(max_depth), "-type", "f", "-name", "*.pt", "-printf", "%p\t%s\t%T@\n", ], check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, ) if result.returncode != 0: raise RuntimeError(result.stderr.strip() or "find failed") for line in result.stdout.splitlines(): path_text, size_text, mtime_text = line.rsplit("\t", 2) path = Path(path_text) item = PtFile(path=path, size=int(size_text), mtime=float(mtime_text)) by_dir.setdefault(path.parent, []).append(item) return by_dir def max_numbered_checkpoint(files: list[PtFile]) -> Path | None: epoch_candidates: list[tuple[int, Path]] = [] other_candidates: list[tuple[int, Path]] = [] for item in files: name = item.path.name match = EPOCH_RE.match(name) if match: epoch_candidates.append((int(match.group(1)), item.path)) continue match = TRAILING_NUM_RE.match(name) if match and name not in {"best.pt", "last.pt"}: other_candidates.append((int(match.group(2)), item.path)) if epoch_candidates: return max(epoch_candidates, key=lambda pair: pair[0])[1] if other_candidates: return max(other_candidates, key=lambda pair: pair[0])[1] return None def build_keep_set(by_dir: dict[Path, list[PtFile]], policy: str) -> set[Path]: keep: set[Path] = set() extra_keep = {Path(item) for item in EXTRA_KEEP_RELATIVE} for files in by_dir.values(): by_name = {item.path.name: item.path for item in files} if policy == "best-last-max": for name in ("best.pt", "last.pt"): if name in by_name: keep.add(by_name[name]) max_epoch = max_numbered_checkpoint(files) if max_epoch is not None: keep.add(max_epoch) elif policy == "minimal": if "best.pt" in by_name: keep.add(by_name["best.pt"]) elif "last.pt" in by_name: keep.add(by_name["last.pt"]) else: max_epoch = max_numbered_checkpoint(files) if max_epoch is not None: keep.add(max_epoch) else: raise ValueError(f"unknown policy: {policy}") for item in files: if relative_to_cwd(item.path) in extra_keep: keep.add(item.path) return keep def destination_for(source: Path, archive_root: Path) -> Path: return archive_root / relative_to_cwd(source) def move_one(source: Path, destination: Path) -> None: if destination.exists(): raise FileExistsError(f"archive target already exists: {destination}") destination.parent.mkdir(parents=True, exist_ok=True) try: source.rename(destination) except OSError as exc: if exc.errno == errno.EXDEV: raise RuntimeError( "source and archive are on different filesystems; refusing to " "copy-then-delete automatically" ) from exc raise def main() -> int: args = parse_args() archive_root = args.archive_root.resolve() by_dir = collect_pt_files(args.checkpoints_root, args.max_depth) keep = build_keep_set(by_dir, args.policy) all_files = [item for files in by_dir.values() for item in files] candidates = [item for item in all_files if item.path not in keep] total_size = sum(item.size for item in all_files) keep_size = sum(item.size for item in all_files if item.path in keep) move_size = sum(item.size for item in candidates) print(f"Mode: {'EXECUTE' if args.execute else 'DRY-RUN'}") print(f"Policy: {args.policy}") print(f"Checkpoint dirs: {len(by_dir)}") print(f"Total .pt files: {len(all_files)} ({format_size(total_size)})") print(f"Keep files: {len(all_files) - len(candidates)} ({format_size(keep_size)})") print(f"Move candidates: {len(candidates)} ({format_size(move_size)})") print(f"Archive root: {archive_root}") print() print("Extra keep-list:") for relpath in sorted(EXTRA_KEEP_RELATIVE): path = Path.cwd() / relpath print(f" KEEP {relpath} ({'exists' if path.exists() else 'missing'})") print() per_dir: list[tuple[int, int, Path]] = [] for directory, files in by_dir.items(): move_count = sum(1 for item in files if item.path not in keep) move_bytes = sum(item.size for item in files if item.path not in keep) if move_count: per_dir.append((move_bytes, move_count, directory)) print("Top directories by archived size:") for size, count, directory in sorted(per_dir, reverse=True)[:30]: print(f" {format_size(size):>10} files={count:<3} {relative_to_cwd(directory)}") print() if args.list_files: print("Move candidates:") for item in sorted(candidates, key=lambda entry: str(entry.path)): print( f" {relative_to_cwd(item.path)} -> " f"{destination_for(item.path, archive_root)}" ) print() if not args.execute: print("Dry run only. Re-run with --execute to move files.") return 0 moved = 0 moved_bytes = 0 for item in sorted(candidates, key=lambda entry: str(entry.path)): destination = destination_for(item.path, archive_root) move_one(item.path, destination) moved += 1 moved_bytes += item.size if moved % 25 == 0: print(f"Moved {moved}/{len(candidates)} files ({format_size(moved_bytes)})") print(f"Done. Moved {moved} files ({format_size(moved_bytes)}).") return 0 if __name__ == "__main__": sys.exit(main())