File size: 3,548 Bytes
f74dd01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Batch runner for LRMC summary.

Iterates over every XXX.json file in cora_seeds/stage0 and runs the
2.6_lrmc_summary.py script for each seed, writing results to
cora_seeds/summary/XXX.json via the script's --output_json flag.

This version does not capture stdout/stderr or write .txt logs; it only
directs the underlying script to produce structured JSON summaries.
"""

import os
import glob
import subprocess
from pathlib import Path
from rich import print
import argparse

# --------------------------------------------- #
# Configuration
# --------------------------------------------- #
SCRIPT_NAME    = "2.6_lrmc_summary.py"
DATASET        = "Cora"
HIDDEN         = "32"
EPOCHS         = "200"
LR             = "0.05"
RUNS           = "20"
EXTRA_FLAGS    = ["--expand_core_with_train"]
# --------------------------------------------- #

def run_script(seed_path: Path, out_json: Path) -> None:
    """Execute the summary script for a single seed, writing JSON to out_json."""
    cmd = [
        "python3",
        SCRIPT_NAME,
        "--dataset", DATASET,
        "--seeds", str(seed_path),
        "--hidden", HIDDEN,
        "--epochs", EPOCHS,
        "--lr", LR,
        "--runs", RUNS,
        "-o", str(out_json),
        *EXTRA_FLAGS
    ]

    # Run the command without capturing output; raise on failure
    subprocess.run(cmd, check=True, cwd=Path.cwd(), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

def main() -> None:
    """Main entry point."""
    parser = argparse.ArgumentParser(description="Batch runner for LRMC summary.")
    parser.add_argument('--baseline', type=str, choices=['random'], help='Use a baseline method.')
    args = parser.parse_args()

    if args.baseline == 'random':
        SEEDS_DIR = Path("cora_seeds/stage0_rand")
        SUMMARY_DIR = Path("cora_seeds/summary_rand")
    else:
        SEEDS_DIR = Path("cora_seeds/stage0")
        SUMMARY_DIR = Path("cora_seeds/summary")

    # Make sure the summary directory exists
    SUMMARY_DIR.mkdir(parents=True, exist_ok=True)

    # Grab all *.json files in the seed folder
    seed_files = sorted(SEEDS_DIR.glob("*.json"))
    if not seed_files:
        print(f"[red]No JSON seed files found in {SEEDS_DIR!s}[/red]")
        return

    # Rich progress bar – one tick per seed file
    try:
        from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn
    except ImportError:
        print("[red]The 'rich' package is required – run: pip install rich[/red]")
        return

    progress = Progress(
        TextColumn("[bold cyan]{task.fields[seed]}"),
        BarColumn(),
        TimeRemainingColumn(),
        transient=True,  # hide the bar once done
    )

    with progress:
        task_id = progress.add_task(
            "Processing seeds", total=len(seed_files), seed="Preparing..."
        )

        for seed_path in seed_files:
            seed_name = seed_path.stem  # XXX (without .json)
            out_json = SUMMARY_DIR / f"{seed_name}.json"

            # Run the script; JSON will be written by the script itself
            try:
                run_script(seed_path, out_json)
            except Exception as exc:
                print(f"[red]Failed: {seed_name}.json -> {exc}[/red]")

            # Update the progress bar
            progress.update(
                task_id,
                advance=1,
                seed=f"{seed_name}.json"
            )

    print(f"[green]All seeds processed – JSON summaries in {SUMMARY_DIR}/[/green]")

if __name__ == "__main__":
    main()