""" Train a GCN on an L‑RMC subgraph and compare to a full‑graph baseline. Modes: - core_mode=forward : Train on core subgraph, then forward on full graph (your current approach). - core_mode=appnp : Train on core subgraph, then seed logits on core and APPNP‑propagate on full graph. Extras: - --expand_core_with_train : Make sure all training labels lie inside the core (C' = C ∪ train_idx) for fair train‑time comparison. - --warm_ft_epochs N : Optional short finetune on the full graph starting from the core model's weights (measure time‑to‑target). It prints: - Dataset stats - Core size and coverage of train/val/test inside the core - Train/Val/Test accuracy for baseline and core model - Wall‑clock times """ import argparse import json import time import random from statistics import mean, stdev from pathlib import Path from typing import Dict import torch import torch.nn.functional as F from torch import nn, Tensor from torch_geometric.datasets import Planetoid from torch_geometric.nn import GCNConv, APPNP from torch_geometric.utils import subgraph # ------------------------------------------------------------ # Rich imports # ------------------------------------------------------------ from rich.console import Console from rich.table import Table from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn # Rich console instance console = Console() # ------------------------------------------------------------ # Utilities # ------------------------------------------------------------ def load_top1_assignment(seeds_json: str, n_nodes: int) -> torch.Tensor: """ seeds_json format (expected): {"clusters": [{"seed_nodes":[...], "score": float, ...}, ...]} We pick the cluster with max (score, size) and return a boolean core mask. Always assume that the seeds json nodes are 1-indexed. """ obj = json.loads(Path(seeds_json).read_text()) clusters = obj.get("clusters", []) if not clusters: return torch.zeros(n_nodes, dtype=torch.bool) best = max(clusters, key=lambda c: (float(c.get("score", 0.0)), len(c.get("seed_nodes", [])))) ids = best.get("seed_nodes", []) ids = [int(x) - 1 for x in ids] # Convert 1-indexed to 0-indexed ids = sorted(set([i for i in ids if 0 <= i < n_nodes])) mask = torch.zeros(n_nodes, dtype=torch.bool) if ids: mask[torch.tensor(ids, dtype=torch.long)] = True return mask def coverage_counts(core_mask: torch.Tensor, train_mask: torch.Tensor, val_mask: torch.Tensor, test_mask: torch.Tensor) -> Dict[str, int]: return { "core_size": int(core_mask.sum().item()), "train_in_core": int((core_mask & train_mask).sum().item()), "val_in_core": int((core_mask & val_mask).sum().item()), "test_in_core": int((core_mask & test_mask).sum().item()), } def accuracy(logits: Tensor, y: Tensor, mask: Tensor) -> float: pred = logits[mask].argmax(dim=1) return (pred == y[mask]).float().mean().item() def set_seed(seed: int): """Set random seeds for reproducibility across runs.""" random.seed(seed) try: import numpy as np # optional np.random.seed(seed) except Exception: pass torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) # Make CUDA/CuDNN deterministic where applicable torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # ------------------------------------------------------------ # Models # ------------------------------------------------------------ class GCN2(nn.Module): def __init__(self, in_dim: int, hid: int, out_dim: int, dropout: float = 0.5): super().__init__() self.c1 = GCNConv(in_dim, hid) self.c2 = GCNConv(hid, out_dim) self.dropout = dropout def forward(self, x, ei): x = self.c1(x, ei) x = torch.relu(x) x = F.dropout(x, p=self.dropout, training=self.training) x = self.c2(x, ei) return x # ------------------------------------------------------------ # Training / evaluation # ------------------------------------------------------------ @torch.no_grad() def eval_all(model: nn.Module, data) -> Dict[str, float]: model.eval() logits = model(data.x, data.edge_index) return { "train": accuracy(logits, data.y, data.train_mask), "val": accuracy(logits, data.y, data.val_mask), "test": accuracy(logits, data.y, data.test_mask), } def train(model: nn.Module, data, epochs=200, lr=0.01, wd=5e-4, patience=100): opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd) best, best_state, bad = -1.0, None, 0 # Optional progress bar with Progress( SpinnerColumn(), "[progress.description]{task.description}", TimeElapsedColumn(), transient=True, ) as progress: task = progress.add_task("Training", total=epochs) for ep in range(1, epochs + 1): model.train() opt.zero_grad(set_to_none=True) out = model(data.x, data.edge_index) loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask]) loss.backward() opt.step() # early stop on val with torch.no_grad(): val = accuracy(model(data.x, data.edge_index), data.y, data.val_mask) if val > best: best, bad = val, 0 best_state = {k: v.detach().clone() for k, v in model.state_dict().items()} else: bad += 1 if bad >= patience: break progress.update(task, advance=1, description=f"Epoch {ep} | val={val:.4f}") if best_state is not None: model.load_state_dict(best_state) model.eval() def subset_data(data, nodes_idx: torch.Tensor): """ Build an induced subgraph on 'nodes_idx'. Keeps x,y,masks restricted to that set. Returns a shallow copy with edge_index/feature/labels/masks sliced. """ nodes_idx = nodes_idx.to(torch.long) sub_ei, _ = subgraph(nodes_idx, data.edge_index, relabel_nodes=True, num_nodes=data.num_nodes) sub = type(data)() sub.x = data.x[nodes_idx] sub.y = data.y[nodes_idx] sub.train_mask = data.train_mask[nodes_idx] sub.val_mask = data.val_mask[nodes_idx] sub.test_mask = data.test_mask[nodes_idx] sub.edge_index = sub_ei sub.num_nodes = sub.x.size(0) return sub # ------------------------------------------------------------ # APPNP seeding (Mode B) # ------------------------------------------------------------ def appnp_seed_propagate(logits_seed: Tensor, edge_index: Tensor, K=10, alpha=0.1) -> Tensor: """ logits_seed is [N, C] where rows outside the core are zeros. We propagate these logits with APPNP to fill the graph. """ appnp = APPNP(K=K, alpha=alpha) # no trainable params return appnp(logits_seed, edge_index) # ------------------------------------------------------------ # Main # ------------------------------------------------------------ def main(): p = argparse.ArgumentParser() p.add_argument("--dataset", required=True, choices=["Cora", "Citeseer", "Pubmed"]) p.add_argument("--seeds", required=True, help="Path to LRMC seeds JSON") p.add_argument("--hidden", type=int, default=64) p.add_argument("--dropout", type=float, default=0.5) p.add_argument("--epochs", type=int, default=200) p.add_argument("--lr", type=float, default=0.01) p.add_argument("--wd", type=float, default=5e-4) p.add_argument("--patience", type=int, default=100) p.add_argument("--core_mode", choices=["forward", "appnp"], default="forward", help="How to evaluate the core model on the full graph.") p.add_argument("--alpha", type=float, default=0.1, help="APPNP teleport prob (Mode B).") p.add_argument("--K", type=int, default=10, help="APPNP steps (Mode B).") p.add_argument("--expand_core_with_train", action="store_true", help="Expand LRMC core with all training nodes (C' = C ∪ train_idx).") p.add_argument("--warm_ft_epochs", type=int, default=0, help="If >0, run a short finetune on the FULL graph starting from the core model.") p.add_argument("--warm_ft_lr", type=float, default=0.005) p.add_argument("--runs", type=int, default=1, help="Number of runs with different seeds to average results.") p.add_argument("-o", "--output_json", type=str, default=None, help="If set, save all computed metrics and settings to this JSON file.") args = p.parse_args() # ------------------------------------------------------------ # Load data # ------------------------------------------------------------ ds = Planetoid(root=f"./data/{args.dataset}", name=args.dataset) data = ds[0] n, e = data.num_nodes, data.edge_index.size(1) // 2 console.print(f"[bold cyan]Dataset: {args.dataset} | Nodes: {n} | Edges: {e}[/bold cyan]") # Results accumulator for optional JSON output results = { "args": { k: (float(v) if isinstance(v, float) else v) for k, v in vars(args).items() if k != "output_json" }, "dataset": { "name": args.dataset, "num_nodes": int(n), "num_edges": int(e), }, } def maybe_save_results(): """Write results to JSON if the user requested it.""" if not args.output_json: return out_path = Path(args.output_json) try: out_path.parent.mkdir(parents=True, exist_ok=True) except Exception: pass with out_path.open("w") as f: json.dump(results, f, indent=2) # ------------------------------------------------------------ # Load LRMC core # ------------------------------------------------------------ core_mask = load_top1_assignment(args.seeds, n) if args.expand_core_with_train: core_mask = core_mask | data.train_mask C_idx = torch.nonzero(core_mask, as_tuple=False).view(-1) frac = 100.0 * C_idx.numel() / n cov = coverage_counts(core_mask, data.train_mask, data.val_mask, data.test_mask) console.print(f"[bold green]Loaded LRMC core of size {cov['core_size']} (≈{frac:.2f}% of the graph) from {args.seeds}[/bold green]") # Record core coverage info results["core"] = { "source": str(args.seeds), "expanded_with_train": bool(args.expand_core_with_train), "size": int(cov["core_size"]), "fraction": float(frac / 100.0), "coverage": { "train_in_core": int(cov["train_in_core"]), "val_in_core": int(cov["val_in_core"]), "test_in_core": int(cov["test_in_core"]), }, } # Coverage table cov_table = Table(title="LRMC Core Coverage") cov_table.add_column("Metric", style="cyan") cov_table.add_column("Count", style="magenta") cov_table.add_row("Core Size", str(cov["core_size"])) cov_table.add_row("Train in Core", str(cov["train_in_core"])) cov_table.add_row("Val in Core", str(cov["val_in_core"])) cov_table.add_row("Test in Core", str(cov["test_in_core"])) console.print(cov_table) # ------------------------------------------------------------ # Single-run or multi-run execution # ------------------------------------------------------------ if args.runs == 1: # --------------------- # Baseline (full graph) # --------------------- set_seed(0) t0 = time.perf_counter() base = GCN2(in_dim=ds.num_node_features, hid=args.hidden, out_dim=ds.num_classes, dropout=args.dropout) train(base, data, epochs=args.epochs, lr=args.lr, wd=args.wd, patience=args.patience) base_metrics = eval_all(base, data) t1 = time.perf_counter() console.print("\n[bold]Baseline (trained on full graph):[/bold]") base_table = Table(show_header=True, header_style="bold magenta") base_table.add_column("Metric", style="cyan") base_table.add_column("Value", style="magenta") base_table.add_row("Train Accuracy", f"{base_metrics['train']:.4f}") base_table.add_row("Validation Accuracy", f"{base_metrics['val']:.4f}") base_table.add_row("Test Accuracy", f"{base_metrics['test']:.4f}") base_table.add_row("Time (s)", f"{t1 - t0:.2f}") console.print(base_table) # Save baseline single-run metrics results["single_run"] = { "baseline": { "train": float(base_metrics["train"]), "val": float(base_metrics["val"]), "test": float(base_metrics["test"]), "time_s": float(t1 - t0), } } # --------------------- # Core model (train on subgraph) # --------------------- if C_idx.numel() == 0: console.print("[bold yellow]LRMC core is empty; skipping core model.[/bold yellow]") results["core_empty"] = True maybe_save_results() return data_C = subset_data(data, C_idx) mC = GCN2(in_dim=ds.num_node_features, hid=args.hidden, out_dim=ds.num_classes, dropout=args.dropout) t2 = time.perf_counter() train(mC, data_C, epochs=args.epochs, lr=args.lr, wd=args.wd, patience=args.patience) t3 = time.perf_counter() # Evaluate core model on FULL graph if args.core_mode == "forward": # Mode A: run a standard forward pass on the full graph mC.eval() logits_full = mC(data.x, data.edge_index) else: # Mode B: seed logits on core and propagate with APPNP mC.eval() with torch.no_grad(): logits_C = mC(data_C.x, data_C.edge_index) # [|C|, num_classes] logits_seed = torch.zeros(n, ds.num_classes, device=logits_C.device) logits_seed[C_idx] = logits_C logits_full = appnp_seed_propagate(logits_seed, data.edge_index, K=args.K, alpha=args.alpha) core_metrics = { "train": accuracy(logits_full, data.y, data.train_mask), "val": accuracy(logits_full, data.y, data.val_mask), "test": accuracy(logits_full, data.y, data.test_mask), } console.print("\n[bold]LRMC‑core model (trained on core, evaluated on full graph):[/bold]") core_table = Table(show_header=True, header_style="bold magenta") core_table.add_column("Metric", style="cyan") core_table.add_column("Value", style="magenta") core_table.add_row("Train Accuracy", f"{core_metrics['train']:.4f}") core_table.add_row("Validation Accuracy", f"{core_metrics['val']:.4f}") core_table.add_row("Test Accuracy", f"{core_metrics['test']:.4f}") core_table.add_row("Core Training Time (s)", f"{t3 - t2:.2f}") speedup = (t1 - t0) / (t3 - t2 + 1e-9) core_table.add_row("Speedup vs. Baseline", f"{speedup:.2f}×") console.print(core_table) # Save core single-run metrics results["single_run"]["core_model"] = { "mode": str(args.core_mode), "train": float(core_metrics["train"]), "val": float(core_metrics["val"]), "test": float(core_metrics["test"]), "core_train_time_s": float(t3 - t2), "speedup_vs_baseline": float(speedup), } # ------------------------------------------------------------------------------------ console.print("\n[bold]Model Comparison: Baseline vs. L-RMC-core[/bold]") # Create comparison table comparison_table = Table(title="Performance Comparison", show_header=True, header_style="bold magenta") comparison_table.add_column("Metric", style="cyan") comparison_table.add_column("Baseline", style="magenta") comparison_table.add_column("L-RMC-core", style="green") comparison_table.add_column("Speedup", style="yellow") # Add performance metrics for metric in ["train", "val", "test"]: comparison_table.add_row( f"{metric.capitalize()} Accuracy", f"{base_metrics[metric]:.4f}", f"{core_metrics[metric]:.4f}", "" # Speedup is not applicable for accuracy ) # Add timing and speedup baseline_time = t1 - t0 core_time = t3 - t2 speedup = baseline_time / core_time if core_time > 0 else float('inf') comparison_table.add_row( "Training Time (s)", f"{baseline_time:.2f}", f"{core_time:.2f}", f"{speedup:.2f}x" ) comparison_table.add_row( "Speedup", "1x", f"{speedup:.2f}x", "" ) console.print(comparison_table) # Optional warm‑start finetune (single run) if args.warm_ft_epochs > 0: warm = GCN2(in_dim=ds.num_node_features, hid=args.hidden, out_dim=ds.num_classes, dropout=args.dropout) warm.load_state_dict(mC.state_dict()) t4 = time.perf_counter() train(warm, data, epochs=args.warm_ft_epochs, lr=args.warm_ft_lr, wd=args.wd, patience=args.warm_ft_epochs + 1) t5 = time.perf_counter() warm_metrics = eval_all(warm, data) console.print("\n[bold]Warm‑start finetune (start from core model, train on FULL graph):[/bold]") warm_table = Table(show_header=True, header_style="bold magenta") warm_table.add_column("Metric", style="cyan") warm_table.add_column("Value", style="magenta") warm_table.add_row("Train Accuracy", f"{warm_metrics['train']:.4f}") warm_table.add_row("Validation Accuracy", f"{warm_metrics['val']:.4f}") warm_table.add_row("Test Accuracy", f"{warm_metrics['test']:.4f}") warm_table.add_row("Finetune Time (s)", f"{t5 - t4:.2f}") warm_table.add_row("Total (core train + warm)", f"{(t3 - t2 + t5 - t4):.2f}s") console.print(warm_table) # Save warm single-run metrics results["single_run"]["warm_finetune"] = { "train": float(warm_metrics["train"]), "val": float(warm_metrics["val"]), "test": float(warm_metrics["test"]), "finetune_time_s": float(t5 - t4), "total_time_s": float((t3 - t2) + (t5 - t4)), } # Emit results for single-run maybe_save_results() else: # -------------------------------------------------------- # Multi-run: average metrics across different seeds # -------------------------------------------------------- runs = args.runs console.print(f"\n[bold]Running {runs} seeds and averaging results[/bold]") # Storage for metrics across runs base_train, base_val, base_test, base_time = [], [], [], [] core_train, core_val, core_test, core_time = [], [], [], [] speedups = [] warm_train, warm_val, warm_test, warm_time, warm_total_time = [], [], [], [], [] data_C = subset_data(data, C_idx) if C_idx.numel() > 0 else None results["core_empty"] = data_C is None for r in range(runs): set_seed(r) # Baseline t0 = time.perf_counter() base = GCN2(in_dim=ds.num_node_features, hid=args.hidden, out_dim=ds.num_classes, dropout=args.dropout) train(base, data, epochs=args.epochs, lr=args.lr, wd=args.wd, patience=args.patience) bm = eval_all(base, data) t1 = time.perf_counter() base_train.append(bm["train"]) ; base_val.append(bm["val"]) ; base_test.append(bm["test"]) ; base_time.append(t1 - t0) # Core model if data_C is None: continue # no core available t2 = time.perf_counter() mC = GCN2(in_dim=ds.num_node_features, hid=args.hidden, out_dim=ds.num_classes, dropout=args.dropout) train(mC, data_C, epochs=args.epochs, lr=args.lr, wd=args.wd, patience=args.patience) t3 = time.perf_counter() if args.core_mode == "forward": mC.eval() logits_full = mC(data.x, data.edge_index) else: mC.eval() with torch.no_grad(): logits_C = mC(data_C.x, data_C.edge_index) logits_seed = torch.zeros(n, ds.num_classes, device=logits_C.device) logits_seed[C_idx] = logits_C logits_full = appnp_seed_propagate(logits_seed, data.edge_index, K=args.K, alpha=args.alpha) cm = { "train": accuracy(logits_full, data.y, data.train_mask), "val": accuracy(logits_full, data.y, data.val_mask), "test": accuracy(logits_full, data.y, data.test_mask), } core_train.append(cm["train"]) ; core_val.append(cm["val"]) ; core_test.append(cm["test"]) ; core_time.append(t3 - t2) speedups.append((t1 - t0) / (t3 - t2 + 1e-9)) # Optional warm finetune per run if args.warm_ft_epochs > 0: warm = GCN2(in_dim=ds.num_node_features, hid=args.hidden, out_dim=ds.num_classes, dropout=args.dropout) warm.load_state_dict(mC.state_dict()) t4 = time.perf_counter() train(warm, data, epochs=args.warm_ft_epochs, lr=args.warm_ft_lr, wd=args.wd, patience=args.warm_ft_epochs + 1) t5 = time.perf_counter() wm = eval_all(warm, data) warm_train.append(wm["train"]) ; warm_val.append(wm["val"]) ; warm_test.append(wm["test"]) ; warm_time.append(t5 - t4) warm_total_time.append((t3 - t2) + (t5 - t4)) # Helper to format mean ± std def fmt(values, prec=4): if not values: return "n/a" if len(values) == 1: return f"{values[0]:.{prec}f}" try: return f"{mean(values):.{prec}f} ± {stdev(values):.{prec}f}" except Exception: m = sum(values) / len(values) var = sum((v - m) ** 2 for v in values) / max(1, len(values) - 1) return f"{m:.{prec}f} ± {var ** 0.5:.{prec}f}" def stats(values): """Return dict with list, mean, std, count for JSON.""" d = { "values": [float(v) for v in values], "count": int(len(values)), } if len(values) >= 1: d["mean"] = float(mean(values)) if len(values) >= 2: d["std"] = float(stdev(values)) else: d["std"] = None return d # Baseline summary console.print("\n[bold]Baseline (averaged over runs):[/bold]") base_table = Table(show_header=True, header_style="bold magenta") base_table.add_column("Metric", style="cyan") base_table.add_column("Mean ± Std", style="magenta") base_table.add_row("Train Accuracy", fmt(base_train)) base_table.add_row("Validation Accuracy", fmt(base_val)) base_table.add_row("Test Accuracy", fmt(base_test)) base_table.add_row("Time (s)", fmt(base_time, prec=2)) console.print(base_table) # Save baseline multi-run summary results["multi_run"] = { "runs": int(runs), "baseline": { "train": stats(base_train), "val": stats(base_val), "test": stats(base_test), "time_s": stats(base_time), } } if data_C is None: console.print("[bold yellow]LRMC core is empty; no core runs to average.[/bold yellow]") maybe_save_results() return # Core summary console.print("\n[bold]LRMC‑core (averaged over runs):[/bold]") core_table = Table(show_header=True, header_style="bold magenta") core_table.add_column("Metric", style="cyan") core_table.add_column("Mean ± Std", style="magenta") core_table.add_row("Train Accuracy", fmt(core_train)) core_table.add_row("Validation Accuracy", fmt(core_val)) core_table.add_row("Test Accuracy", fmt(core_test)) core_table.add_row("Core Training Time (s)", fmt(core_time, prec=2)) core_table.add_row("Speedup vs. Baseline", fmt(speedups, prec=2)) console.print(core_table) # Save core multi-run summary results["multi_run"]["core_model"] = { "mode": str(args.core_mode), "train": stats(core_train), "val": stats(core_val), "test": stats(core_test), "core_train_time_s": stats(core_time), "speedup_vs_baseline": stats(speedups), } # Comparison summary console.print("\n[bold]Model Comparison (averaged): Baseline vs. L-RMC-core[/bold]") comparison_table = Table(title="Performance Comparison (Mean ± Std)", show_header=True, header_style="bold magenta") comparison_table.add_column("Metric", style="cyan") comparison_table.add_column("Baseline", style="magenta") comparison_table.add_column("L-RMC-core", style="green") comparison_table.add_column("Speedup", style="yellow") for metric, b_vals, c_vals in [ ("Train Accuracy", base_train, core_train), ("Validation Accuracy", base_val, core_val), ("Test Accuracy", base_test, core_test), ]: comparison_table.add_row(metric, fmt(b_vals), fmt(c_vals), "") comparison_table.add_row("Training Time (s)", fmt(base_time, prec=2), fmt(core_time, prec=2), fmt(speedups, prec=2)) comparison_table.add_row("Speedup", "1x", fmt(speedups, prec=2), "") console.print(comparison_table) # Optional warm summary if args.warm_ft_epochs > 0 and warm_time: console.print("\n[bold]Warm‑start finetune (averaged over runs):[/bold]") warm_table = Table(show_header=True, header_style="bold magenta") warm_table.add_column("Metric", style="cyan") warm_table.add_column("Mean ± Std", style="magenta") warm_table.add_row("Train Accuracy", fmt(warm_train)) warm_table.add_row("Validation Accuracy", fmt(warm_val)) warm_table.add_row("Test Accuracy", fmt(warm_test)) warm_table.add_row("Finetune Time (s)", fmt(warm_time, prec=2)) warm_table.add_row("Total (core train + warm)", fmt(warm_total_time, prec=2)) console.print(warm_table) # Save warm multi-run summary results["multi_run"]["warm_finetune"] = { "train": stats(warm_train), "val": stats(warm_val), "test": stats(warm_test), "finetune_time_s": stats(warm_time), "total_time_s": stats(warm_total_time), } # Emit results for multi-run maybe_save_results() if __name__ == "__main__": main()