clique / src /grl /analyze_lrmc_sweep.py
qingy2024's picture
Upload folder using huggingface_hub
f74dd01 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Analyze hyperparameter sweeps for L-RMC/GCN on Cora in a "rich" format similar to analyze_stats.py.
Usage:
python analyze_lrmc_sweep.py /path/to/hyperparam_sweep.csv --topk 10 --variant pool
"""
import argparse
from pathlib import Path
from collections import defaultdict
import numpy as np
import pandas as pd
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from rich.text import Text
from rich import box
# ------------------------
# Helpers
# ------------------------
def _fmt(x, digits=3):
if x is None or (isinstance(x, float) and (np.isnan(x) or np.isinf(x))):
return "-"
if isinstance(x, float):
return f"{x:.{digits}f}"
return str(x)
def _to_float(series):
try:
return series.astype(float)
except Exception:
return pd.to_numeric(series, errors="coerce")
def _unique_counts(df, cols):
return {c: df[c].nunique() for c in cols if c in df.columns}
def _has_cols(df, cols):
return all(c in df.columns for c in cols)
def _get_common_hparams():
# Hyperparameters that exist for both 'baseline' and 'pool'
return ["hidden", "lr", "dropout", "self_loop_scale", "use_a2"]
# ------------------------
# Core analysis
# ------------------------
def analyze(df: pd.DataFrame):
# Normalize column names if needed
df = df.copy()
# Expected columns
expected = [
"dataset","variant","hidden","epochs","lr","wd","dropout",
"self_loop_scale","use_a2",
"lrmc_inv_weight","lrmc_gamma",
"seed","val_accuracy","test_accuracy"
]
missing = [c for c in expected if c not in df.columns]
if missing:
raise ValueError(f"CSV missing expected columns: {missing}")
# Cast numerics
num_cols = ["hidden","epochs","lr","wd","dropout","self_loop_scale",
"lrmc_inv_weight","lrmc_gamma","val_accuracy","test_accuracy"]
for c in num_cols:
df[c] = _to_float(df[c])
# Basic summary
summary = {
"file_rows": len(df),
"dataset": ", ".join(sorted(df["dataset"].astype(str).unique())),
"variants": ", ".join(sorted(df["variant"].astype(str).unique())),
"columns": list(df.columns),
"n_columns": len(df.columns),
"unique_counts": _unique_counts(df, ["hidden","lr","dropout","self_loop_scale","use_a2","lrmc_inv_weight","lrmc_gamma"]),
"val_test_corr_all": df["val_accuracy"].corr(df["test_accuracy"]),
"val_test_corr_by_variant": df.groupby("variant").apply(lambda g: g["val_accuracy"].corr(g["test_accuracy"])).to_dict(),
}
# Best configs (global and per variant)
def best_rows(by_cols=("val_accuracy","test_accuracy"), ascending=(False, False), topk=10, variant=None):
sub = df if variant is None else df[df["variant"]==variant]
if len(sub)==0:
return sub
return sub.sort_values(list(by_cols), ascending=list(ascending)).head(topk)
best_overall = best_rows(topk=10)
best_by_variant = {v: best_rows(topk=5, variant=v) for v in df["variant"].unique()}
# Per-hparam aggregates (marginal means)
hparams = ["hidden","lr","dropout","self_loop_scale","use_a2","lrmc_inv_weight","lrmc_gamma"]
per_param = {}
for hp in hparams:
g = df.groupby(hp).agg(
mean_val=("val_accuracy","mean"),
std_val=("val_accuracy","std"),
mean_test=("test_accuracy","mean"),
std_test=("test_accuracy","std"),
n=("val_accuracy","count"),
).sort_values("mean_val", ascending=False)
per_param[hp] = g
per_param_by_variant = {}
for v in df["variant"].unique():
sub = df[df["variant"]==v]
per_param_by_variant[v] = {}
for hp in hparams:
g = sub.groupby(hp).agg(
mean_val=("val_accuracy","mean"),
std_val=("val_accuracy","std"),
mean_test=("test_accuracy","mean"),
std_test=("test_accuracy","std"),
n=("val_accuracy","count"),
).sort_values("mean_val", ascending=False)
per_param_by_variant[v][hp] = g
# Matched comparisons (baseline vs pool), using common hyperparameters
commons = _get_common_hparams()
matched = df.groupby(["variant"]+commons).agg(
mean_val=("val_accuracy","mean"),
mean_test=("test_accuracy","mean"),
best_val=("val_accuracy","max"),
best_test=("test_accuracy","max"),
n=("val_accuracy","count"),
).reset_index()
baseline_mean = matched[matched["variant"]=="baseline"].set_index(commons)
pool_mean = matched[matched["variant"]=="pool"].set_index(commons)
# Compare mean vs mean
comp_mean = pool_mean[["mean_val","mean_test"]].join(
baseline_mean[["mean_val","mean_test"]],
lsuffix="_pool", rsuffix="_base", how="inner"
)
comp_mean["delta_val"] = comp_mean["mean_val_pool"] - comp_mean["mean_val_base"]
comp_mean["delta_test"] = comp_mean["mean_test_pool"] - comp_mean["mean_test_base"]
# Compare best vs best per setting
baseline_best = matched[matched["variant"]=="baseline"].set_index(commons)[["best_val","best_test"]]
pool_best = matched[matched["variant"]=="pool"].set_index(commons)[["best_val","best_test"]]
comp_best = pool_best.join(baseline_best, lsuffix="_pool", rsuffix="_base", how="inner")
comp_best["delta_best_val"] = comp_best["best_val_pool"] - comp_best["best_val_base"]
comp_best["delta_best_test"] = comp_best["best_test_pool"] - comp_best["best_test_base"]
return {
"df": df,
"summary": summary,
"best_overall": best_overall,
"best_by_variant": best_by_variant,
"per_param": per_param,
"per_param_by_variant": per_param_by_variant,
"comp_mean": comp_mean,
"comp_best": comp_best,
"commons": commons,
}
# ------------------------
# Rendering (rich style)
# ------------------------
def render_summary(console: Console, A):
s = A["summary"]
title = Text("L‑RMC/GCN Hyperparameter Sweep — Summary", style="bold white")
body = Text()
body.append(f"File rows: ", style="bold green"); body.append(str(s["file_rows"])+"\n")
body.append(f"Dataset(s): ", style="bold green"); body.append(f"{s['dataset']}\n")
body.append(f"Variants: ", style="bold green"); body.append(f"{s['variants']}\n")
body.append(f"Val/Test Corr (all): ", style="bold green"); body.append(f"{_fmt(s['val_test_corr_all'])}\n")
for v,c in s["val_test_corr_by_variant"].items():
body.append(f"Val/Test Corr ({v}): ", style="bold green"); body.append(f"{_fmt(c)}\n")
# Unique counts
uc = s["unique_counts"]
uc_text = ", ".join([f"{k}={v}" for k,v in uc.items()])
body.append(f"Unique values per hparam: ", style="bold green"); body.append(uc_text+"\n")
console.print(Panel(body, title=title, border_style="cyan", box=box.ROUNDED))
def render_top_configs(console: Console, A, topk=10):
# Overall (by val desc, tiebreak by test desc)
df = A["best_overall"]
table = Table(title=f"Top {min(topk, len(df))} Configs by Val Accuracy (overall)", box=box.SIMPLE_HEAVY)
for col in ["variant","val_accuracy","test_accuracy","hidden","lr","dropout","self_loop_scale","use_a2","lrmc_inv_weight","lrmc_gamma"]:
style = "yellow" if col=="val_accuracy" else ("green" if col=="test_accuracy" else "cyan")
table.add_column(col, style=style, justify="right" if "accuracy" in col else "left")
for _,row in df.iterrows():
table.add_row(
str(row["variant"]),
_fmt(row["val_accuracy"]), _fmt(row["test_accuracy"]),
_fmt(row["hidden"]), _fmt(row["lr"]), _fmt(row["dropout"]),
_fmt(row["self_loop_scale"]), str(row["use_a2"]),
_fmt(row["lrmc_inv_weight"]), _fmt(row["lrmc_gamma"])
)
console.print(table)
# Per-variant bests
for v, sub in A["best_by_variant"].items():
table = Table(title=f"Top {min(5, len(sub))} Configs for variant='{v}' (by Val)", box=box.MINIMAL_HEAVY_HEAD)
for col in ["val_accuracy","test_accuracy","hidden","lr","dropout","self_loop_scale","use_a2","lrmc_inv_weight","lrmc_gamma"]:
style = "yellow" if col=="val_accuracy" else ("green" if col=="test_accuracy" else "cyan")
table.add_column(col, style=style, justify="right" if "accuracy" in col else "right")
for _,row in sub.iterrows():
table.add_row(
_fmt(row["val_accuracy"]), _fmt(row["test_accuracy"]),
_fmt(row["hidden"]), _fmt(row["lr"]), _fmt(row["dropout"]),
_fmt(row["self_loop_scale"]), str(row["use_a2"]),
_fmt(row["lrmc_inv_weight"]), _fmt(row["lrmc_gamma"])
)
console.print(table)
def render_per_param(console: Console, A, variant=None):
title = f"Per‑Hyperparameter Effects (marginal means){'' if variant is None else f' — variant={variant}'}"
console.print(Panel(Text(title, style="bold white"), border_style="magenta", box=box.ROUNDED))
per_param = A["per_param"] if variant is None else A["per_param_by_variant"][variant]
for hp, g in per_param.items():
table = Table(title=f"{hp}", box=box.SIMPLE_HEAD)
table.add_column(hp, style="cyan")
table.add_column("n", style="white", justify="right")
table.add_column("val_mean", style="yellow", justify="right")
table.add_column("val_std", style="yellow", justify="right")
table.add_column("test_mean", style="green", justify="right")
table.add_column("test_std", style="green", justify="right")
# Find best by val_mean and by test_mean
if len(g)>0:
best_val_idx = g["mean_val"].idxmax()
best_test_idx = g["mean_test"].idxmax()
else:
best_val_idx = best_test_idx = None
for idx, row in g.reset_index().iterrows():
key = row[hp]
is_best_val = (key == best_val_idx)
is_best_test = (key == best_test_idx)
def mark(s, best):
return f"[bold green]{s}[/]" if best else s
table.add_row(
f"{key}",
_fmt(row["n"]),
mark(_fmt(row["mean_val"]), is_best_val),
_fmt(row["std_val"]),
mark(_fmt(row["mean_test"]), is_best_test),
_fmt(row["std_test"]),
)
console.print(table)
def render_matched_comparisons(console: Console, A):
console.print(Panel(Text("Baseline vs Pool — Matched Hyperparameter Comparisons", style="bold white"), border_style="blue", box=box.ROUNDED))
comp_mean = A["comp_mean"].reset_index()
comp_best = A["comp_best"].reset_index()
commons = A["commons"]
# High-level stats
win_rate_mean = float((comp_mean["delta_test"] > 0).mean()) if len(comp_mean)>0 else float("nan")
win_rate_best = float((comp_best["delta_best_test"] > 0).mean()) if len(comp_best)>0 else float("nan")
stats_text = Text()
stats_text.append("Pool > Baseline (by mean test across same settings): ", style="bold green")
stats_text.append(f"{_fmt(100*win_rate_mean, 1)}% of settings\n")
stats_text.append("Pool best > Baseline best (per setting): ", style="bold green")
stats_text.append(f"{_fmt(100*win_rate_best, 1)}% of settings\n")
console.print(Panel(stats_text, border_style="green", box=box.SQUARE))
# Show the top positive / negative settings
def table_from_df(df, deltas_col, title):
df = df.sort_values(deltas_col, ascending=False)
head = df.head(8)
tail = df.tail(8)
for part, name in [(head, "Top Gains (Pool minus Baseline)"), (tail, "Largest Drops (Pool minus Baseline)")]:
table = Table(title=f"{title}{name}", box=box.MINIMAL_HEAVY_HEAD)
for c in commons + ["mean_test_pool","mean_test_base","delta_test"]:
style = "green" if c in ("mean_test_pool","mean_test_base","delta_test") else "cyan"
table.add_column(c, style=style, justify="right")
for _,r in part.iterrows():
row = [str(r[c]) for c in commons] + [_fmt(r["mean_test_pool"]), _fmt(r["mean_test_base"]), _fmt(r["delta_test"])]
table.add_row(*row)
console.print(table)
if len(comp_mean)>0:
table_from_df(comp_mean, "delta_test", "Mean Test Accuracy (matched)")
if len(comp_best)>0:
# For best-vs-best we only need the deltas
df = comp_best.sort_values("delta_best_test", ascending=False)
head = df.head(8)
tail = df.tail(8)
for part, name in [(head, "Top Gains"), (tail, "Largest Drops")]:
table = Table(title=f"Best-vs-Best Test Accuracy — {name}", box=box.SIMPLE_HEAVY)
for c in commons + ["best_test_pool","best_test_base","delta_best_test"]:
style = "green" if c in ("best_test_pool","best_test_base","delta_best_test") else "cyan"
table.add_column(c, style=style, justify="right")
for _,r in part.iterrows():
row = [str(r[c]) for c in commons] + [_fmt(r["best_test_pool"]), _fmt(r["best_test_base"]), _fmt(r["delta_best_test"])]
table.add_row(*row)
console.print(table)
def recommend_settings(console: Console, A):
"""Recommend a configuration per variant based on marginal means and sanity checks."""
per_pool = A["per_param_by_variant"].get("pool", {})
per_base = A["per_param_by_variant"].get("baseline", {})
# Pool recommendations (favor generalization: choose by mean_test where reasonable)
rec_pool = {
"hidden": per_pool.get("hidden", pd.DataFrame()).get("mean_test", pd.Series()).idxmax() if "hidden" in per_pool else None,
"lr": per_pool.get("lr", pd.DataFrame()).get("mean_test", pd.Series()).idxmax() if "lr" in per_pool else None,
"dropout": per_pool.get("dropout", pd.DataFrame()).get("mean_test", pd.Series()).idxmax() if "dropout" in per_pool else None,
"self_loop_scale": per_pool.get("self_loop_scale", pd.DataFrame()).get("mean_test", pd.Series()).idxmax() if "self_loop_scale" in per_pool else None,
"use_a2": per_pool.get("use_a2", pd.DataFrame()).get("mean_test", pd.Series()).idxmax() if "use_a2" in per_pool else None,
"lrmc_inv_weight": per_pool.get("lrmc_inv_weight", pd.DataFrame()).get("mean_test", pd.Series()).idxmax() if "lrmc_inv_weight" in per_pool else None,
"lrmc_gamma": per_pool.get("lrmc_gamma", pd.DataFrame()).get("mean_test", pd.Series()).idxmax() if "lrmc_gamma" in per_pool else None,
}
# Baseline recommendations
rec_base = {
"hidden": per_base.get("hidden", pd.DataFrame()).get("mean_test", pd.Series()).idxmax() if "hidden" in per_base else None,
"lr": per_base.get("lr", pd.DataFrame()).get("mean_test", pd.Series()).idxmax() if "lr" in per_base else None,
"dropout": per_base.get("dropout", pd.DataFrame()).get("mean_test", pd.Series()).idxmax() if "dropout" in per_base else None,
"self_loop_scale": per_base.get("self_loop_scale", pd.DataFrame()).get("mean_test", pd.Series()).idxmax() if "self_loop_scale" in per_base else None,
"use_a2": per_base.get("use_a2", pd.DataFrame()).get("mean_test", pd.Series()).idxmax() if "use_a2" in per_base else None,
}
# Render
def render_panel(title, rec):
txt = Text()
for k,v in rec.items():
txt.append(f"{k}: ", style="bold cyan"); txt.append(f"{v}\n")
console.print(Panel(txt, title=title, border_style="green", box=box.ROUNDED))
render_panel("Recommended settings — variant=pool (by mean test)", rec_pool)
render_panel("Recommended settings — variant=baseline (by mean test)", rec_base)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("csv_path", type=str, help="Path to hyperparam_sweep.csv")
parser.add_argument("--topk", type=int, default=10, help="How many top configs to display")
parser.add_argument("--variant", type=str, default=None, help="Filter to a specific variant (baseline or pool) for per-parameter tables")
args = parser.parse_args()
csv_path = Path(args.csv_path)
if not csv_path.exists():
raise SystemExit(f"CSV not found: {csv_path}")
df = pd.read_csv(csv_path)
A = analyze(df)
console = Console()
render_summary(console, A)
render_top_configs(console, A, topk=args.topk)
# Per-parameter effects (all runs)
render_per_param(console, A, variant=None)
# Per-parameter effects per variant (if requested)
if args.variant is not None:
render_per_param(console, A, variant=args.variant)
render_matched_comparisons(console, A)
recommend_settings(console, A)
if __name__ == "__main__":
main()