File size: 3,809 Bytes
29e5bf8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
Run Tasks 1,2,3,5 for every available checkpoint (excluding Task 4).

Usage:
  python analysis/run_tasks_except4_all_models.py
  python analysis/run_tasks_except4_all_models.py --input "dharmo rakṣati rakṣitaḥ"
"""

from __future__ import annotations

import argparse
import json
import os
import subprocess
import sys
from datetime import datetime
from pathlib import Path


ROOT = Path(__file__).resolve().parents[1]
DEFAULT_OUT_ROOT = ROOT / "analysis" / "outputs_multi"


def discover_checkpoints() -> list[Path]:
    roots = [ROOT / "results", ROOT / "results7", ROOT / "ablation_results"]
    out: list[Path] = []
    for base in roots:
        if not base.exists():
            continue
        for ckpt in sorted(base.glob("*/best_model.pt")):
            out.append(ckpt)
    return out


def slug_for_checkpoint(ckpt: Path) -> str:
    root = ckpt.parent.parent.name
    exp = ckpt.parent.name
    return f"{root}__{exp}"


def run_task(task: str, ckpt: Path, input_text: str, out_dir: Path) -> tuple[int, float]:
    cmd = [
        sys.executable,
        str(ROOT / "analysis" / "run_analysis.py"),
        "--task", task,
        "--checkpoint", str(ckpt),
        "--output_dir", str(out_dir),
    ]
    if task == "2":
        cmd.extend(["--input", input_text])

    start = datetime.now()
    env = os.environ.copy()
    env.setdefault("HF_HOME", "/tmp/hf_home")
    env.setdefault("HF_DATASETS_CACHE", "/tmp/hf_datasets")
    env.setdefault("HF_HUB_CACHE", "/tmp/hf_hub")
    env.setdefault("TRANSFORMERS_CACHE", "/tmp/hf_transformers")
    os.makedirs(env["HF_HOME"], exist_ok=True)
    os.makedirs(env["HF_DATASETS_CACHE"], exist_ok=True)
    os.makedirs(env["HF_HUB_CACHE"], exist_ok=True)
    os.makedirs(env["TRANSFORMERS_CACHE"], exist_ok=True)

    proc = subprocess.run(cmd, cwd=str(ROOT), env=env)
    seconds = (datetime.now() - start).total_seconds()
    return proc.returncode, seconds


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", default="dharmo rakṣati rakṣitaḥ")
    parser.add_argument("--out_root", default=str(DEFAULT_OUT_ROOT))
    args = parser.parse_args()

    checkpoints = discover_checkpoints()
    if not checkpoints:
        raise FileNotFoundError("No checkpoints found under results/results7/ablation_results.")

    out_root = Path(args.out_root)
    out_root.mkdir(parents=True, exist_ok=True)

    tasks = ["1", "2", "3", "5"]
    summary = {
        "timestamp": datetime.now().isoformat(timespec="seconds"),
        "tasks": tasks,
        "checkpoints": [],
    }

    for ckpt in checkpoints:
        slug = slug_for_checkpoint(ckpt)
        model_out = out_root / slug
        model_out.mkdir(parents=True, exist_ok=True)
        print(f"\n=== Checkpoint: {ckpt} ===")
        model_item = {
            "checkpoint": str(ckpt),
            "output_dir": str(model_out),
            "tasks": [],
        }

        for task in tasks:
            task_out = model_out / f"task{task}"
            task_out.mkdir(parents=True, exist_ok=True)
            print(f"-> Running task {task} ...")
            code, sec = run_task(task, ckpt, args.input, task_out)
            item = {
                "task": task,
                "exit_code": code,
                "seconds": round(sec, 2),
                "output_dir": str(task_out),
            }
            model_item["tasks"].append(item)
            status = "OK" if code == 0 else "FAILED"
            print(f"   {status} ({sec:.1f}s)")

        summary["checkpoints"].append(model_item)

    summary_path = out_root / "summary.json"
    with summary_path.open("w", encoding="utf-8") as f:
        json.dump(summary, f, ensure_ascii=False, indent=2)
    print(f"\nSaved summary: {summary_path}")


if __name__ == "__main__":
    main()