File size: 9,192 Bytes
4fdc640
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
#!/usr/bin/env python3
"""Archive older checkpoint .pt files while keeping key recovery points.

Default behavior is a dry run. Use --execute to actually move files.
"""

from __future__ import annotations

import argparse
import errno
import re
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path


DEFAULT_ARCHIVE_ROOT = Path(
    "/apdcephfs_cq10/share_1603164/user/schmittzhu/code/ckpts"
)

EXTRA_KEEP_RELATIVE = {
    # run_ov1_v10_phase1_cls.sh defaults to this v9 ep3 checkpoint.
    "checkpoints/spatial_beats_ov1_local_spatial_v9_ov123_exp/"
    "03_ov123_top4/epoch_0003.pt",
}

EPOCH_RE = re.compile(r"^epoch[_-]?(\d+)\.pt$")
TRAILING_NUM_RE = re.compile(r"^(.*?)(\d+)\.pt$")


@dataclass(frozen=True)
class PtFile:
    path: Path
    size: int
    mtime: float


def format_size(num_bytes: int) -> str:
    units = ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]
    size = float(num_bytes)
    for unit in units:
        if size < 1024 or unit == units[-1]:
            if unit == "B":
                return f"{num_bytes} B"
            return f"{size:.1f} {unit}"
        size /= 1024
    raise AssertionError("unreachable")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description=(
            "Move old checkpoints to an archive directory, preserving original "
            "relative paths. Keeps best.pt, last.pt, the max-numbered epoch "
            "checkpoint per directory, and explicit keep-list exceptions."
        )
    )
    parser.add_argument(
        "--checkpoints-root",
        type=Path,
        default=Path("checkpoints"),
        help="Checkpoint root to scan. Default: checkpoints",
    )
    parser.add_argument(
        "--archive-root",
        type=Path,
        default=DEFAULT_ARCHIVE_ROOT,
        help=f"Archive root. Default: {DEFAULT_ARCHIVE_ROOT}",
    )
    parser.add_argument(
        "--execute",
        action="store_true",
        help="Actually move files. Without this flag, only prints a dry run.",
    )
    parser.add_argument(
        "--max-depth",
        type=int,
        default=3,
        help=(
            "Maximum depth under the repo root to scan. Default: 3, matching "
            "the current checkpoint layout."
        ),
    )
    parser.add_argument(
        "--list-files",
        action="store_true",
        help="Print every move candidate, not just summary and top directories.",
    )
    parser.add_argument(
        "--policy",
        choices=("best-last-max", "minimal"),
        default="best-last-max",
        help=(
            "Retention policy. best-last-max keeps best.pt, last.pt, and the "
            "max-numbered epoch checkpoint per directory. minimal keeps best.pt "
            "when present, otherwise last.pt, plus explicit keep-list exceptions."
        ),
    )
    return parser.parse_args()


def relative_to_cwd(path: Path) -> Path:
    if not path.is_absolute():
        return path
    return path.relative_to(Path.cwd().resolve())


def collect_pt_files(root: Path, max_depth: int) -> dict[Path, list[PtFile]]:
    by_dir: dict[Path, list[PtFile]] = {}
    if not root.is_dir():
        raise FileNotFoundError(f"checkpoint root does not exist: {root}")

    result = subprocess.run(
        [
            "find",
            str(root),
            "-maxdepth",
            str(max_depth),
            "-type",
            "f",
            "-name",
            "*.pt",
            "-printf",
            "%p\t%s\t%T@\n",
        ],
        check=False,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True,
    )
    if result.returncode != 0:
        raise RuntimeError(result.stderr.strip() or "find failed")

    for line in result.stdout.splitlines():
        path_text, size_text, mtime_text = line.rsplit("\t", 2)
        path = Path(path_text)
        item = PtFile(path=path, size=int(size_text), mtime=float(mtime_text))
        by_dir.setdefault(path.parent, []).append(item)

    return by_dir


def max_numbered_checkpoint(files: list[PtFile]) -> Path | None:
    epoch_candidates: list[tuple[int, Path]] = []
    other_candidates: list[tuple[int, Path]] = []

    for item in files:
        name = item.path.name
        match = EPOCH_RE.match(name)
        if match:
            epoch_candidates.append((int(match.group(1)), item.path))
            continue

        match = TRAILING_NUM_RE.match(name)
        if match and name not in {"best.pt", "last.pt"}:
            other_candidates.append((int(match.group(2)), item.path))

    if epoch_candidates:
        return max(epoch_candidates, key=lambda pair: pair[0])[1]
    if other_candidates:
        return max(other_candidates, key=lambda pair: pair[0])[1]
    return None


