File size: 1,497 Bytes
951f760
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#!/usr/bin/env python3
from __future__ import annotations

import json


def build_checkpoint_report(files: list[str]) -> dict[str, object]:
    by_job: dict[str, dict[str, object]] = {}
    for path in files:
        parts = path.split("/")
        if len(parts) < 3 or parts[0] != "jobs":
            continue
        job_id = parts[1]
        filename = parts[-1]
        if filename not in {"best_bpb.pt", "pretrain_final.pt", "latest.pt"}:
            continue
        row = by_job.setdefault(job_id, {"job_id": job_id, "paths": []})
        row["paths"].append(path)

    candidates = []
    for job_id, row in by_job.items():
        paths = list(row["paths"])
        preferred = None
        for suffix in ("pretrain_final.pt", "best_bpb.pt", "latest.pt"):
            for path in paths:
                if path.endswith(suffix):
                    preferred = path
                    break
            if preferred is not None:
                break
        candidates.append({
            "job_id": job_id,
            "preferred_path": preferred,
            "available_paths": sorted(paths),
        })

    candidates.sort(key=lambda row: row["job_id"], reverse=True)
    return {
        "n_candidates": len(candidates),
        "candidates": candidates,
    }


def main() -> int:
    print(json.dumps(build_checkpoint_report([]), indent=2, sort_keys=True))
    return 0


if __name__ == "__main__":
    raise SystemExit(main())