clique / src /3_train_summary_seeds.py
qingy2024's picture
Upload folder using huggingface_hub
f74dd01 verified
"""
Batch runner for LRMC summary.
Iterates over every XXX.json file in cora_seeds/stage0 and runs the
2.6_lrmc_summary.py script for each seed, writing results to
cora_seeds/summary/XXX.json via the script's --output_json flag.
This version does not capture stdout/stderr or write .txt logs; it only
directs the underlying script to produce structured JSON summaries.
"""
import os
import glob
import subprocess
from pathlib import Path
from rich import print
import argparse
# --------------------------------------------- #
# Configuration
# --------------------------------------------- #
SCRIPT_NAME = "2.6_lrmc_summary.py"
DATASET = "Cora"
HIDDEN = "32"
EPOCHS = "200"
LR = "0.05"
RUNS = "20"
EXTRA_FLAGS = ["--expand_core_with_train"]
# --------------------------------------------- #
def run_script(seed_path: Path, out_json: Path) -> None:
"""Execute the summary script for a single seed, writing JSON to out_json."""
cmd = [
"python3",
SCRIPT_NAME,
"--dataset", DATASET,
"--seeds", str(seed_path),
"--hidden", HIDDEN,
"--epochs", EPOCHS,
"--lr", LR,
"--runs", RUNS,
"-o", str(out_json),
*EXTRA_FLAGS
]
# Run the command without capturing output; raise on failure
subprocess.run(cmd, check=True, cwd=Path.cwd(), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
def main() -> None:
"""Main entry point."""
parser = argparse.ArgumentParser(description="Batch runner for LRMC summary.")
parser.add_argument('--baseline', type=str, choices=['random'], help='Use a baseline method.')
args = parser.parse_args()
if args.baseline == 'random':
SEEDS_DIR = Path("cora_seeds/stage0_rand")
SUMMARY_DIR = Path("cora_seeds/summary_rand")
else:
SEEDS_DIR = Path("cora_seeds/stage0")
SUMMARY_DIR = Path("cora_seeds/summary")
# Make sure the summary directory exists
SUMMARY_DIR.mkdir(parents=True, exist_ok=True)
# Grab all *.json files in the seed folder
seed_files = sorted(SEEDS_DIR.glob("*.json"))
if not seed_files:
print(f"[red]No JSON seed files found in {SEEDS_DIR!s}[/red]")
return
# Rich progress bar – one tick per seed file
try:
from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn
except ImportError:
print("[red]The 'rich' package is required – run: pip install rich[/red]")
return
progress = Progress(
TextColumn("[bold cyan]{task.fields[seed]}"),
BarColumn(),
TimeRemainingColumn(),
transient=True, # hide the bar once done
)
with progress:
task_id = progress.add_task(
"Processing seeds", total=len(seed_files), seed="Preparing..."
)
for seed_path in seed_files:
seed_name = seed_path.stem # XXX (without .json)
out_json = SUMMARY_DIR / f"{seed_name}.json"
# Run the script; JSON will be written by the script itself
try:
run_script(seed_path, out_json)
except Exception as exc:
print(f"[red]Failed: {seed_name}.json -> {exc}[/red]")
# Update the progress bar
progress.update(
task_id,
advance=1,
seed=f"{seed_name}.json"
)
print(f"[green]All seeds processed – JSON summaries in {SUMMARY_DIR}/[/green]")
if __name__ == "__main__":
main()