def build_keep_set(by_dir: dict[Path, list[PtFile]], policy: str) -> set[Path]:
    keep: set[Path] = set()
    extra_keep = {Path(item) for item in EXTRA_KEEP_RELATIVE}

    for files in by_dir.values():
        by_name = {item.path.name: item.path for item in files}
        if policy == "best-last-max":
            for name in ("best.pt", "last.pt"):
                if name in by_name:
                    keep.add(by_name[name])

            max_epoch = max_numbered_checkpoint(files)
            if max_epoch is not None:
                keep.add(max_epoch)
        elif policy == "minimal":
            if "best.pt" in by_name:
                keep.add(by_name["best.pt"])
            elif "last.pt" in by_name:
                keep.add(by_name["last.pt"])
            else:
                max_epoch = max_numbered_checkpoint(files)
                if max_epoch is not None:
                    keep.add(max_epoch)
        else:
            raise ValueError(f"unknown policy: {policy}")

        for item in files:
            if relative_to_cwd(item.path) in extra_keep:
                keep.add(item.path)

    return keep


def destination_for(source: Path, archive_root: Path) -> Path:
    return archive_root / relative_to_cwd(source)


def move_one(source: Path, destination: Path) -> None:
    if destination.exists():
        raise FileExistsError(f"archive target already exists: {destination}")
    destination.parent.mkdir(parents=True, exist_ok=True)
    try:
        source.rename(destination)
    except OSError as exc:
        if exc.errno == errno.EXDEV:
            raise RuntimeError(
                "source and archive are on different filesystems; refusing to "
                "copy-then-delete automatically"
            ) from exc
        raise


def main() -> int:
    args = parse_args()
    archive_root = args.archive_root.resolve()
    by_dir = collect_pt_files(args.checkpoints_root, args.max_depth)
    keep = build_keep_set(by_dir, args.policy)

    all_files = [item for files in by_dir.values() for item in files]
    candidates = [item for item in all_files if item.path not in keep]
    total_size = sum(item.size for item in all_files)
    keep_size = sum(item.size for item in all_files if item.path in keep)
    move_size = sum(item.size for item in candidates)

    print(f"Mode: {'EXECUTE' if args.execute else 'DRY-RUN'}")
    print(f"Policy: {args.policy}")
    print(f"Checkpoint dirs: {len(by_dir)}")
    print(f"Total .pt files: {len(all_files)} ({format_size(total_size)})")
    print(f"Keep files: {len(all_files) - len(candidates)} ({format_size(keep_size)})")
    print(f"Move candidates: {len(candidates)} ({format_size(move_size)})")
    print(f"Archive root: {archive_root}")
    print()

    print("Extra keep-list:")
    for relpath in sorted(EXTRA_KEEP_RELATIVE):
        path = Path.cwd() / relpath
        print(f"  KEEP {relpath} ({'exists' if path.exists() else 'missing'})")
    print()

    per_dir: list[tuple[int, int, Path]] = []
    for directory, files in by_dir.items():
        move_count = sum(1 for item in files if item.path not in keep)
        move_bytes = sum(item.size for item in files if item.path not in keep)
        if move_count:
            per_dir.append((move_bytes, move_count, directory))

    print("Top directories by archived size:")
    for size, count, directory in sorted(per_dir, reverse=True)[:30]:
        print(f"  {format_size(size):>10}  files={count:<3}  {relative_to_cwd(directory)}")
    print()

    if args.list_files:
        print("Move candidates:")
        for item in sorted(candidates, key=lambda entry: str(entry.path)):
            print(
                f"  {relative_to_cwd(item.path)} -> "
                f"{destination_for(item.path, archive_root)}"
            )
        print()

    if not args.execute:
        print("Dry run only. Re-run with --execute to move files.")
        return 0

    moved = 0
    moved_bytes = 0
    for item in sorted(candidates, key=lambda entry: str(entry.path)):
        destination = destination_for(item.path, archive_root)
        move_one(item.path, destination)
        moved += 1
        moved_bytes += item.size
        if moved % 25 == 0:
            print(f"Moved {moved}/{len(candidates)} files ({format_size(moved_bytes)})")

    print(f"Done. Moved {moved} files ({format_size(moved_bytes)}).")
    return 0


if __name__ == "__main__":
    sys.exit(main())