clique / src /2.6_lrmc_summary.py
qingy2024's picture
Upload folder using huggingface_hub
f74dd01 verified
"""
Train a GCN on an L‑RMC subgraph and compare to a full‑graph baseline.
Modes:
- core_mode=forward : Train on core subgraph, then forward on full graph (your current approach).
- core_mode=appnp : Train on core subgraph, then seed logits on core and APPNP‑propagate on full graph.
Extras:
- --expand_core_with_train : Make sure all training labels lie inside the core
(C' = C ∪ train_idx) for fair train‑time comparison.
- --warm_ft_epochs N : Optional short finetune on the full graph starting
from the core model's weights (measure time‑to‑target).
It prints:
- Dataset stats
- Core size and coverage of train/val/test inside the core
- Train/Val/Test accuracy for baseline and core model
- Wall‑clock times
"""
import argparse
import json
import time
import random
from statistics import mean, stdev
from pathlib import Path
from typing import Dict
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv, APPNP
from torch_geometric.utils import subgraph
# ------------------------------------------------------------
# Rich imports
# ------------------------------------------------------------
from rich.console import Console
from rich.table import Table
from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn
# Rich console instance
console = Console()
# ------------------------------------------------------------
# Utilities
# ------------------------------------------------------------
def load_top1_assignment(seeds_json: str, n_nodes: int) -> torch.Tensor:
"""
seeds_json format (expected):
{"clusters": [{"seed_nodes":[...], "score": float, ...}, ...]}
We pick the cluster with max (score, size) and return a boolean core mask.
Always assume that the seeds json nodes are 1-indexed.
"""
obj = json.loads(Path(seeds_json).read_text())
clusters = obj.get("clusters", [])
if not clusters:
return torch.zeros(n_nodes, dtype=torch.bool)
best = max(clusters, key=lambda c: (float(c.get("score", 0.0)), len(c.get("seed_nodes", []))))
ids = best.get("seed_nodes", [])
ids = [int(x) - 1 for x in ids] # Convert 1-indexed to 0-indexed
ids = sorted(set([i for i in ids if 0 <= i < n_nodes]))
mask = torch.zeros(n_nodes, dtype=torch.bool)
if ids:
mask[torch.tensor(ids, dtype=torch.long)] = True
return mask
def coverage_counts(core_mask: torch.Tensor, train_mask: torch.Tensor,
val_mask: torch.Tensor, test_mask: torch.Tensor) -> Dict[str, int]:
return {
"core_size": int(core_mask.sum().item()),
"train_in_core": int((core_mask & train_mask).sum().item()),
"val_in_core": int((core_mask & val_mask).sum().item()),
"test_in_core": int((core_mask & test_mask).sum().item()),
}
def accuracy(logits: Tensor, y: Tensor, mask: Tensor) -> float:
pred = logits[mask].argmax(dim=1)
return (pred == y[mask]).float().mean().item()
def set_seed(seed: int):
"""Set random seeds for reproducibility across runs."""
random.seed(seed)
try:
import numpy as np # optional
np.random.seed(seed)
except Exception:
pass
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# Make CUDA/CuDNN deterministic where applicable
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# ------------------------------------------------------------
# Models
# ------------------------------------------------------------
class GCN2(nn.Module):
def __init__(self, in_dim: int, hid: int, out_dim: int, dropout: float = 0.5):
super().__init__()
self.c1 = GCNConv(in_dim, hid)
self.c2 = GCNConv(hid, out_dim)
self.dropout = dropout
def forward(self, x, ei):
x = self.c1(x, ei)
x = torch.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.c2(x, ei)
return x
# ------------------------------------------------------------
# Training / evaluation
# ------------------------------------------------------------
@torch.no_grad()
def eval_all(model: nn.Module, data) -> Dict[str, float]:
model.eval()
logits = model(data.x, data.edge_index)
return {
"train": accuracy(logits, data.y, data.train_mask),
"val": accuracy(logits, data.y, data.val_mask),
"test": accuracy(logits, data.y, data.test_mask),
}
def train(model: nn.Module, data, epochs=200, lr=0.01, wd=5e-4, patience=100):
opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
best, best_state, bad = -1.0, None, 0
# Optional progress bar
with Progress(
SpinnerColumn(),
"[progress.description]{task.description}",
TimeElapsedColumn(),
transient=True,
) as progress:
task = progress.add_task("Training", total=epochs)
for ep in range(1, epochs + 1):
model.train()
opt.zero_grad(set_to_none=True)
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
opt.step()
# early stop on val
with torch.no_grad():
val = accuracy(model(data.x, data.edge_index), data.y, data.val_mask)
if val > best:
best, bad = val, 0
best_state = {k: v.detach().clone() for k, v in model.state_dict().items()}
else:
bad += 1
if bad >= patience:
break
progress.update(task, advance=1, description=f"Epoch {ep} | val={val:.4f}")
if best_state is not None:
model.load_state_dict(best_state)
model.eval()
def subset_data(data, nodes_idx: torch.Tensor):
"""
Build an induced subgraph on 'nodes_idx'. Keeps x,y,masks restricted to that set.
Returns a shallow copy with edge_index/feature/labels/masks sliced.
"""
nodes_idx = nodes_idx.to(torch.long)
sub_ei, _ = subgraph(nodes_idx, data.edge_index, relabel_nodes=True, num_nodes=data.num_nodes)
sub = type(data)()
sub.x = data.x[nodes_idx]
sub.y = data.y[nodes_idx]
sub.train_mask = data.train_mask[nodes_idx]
sub.val_mask = data.val_mask[nodes_idx]
sub.test_mask = data.test_mask[nodes_idx]
sub.edge_index = sub_ei
sub.num_nodes = sub.x.size(0)
return sub
# ------------------------------------------------------------
# APPNP seeding (Mode B)
# ------------------------------------------------------------
def appnp_seed_propagate(logits_seed: Tensor, edge_index: Tensor, K=10, alpha=0.1) -> Tensor:
"""
logits_seed is [N, C] where rows outside the core are zeros.
We propagate these logits with APPNP to fill the graph.
"""
appnp = APPNP(K=K, alpha=alpha) # no trainable params
return appnp(logits_seed, edge_index)
# ------------------------------------------------------------
# Main
# ------------------------------------------------------------
def main():
p = argparse.ArgumentParser()
p.add_argument("--dataset", required=True, choices=["Cora", "Citeseer", "Pubmed"])
p.add_argument("--seeds", required=True, help="Path to LRMC seeds JSON")
p.add_argument("--hidden", type=int, default=64)
p.add_argument("--dropout", type=float, default=0.5)
p.add_argument("--epochs", type=int, default=200)
p.add_argument("--lr", type=float, default=0.01)
p.add_argument("--wd", type=float, default=5e-4)
p.add_argument("--patience", type=int, default=100)
p.add_argument("--core_mode", choices=["forward", "appnp"], default="forward",
help="How to evaluate the core model on the full graph.")
p.add_argument("--alpha", type=float, default=0.1, help="APPNP teleport prob (Mode B).")
p.add_argument("--K", type=int, default=10, help="APPNP steps (Mode B).")
p.add_argument("--expand_core_with_train", action="store_true",
help="Expand LRMC core with all training nodes (C' = C ∪ train_idx).")
p.add_argument("--warm_ft_epochs", type=int, default=0,
help="If >0, run a short finetune on the FULL graph starting from the core model.")
p.add_argument("--warm_ft_lr", type=float, default=0.005)
p.add_argument("--runs", type=int, default=1,
help="Number of runs with different seeds to average results.")
p.add_argument("-o", "--output_json", type=str, default=None,
help="If set, save all computed metrics and settings to this JSON file.")
args = p.parse_args()
# ------------------------------------------------------------
# Load data
# ------------------------------------------------------------
ds = Planetoid(root=f"./data/{args.dataset}", name=args.dataset)
data = ds[0]
n, e = data.num_nodes, data.edge_index.size(1) // 2
console.print(f"[bold cyan]Dataset: {args.dataset} | Nodes: {n} | Edges: {e}[/bold cyan]")
# Results accumulator for optional JSON output
results = {
"args": {
k: (float(v) if isinstance(v, float) else v)
for k, v in vars(args).items()
if k != "output_json"
},
"dataset": {
"name": args.dataset,
"num_nodes": int(n),
"num_edges": int(e),
},
}
def maybe_save_results():
"""Write results to JSON if the user requested it."""
if not args.output_json:
return
out_path = Path(args.output_json)
try:
out_path.parent.mkdir(parents=True, exist_ok=True)
except Exception:
pass
with out_path.open("w") as f:
json.dump(results, f, indent=2)
# ------------------------------------------------------------
# Load LRMC core
# ------------------------------------------------------------
core_mask = load_top1_assignment(args.seeds, n)
if args.expand_core_with_train:
core_mask = core_mask | data.train_mask
C_idx = torch.nonzero(core_mask, as_tuple=False).view(-1)
frac = 100.0 * C_idx.numel() / n
cov = coverage_counts(core_mask, data.train_mask, data.val_mask, data.test_mask)
console.print(f"[bold green]Loaded LRMC core of size {cov['core_size']} (≈{frac:.2f}% of the graph) from {args.seeds}[/bold green]")
# Record core coverage info
results["core"] = {
"source": str(args.seeds),
"expanded_with_train": bool(args.expand_core_with_train),
"size": int(cov["core_size"]),
"fraction": float(frac / 100.0),
"coverage": {
"train_in_core": int(cov["train_in_core"]),
"val_in_core": int(cov["val_in_core"]),
"test_in_core": int(cov["test_in_core"]),
},
}
# Coverage table
cov_table = Table(title="LRMC Core Coverage")
cov_table.add_column("Metric", style="cyan")
cov_table.add_column("Count", style="magenta")
cov_table.add_row("Core Size", str(cov["core_size"]))
cov_table.add_row("Train in Core", str(cov["train_in_core"]))
cov_table.add_row("Val in Core", str(cov["val_in_core"]))
cov_table.add_row("Test in Core", str(cov["test_in_core"]))
console.print(cov_table)
# ------------------------------------------------------------
# Single-run or multi-run execution
# ------------------------------------------------------------
if args.runs == 1:
# ---------------------
# Baseline (full graph)
# ---------------------
set_seed(0)
t0 = time.perf_counter()
base = GCN2(in_dim=ds.num_node_features,
hid=args.hidden,
out_dim=ds.num_classes,
dropout=args.dropout)
train(base, data, epochs=args.epochs, lr=args.lr, wd=args.wd, patience=args.patience)
base_metrics = eval_all(base, data)
t1 = time.perf_counter()
console.print("\n[bold]Baseline (trained on full graph):[/bold]")
base_table = Table(show_header=True, header_style="bold magenta")
base_table.add_column("Metric", style="cyan")
base_table.add_column("Value", style="magenta")
base_table.add_row("Train Accuracy", f"{base_metrics['train']:.4f}")
base_table.add_row("Validation Accuracy", f"{base_metrics['val']:.4f}")
base_table.add_row("Test Accuracy", f"{base_metrics['test']:.4f}")
base_table.add_row("Time (s)", f"{t1 - t0:.2f}")
console.print(base_table)
# Save baseline single-run metrics
results["single_run"] = {
"baseline": {
"train": float(base_metrics["train"]),
"val": float(base_metrics["val"]),
"test": float(base_metrics["test"]),
"time_s": float(t1 - t0),
}
}
# ---------------------
# Core model (train on subgraph)
# ---------------------
if C_idx.numel() == 0:
console.print("[bold yellow]LRMC core is empty; skipping core model.[/bold yellow]")
results["core_empty"] = True
maybe_save_results()
return
data_C = subset_data(data, C_idx)
mC = GCN2(in_dim=ds.num_node_features,
hid=args.hidden,
out_dim=ds.num_classes,
dropout=args.dropout)
t2 = time.perf_counter()
train(mC, data_C, epochs=args.epochs, lr=args.lr, wd=args.wd, patience=args.patience)
t3 = time.perf_counter()
# Evaluate core model on FULL graph
if args.core_mode == "forward":
# Mode A: run a standard forward pass on the full graph
mC.eval()
logits_full = mC(data.x, data.edge_index)
else:
# Mode B: seed logits on core and propagate with APPNP
mC.eval()
with torch.no_grad():
logits_C = mC(data_C.x, data_C.edge_index) # [|C|, num_classes]
logits_seed = torch.zeros(n, ds.num_classes, device=logits_C.device)
logits_seed[C_idx] = logits_C
logits_full = appnp_seed_propagate(logits_seed,
data.edge_index,
K=args.K,
alpha=args.alpha)
core_metrics = {
"train": accuracy(logits_full, data.y, data.train_mask),
"val": accuracy(logits_full, data.y, data.val_mask),
"test": accuracy(logits_full, data.y, data.test_mask),
}
console.print("\n[bold]LRMC‑core model (trained on core, evaluated on full graph):[/bold]")
core_table = Table(show_header=True, header_style="bold magenta")
core_table.add_column("Metric", style="cyan")
core_table.add_column("Value", style="magenta")
core_table.add_row("Train Accuracy", f"{core_metrics['train']:.4f}")
core_table.add_row("Validation Accuracy", f"{core_metrics['val']:.4f}")
core_table.add_row("Test Accuracy", f"{core_metrics['test']:.4f}")
core_table.add_row("Core Training Time (s)", f"{t3 - t2:.2f}")
speedup = (t1 - t0) / (t3 - t2 + 1e-9)
core_table.add_row("Speedup vs. Baseline", f"{speedup:.2f}×")
console.print(core_table)
# Save core single-run metrics
results["single_run"]["core_model"] = {
"mode": str(args.core_mode),
"train": float(core_metrics["train"]),
"val": float(core_metrics["val"]),
"test": float(core_metrics["test"]),
"core_train_time_s": float(t3 - t2),
"speedup_vs_baseline": float(speedup),
}
# ------------------------------------------------------------------------------------
console.print("\n[bold]Model Comparison: Baseline vs. L-RMC-core[/bold]")
# Create comparison table
comparison_table = Table(title="Performance Comparison", show_header=True, header_style="bold magenta")
comparison_table.add_column("Metric", style="cyan")
comparison_table.add_column("Baseline", style="magenta")
comparison_table.add_column("L-RMC-core", style="green")
comparison_table.add_column("Speedup", style="yellow")
# Add performance metrics
for metric in ["train", "val", "test"]:
comparison_table.add_row(
f"{metric.capitalize()} Accuracy",
f"{base_metrics[metric]:.4f}",
f"{core_metrics[metric]:.4f}",
"" # Speedup is not applicable for accuracy
)
# Add timing and speedup
baseline_time = t1 - t0
core_time = t3 - t2
speedup = baseline_time / core_time if core_time > 0 else float('inf')
comparison_table.add_row(
"Training Time (s)",
f"{baseline_time:.2f}",
f"{core_time:.2f}",
f"{speedup:.2f}x"
)
comparison_table.add_row(
"Speedup",
"1x",
f"{speedup:.2f}x",
""
)
console.print(comparison_table)
# Optional warm‑start finetune (single run)
if args.warm_ft_epochs > 0:
warm = GCN2(in_dim=ds.num_node_features,
hid=args.hidden,
out_dim=ds.num_classes,
dropout=args.dropout)
warm.load_state_dict(mC.state_dict())
t4 = time.perf_counter()
train(warm, data,
epochs=args.warm_ft_epochs,
lr=args.warm_ft_lr,
wd=args.wd,
patience=args.warm_ft_epochs + 1)
t5 = time.perf_counter()
warm_metrics = eval_all(warm, data)
console.print("\n[bold]Warm‑start finetune (start from core model, train on FULL graph):[/bold]")
warm_table = Table(show_header=True, header_style="bold magenta")
warm_table.add_column("Metric", style="cyan")
warm_table.add_column("Value", style="magenta")
warm_table.add_row("Train Accuracy", f"{warm_metrics['train']:.4f}")
warm_table.add_row("Validation Accuracy", f"{warm_metrics['val']:.4f}")
warm_table.add_row("Test Accuracy", f"{warm_metrics['test']:.4f}")
warm_table.add_row("Finetune Time (s)", f"{t5 - t4:.2f}")
warm_table.add_row("Total (core train + warm)", f"{(t3 - t2 + t5 - t4):.2f}s")
console.print(warm_table)
# Save warm single-run metrics
results["single_run"]["warm_finetune"] = {
"train": float(warm_metrics["train"]),
"val": float(warm_metrics["val"]),
"test": float(warm_metrics["test"]),
"finetune_time_s": float(t5 - t4),
"total_time_s": float((t3 - t2) + (t5 - t4)),
}
# Emit results for single-run
maybe_save_results()
else:
# --------------------------------------------------------
# Multi-run: average metrics across different seeds
# --------------------------------------------------------
runs = args.runs
console.print(f"\n[bold]Running {runs} seeds and averaging results[/bold]")
# Storage for metrics across runs
base_train, base_val, base_test, base_time = [], [], [], []
core_train, core_val, core_test, core_time = [], [], [], []
speedups = []
warm_train, warm_val, warm_test, warm_time, warm_total_time = [], [], [], [], []
data_C = subset_data(data, C_idx) if C_idx.numel() > 0 else None
results["core_empty"] = data_C is None
for r in range(runs):
set_seed(r)
# Baseline
t0 = time.perf_counter()
base = GCN2(in_dim=ds.num_node_features,
hid=args.hidden,
out_dim=ds.num_classes,
dropout=args.dropout)
train(base, data, epochs=args.epochs, lr=args.lr, wd=args.wd, patience=args.patience)
bm = eval_all(base, data)
t1 = time.perf_counter()
base_train.append(bm["train"]) ; base_val.append(bm["val"]) ; base_test.append(bm["test"]) ; base_time.append(t1 - t0)
# Core model
if data_C is None:
continue # no core available
t2 = time.perf_counter()
mC = GCN2(in_dim=ds.num_node_features,
hid=args.hidden,
out_dim=ds.num_classes,
dropout=args.dropout)
train(mC, data_C, epochs=args.epochs, lr=args.lr, wd=args.wd, patience=args.patience)
t3 = time.perf_counter()
if args.core_mode == "forward":
mC.eval()
logits_full = mC(data.x, data.edge_index)
else:
mC.eval()
with torch.no_grad():
logits_C = mC(data_C.x, data_C.edge_index)
logits_seed = torch.zeros(n, ds.num_classes, device=logits_C.device)
logits_seed[C_idx] = logits_C
logits_full = appnp_seed_propagate(logits_seed,
data.edge_index,
K=args.K,
alpha=args.alpha)
cm = {
"train": accuracy(logits_full, data.y, data.train_mask),
"val": accuracy(logits_full, data.y, data.val_mask),
"test": accuracy(logits_full, data.y, data.test_mask),
}
core_train.append(cm["train"]) ; core_val.append(cm["val"]) ; core_test.append(cm["test"]) ; core_time.append(t3 - t2)
speedups.append((t1 - t0) / (t3 - t2 + 1e-9))
# Optional warm finetune per run
if args.warm_ft_epochs > 0:
warm = GCN2(in_dim=ds.num_node_features,
hid=args.hidden,
out_dim=ds.num_classes,
dropout=args.dropout)
warm.load_state_dict(mC.state_dict())
t4 = time.perf_counter()
train(warm, data,
epochs=args.warm_ft_epochs,
lr=args.warm_ft_lr,
wd=args.wd,
patience=args.warm_ft_epochs + 1)
t5 = time.perf_counter()
wm = eval_all(warm, data)
warm_train.append(wm["train"]) ; warm_val.append(wm["val"]) ; warm_test.append(wm["test"]) ; warm_time.append(t5 - t4)
warm_total_time.append((t3 - t2) + (t5 - t4))
# Helper to format mean ± std
def fmt(values, prec=4):
if not values:
return "n/a"
if len(values) == 1:
return f"{values[0]:.{prec}f}"
try:
return f"{mean(values):.{prec}f} ± {stdev(values):.{prec}f}"
except Exception:
m = sum(values) / len(values)
var = sum((v - m) ** 2 for v in values) / max(1, len(values) - 1)
return f"{m:.{prec}f} ± {var ** 0.5:.{prec}f}"
def stats(values):
"""Return dict with list, mean, std, count for JSON."""
d = {
"values": [float(v) for v in values],
"count": int(len(values)),
}
if len(values) >= 1:
d["mean"] = float(mean(values))
if len(values) >= 2:
d["std"] = float(stdev(values))
else:
d["std"] = None
return d
# Baseline summary
console.print("\n[bold]Baseline (averaged over runs):[/bold]")
base_table = Table(show_header=True, header_style="bold magenta")
base_table.add_column("Metric", style="cyan")
base_table.add_column("Mean ± Std", style="magenta")
base_table.add_row("Train Accuracy", fmt(base_train))
base_table.add_row("Validation Accuracy", fmt(base_val))
base_table.add_row("Test Accuracy", fmt(base_test))
base_table.add_row("Time (s)", fmt(base_time, prec=2))
console.print(base_table)
# Save baseline multi-run summary
results["multi_run"] = {
"runs": int(runs),
"baseline": {
"train": stats(base_train),
"val": stats(base_val),
"test": stats(base_test),
"time_s": stats(base_time),
}
}
if data_C is None:
console.print("[bold yellow]LRMC core is empty; no core runs to average.[/bold yellow]")
maybe_save_results()
return
# Core summary
console.print("\n[bold]LRMC‑core (averaged over runs):[/bold]")
core_table = Table(show_header=True, header_style="bold magenta")
core_table.add_column("Metric", style="cyan")
core_table.add_column("Mean ± Std", style="magenta")
core_table.add_row("Train Accuracy", fmt(core_train))
core_table.add_row("Validation Accuracy", fmt(core_val))
core_table.add_row("Test Accuracy", fmt(core_test))
core_table.add_row("Core Training Time (s)", fmt(core_time, prec=2))
core_table.add_row("Speedup vs. Baseline", fmt(speedups, prec=2))
console.print(core_table)
# Save core multi-run summary
results["multi_run"]["core_model"] = {
"mode": str(args.core_mode),
"train": stats(core_train),
"val": stats(core_val),
"test": stats(core_test),
"core_train_time_s": stats(core_time),
"speedup_vs_baseline": stats(speedups),
}
# Comparison summary
console.print("\n[bold]Model Comparison (averaged): Baseline vs. L-RMC-core[/bold]")
comparison_table = Table(title="Performance Comparison (Mean ± Std)", show_header=True, header_style="bold magenta")
comparison_table.add_column("Metric", style="cyan")
comparison_table.add_column("Baseline", style="magenta")
comparison_table.add_column("L-RMC-core", style="green")
comparison_table.add_column("Speedup", style="yellow")
for metric, b_vals, c_vals in [
("Train Accuracy", base_train, core_train),
("Validation Accuracy", base_val, core_val),
("Test Accuracy", base_test, core_test),
]:
comparison_table.add_row(metric, fmt(b_vals), fmt(c_vals), "")
comparison_table.add_row("Training Time (s)", fmt(base_time, prec=2), fmt(core_time, prec=2), fmt(speedups, prec=2))
comparison_table.add_row("Speedup", "1x", fmt(speedups, prec=2), "")
console.print(comparison_table)
# Optional warm summary
if args.warm_ft_epochs > 0 and warm_time:
console.print("\n[bold]Warm‑start finetune (averaged over runs):[/bold]")
warm_table = Table(show_header=True, header_style="bold magenta")
warm_table.add_column("Metric", style="cyan")
warm_table.add_column("Mean ± Std", style="magenta")
warm_table.add_row("Train Accuracy", fmt(warm_train))
warm_table.add_row("Validation Accuracy", fmt(warm_val))
warm_table.add_row("Test Accuracy", fmt(warm_test))
warm_table.add_row("Finetune Time (s)", fmt(warm_time, prec=2))
warm_table.add_row("Total (core train + warm)", fmt(warm_total_time, prec=2))
console.print(warm_table)
# Save warm multi-run summary
results["multi_run"]["warm_finetune"] = {
"train": stats(warm_train),
"val": stats(warm_val),
"test": stats(warm_test),
"finetune_time_s": stats(warm_time),
"total_time_s": stats(warm_total_time),
}
# Emit results for multi-run
maybe_save_results()
if __name__ == "__main__":
main()