|
|
""" |
|
|
Train a GCN on an L‑RMC subgraph and compare to a full‑graph baseline. |
|
|
|
|
|
Modes: |
|
|
- core_mode=forward : Train on core subgraph, then forward on full graph (your current approach). |
|
|
- core_mode=appnp : Train on core subgraph, then seed logits on core and APPNP‑propagate on full graph. |
|
|
|
|
|
Extras: |
|
|
- --expand_core_with_train : Make sure all training labels lie inside the core |
|
|
(C' = C ∪ train_idx) for fair train‑time comparison. |
|
|
- --warm_ft_epochs N : Optional short finetune on the full graph starting |
|
|
from the core model's weights (measure time‑to‑target). |
|
|
|
|
|
It prints: |
|
|
- Dataset stats |
|
|
- Core size and coverage of train/val/test inside the core |
|
|
- Train/Val/Test accuracy for baseline and core model |
|
|
- Wall‑clock times |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import time |
|
|
import random |
|
|
from statistics import mean, stdev |
|
|
from pathlib import Path |
|
|
from typing import Dict |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch import nn, Tensor |
|
|
from torch_geometric.datasets import Planetoid |
|
|
from torch_geometric.nn import GCNConv, APPNP |
|
|
from torch_geometric.utils import subgraph |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from rich.console import Console |
|
|
from rich.table import Table |
|
|
from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn |
|
|
|
|
|
|
|
|
console = Console() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_top1_assignment(seeds_json: str, n_nodes: int) -> torch.Tensor: |
|
|
""" |
|
|
seeds_json format (expected): |
|
|
{"clusters": [{"seed_nodes":[...], "score": float, ...}, ...]} |
|
|
We pick the cluster with max (score, size) and return a boolean core mask. |
|
|
|
|
|
Always assume that the seeds json nodes are 1-indexed. |
|
|
""" |
|
|
obj = json.loads(Path(seeds_json).read_text()) |
|
|
clusters = obj.get("clusters", []) |
|
|
if not clusters: |
|
|
return torch.zeros(n_nodes, dtype=torch.bool) |
|
|
best = max(clusters, key=lambda c: (float(c.get("score", 0.0)), len(c.get("seed_nodes", [])))) |
|
|
ids = best.get("seed_nodes", []) |
|
|
ids = [int(x) - 1 for x in ids] |
|
|
ids = sorted(set([i for i in ids if 0 <= i < n_nodes])) |
|
|
mask = torch.zeros(n_nodes, dtype=torch.bool) |
|
|
if ids: |
|
|
mask[torch.tensor(ids, dtype=torch.long)] = True |
|
|
return mask |
|
|
|
|
|
|
|
|
def coverage_counts(core_mask: torch.Tensor, train_mask: torch.Tensor, |
|
|
val_mask: torch.Tensor, test_mask: torch.Tensor) -> Dict[str, int]: |
|
|
return { |
|
|
"core_size": int(core_mask.sum().item()), |
|
|
"train_in_core": int((core_mask & train_mask).sum().item()), |
|
|
"val_in_core": int((core_mask & val_mask).sum().item()), |
|
|
"test_in_core": int((core_mask & test_mask).sum().item()), |
|
|
} |
|
|
|
|
|
|
|
|
def accuracy(logits: Tensor, y: Tensor, mask: Tensor) -> float: |
|
|
pred = logits[mask].argmax(dim=1) |
|
|
return (pred == y[mask]).float().mean().item() |
|
|
|
|
|
|
|
|
def set_seed(seed: int): |
|
|
"""Set random seeds for reproducibility across runs.""" |
|
|
random.seed(seed) |
|
|
try: |
|
|
import numpy as np |
|
|
np.random.seed(seed) |
|
|
except Exception: |
|
|
pass |
|
|
torch.manual_seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GCN2(nn.Module): |
|
|
def __init__(self, in_dim: int, hid: int, out_dim: int, dropout: float = 0.5): |
|
|
super().__init__() |
|
|
self.c1 = GCNConv(in_dim, hid) |
|
|
self.c2 = GCNConv(hid, out_dim) |
|
|
self.dropout = dropout |
|
|
|
|
|
def forward(self, x, ei): |
|
|
x = self.c1(x, ei) |
|
|
x = torch.relu(x) |
|
|
x = F.dropout(x, p=self.dropout, training=self.training) |
|
|
x = self.c2(x, ei) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def eval_all(model: nn.Module, data) -> Dict[str, float]: |
|
|
model.eval() |
|
|
logits = model(data.x, data.edge_index) |
|
|
return { |
|
|
"train": accuracy(logits, data.y, data.train_mask), |
|
|
"val": accuracy(logits, data.y, data.val_mask), |
|
|
"test": accuracy(logits, data.y, data.test_mask), |
|
|
} |
|
|
|
|
|
|
|
|
def train(model: nn.Module, data, epochs=200, lr=0.01, wd=5e-4, patience=100): |
|
|
opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd) |
|
|
best, best_state, bad = -1.0, None, 0 |
|
|
|
|
|
|
|
|
with Progress( |
|
|
SpinnerColumn(), |
|
|
"[progress.description]{task.description}", |
|
|
TimeElapsedColumn(), |
|
|
transient=True, |
|
|
) as progress: |
|
|
task = progress.add_task("Training", total=epochs) |
|
|
|
|
|
for ep in range(1, epochs + 1): |
|
|
model.train() |
|
|
opt.zero_grad(set_to_none=True) |
|
|
out = model(data.x, data.edge_index) |
|
|
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask]) |
|
|
loss.backward() |
|
|
opt.step() |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
val = accuracy(model(data.x, data.edge_index), data.y, data.val_mask) |
|
|
|
|
|
if val > best: |
|
|
best, bad = val, 0 |
|
|
best_state = {k: v.detach().clone() for k, v in model.state_dict().items()} |
|
|
else: |
|
|
bad += 1 |
|
|
if bad >= patience: |
|
|
break |
|
|
|
|
|
progress.update(task, advance=1, description=f"Epoch {ep} | val={val:.4f}") |
|
|
|
|
|
if best_state is not None: |
|
|
model.load_state_dict(best_state) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
def subset_data(data, nodes_idx: torch.Tensor): |
|
|
""" |
|
|
Build an induced subgraph on 'nodes_idx'. Keeps x,y,masks restricted to that set. |
|
|
Returns a shallow copy with edge_index/feature/labels/masks sliced. |
|
|
""" |
|
|
nodes_idx = nodes_idx.to(torch.long) |
|
|
sub_ei, _ = subgraph(nodes_idx, data.edge_index, relabel_nodes=True, num_nodes=data.num_nodes) |
|
|
sub = type(data)() |
|
|
sub.x = data.x[nodes_idx] |
|
|
sub.y = data.y[nodes_idx] |
|
|
sub.train_mask = data.train_mask[nodes_idx] |
|
|
sub.val_mask = data.val_mask[nodes_idx] |
|
|
sub.test_mask = data.test_mask[nodes_idx] |
|
|
sub.edge_index = sub_ei |
|
|
sub.num_nodes = sub.x.size(0) |
|
|
return sub |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def appnp_seed_propagate(logits_seed: Tensor, edge_index: Tensor, K=10, alpha=0.1) -> Tensor: |
|
|
""" |
|
|
logits_seed is [N, C] where rows outside the core are zeros. |
|
|
We propagate these logits with APPNP to fill the graph. |
|
|
""" |
|
|
appnp = APPNP(K=K, alpha=alpha) |
|
|
return appnp(logits_seed, edge_index) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
p = argparse.ArgumentParser() |
|
|
p.add_argument("--dataset", required=True, choices=["Cora", "Citeseer", "Pubmed"]) |
|
|
p.add_argument("--seeds", required=True, help="Path to LRMC seeds JSON") |
|
|
p.add_argument("--hidden", type=int, default=64) |
|
|
p.add_argument("--dropout", type=float, default=0.5) |
|
|
p.add_argument("--epochs", type=int, default=200) |
|
|
p.add_argument("--lr", type=float, default=0.01) |
|
|
p.add_argument("--wd", type=float, default=5e-4) |
|
|
p.add_argument("--patience", type=int, default=100) |
|
|
p.add_argument("--core_mode", choices=["forward", "appnp"], default="forward", |
|
|
help="How to evaluate the core model on the full graph.") |
|
|
p.add_argument("--alpha", type=float, default=0.1, help="APPNP teleport prob (Mode B).") |
|
|
p.add_argument("--K", type=int, default=10, help="APPNP steps (Mode B).") |
|
|
p.add_argument("--expand_core_with_train", action="store_true", |
|
|
help="Expand LRMC core with all training nodes (C' = C ∪ train_idx).") |
|
|
p.add_argument("--warm_ft_epochs", type=int, default=0, |
|
|
help="If >0, run a short finetune on the FULL graph starting from the core model.") |
|
|
p.add_argument("--warm_ft_lr", type=float, default=0.005) |
|
|
p.add_argument("--runs", type=int, default=1, |
|
|
help="Number of runs with different seeds to average results.") |
|
|
p.add_argument("-o", "--output_json", type=str, default=None, |
|
|
help="If set, save all computed metrics and settings to this JSON file.") |
|
|
args = p.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ds = Planetoid(root=f"./data/{args.dataset}", name=args.dataset) |
|
|
data = ds[0] |
|
|
n, e = data.num_nodes, data.edge_index.size(1) // 2 |
|
|
|
|
|
console.print(f"[bold cyan]Dataset: {args.dataset} | Nodes: {n} | Edges: {e}[/bold cyan]") |
|
|
|
|
|
|
|
|
results = { |
|
|
"args": { |
|
|
k: (float(v) if isinstance(v, float) else v) |
|
|
for k, v in vars(args).items() |
|
|
if k != "output_json" |
|
|
}, |
|
|
"dataset": { |
|
|
"name": args.dataset, |
|
|
"num_nodes": int(n), |
|
|
"num_edges": int(e), |
|
|
}, |
|
|
} |
|
|
|
|
|
def maybe_save_results(): |
|
|
"""Write results to JSON if the user requested it.""" |
|
|
if not args.output_json: |
|
|
return |
|
|
out_path = Path(args.output_json) |
|
|
try: |
|
|
out_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
except Exception: |
|
|
pass |
|
|
with out_path.open("w") as f: |
|
|
json.dump(results, f, indent=2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
core_mask = load_top1_assignment(args.seeds, n) |
|
|
if args.expand_core_with_train: |
|
|
core_mask = core_mask | data.train_mask |
|
|
|
|
|
C_idx = torch.nonzero(core_mask, as_tuple=False).view(-1) |
|
|
frac = 100.0 * C_idx.numel() / n |
|
|
cov = coverage_counts(core_mask, data.train_mask, data.val_mask, data.test_mask) |
|
|
|
|
|
console.print(f"[bold green]Loaded LRMC core of size {cov['core_size']} (≈{frac:.2f}% of the graph) from {args.seeds}[/bold green]") |
|
|
|
|
|
|
|
|
results["core"] = { |
|
|
"source": str(args.seeds), |
|
|
"expanded_with_train": bool(args.expand_core_with_train), |
|
|
"size": int(cov["core_size"]), |
|
|
"fraction": float(frac / 100.0), |
|
|
"coverage": { |
|
|
"train_in_core": int(cov["train_in_core"]), |
|
|
"val_in_core": int(cov["val_in_core"]), |
|
|
"test_in_core": int(cov["test_in_core"]), |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
cov_table = Table(title="LRMC Core Coverage") |
|
|
cov_table.add_column("Metric", style="cyan") |
|
|
cov_table.add_column("Count", style="magenta") |
|
|
cov_table.add_row("Core Size", str(cov["core_size"])) |
|
|
cov_table.add_row("Train in Core", str(cov["train_in_core"])) |
|
|
cov_table.add_row("Val in Core", str(cov["val_in_core"])) |
|
|
cov_table.add_row("Test in Core", str(cov["test_in_core"])) |
|
|
console.print(cov_table) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.runs == 1: |
|
|
|
|
|
|
|
|
|
|
|
set_seed(0) |
|
|
t0 = time.perf_counter() |
|
|
base = GCN2(in_dim=ds.num_node_features, |
|
|
hid=args.hidden, |
|
|
out_dim=ds.num_classes, |
|
|
dropout=args.dropout) |
|
|
train(base, data, epochs=args.epochs, lr=args.lr, wd=args.wd, patience=args.patience) |
|
|
base_metrics = eval_all(base, data) |
|
|
t1 = time.perf_counter() |
|
|
|
|
|
console.print("\n[bold]Baseline (trained on full graph):[/bold]") |
|
|
base_table = Table(show_header=True, header_style="bold magenta") |
|
|
base_table.add_column("Metric", style="cyan") |
|
|
base_table.add_column("Value", style="magenta") |
|
|
base_table.add_row("Train Accuracy", f"{base_metrics['train']:.4f}") |
|
|
base_table.add_row("Validation Accuracy", f"{base_metrics['val']:.4f}") |
|
|
base_table.add_row("Test Accuracy", f"{base_metrics['test']:.4f}") |
|
|
base_table.add_row("Time (s)", f"{t1 - t0:.2f}") |
|
|
console.print(base_table) |
|
|
|
|
|
|
|
|
results["single_run"] = { |
|
|
"baseline": { |
|
|
"train": float(base_metrics["train"]), |
|
|
"val": float(base_metrics["val"]), |
|
|
"test": float(base_metrics["test"]), |
|
|
"time_s": float(t1 - t0), |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if C_idx.numel() == 0: |
|
|
console.print("[bold yellow]LRMC core is empty; skipping core model.[/bold yellow]") |
|
|
results["core_empty"] = True |
|
|
maybe_save_results() |
|
|
return |
|
|
|
|
|
data_C = subset_data(data, C_idx) |
|
|
mC = GCN2(in_dim=ds.num_node_features, |
|
|
hid=args.hidden, |
|
|
out_dim=ds.num_classes, |
|
|
dropout=args.dropout) |
|
|
|
|
|
t2 = time.perf_counter() |
|
|
train(mC, data_C, epochs=args.epochs, lr=args.lr, wd=args.wd, patience=args.patience) |
|
|
t3 = time.perf_counter() |
|
|
|
|
|
|
|
|
if args.core_mode == "forward": |
|
|
|
|
|
mC.eval() |
|
|
logits_full = mC(data.x, data.edge_index) |
|
|
else: |
|
|
|
|
|
mC.eval() |
|
|
with torch.no_grad(): |
|
|
logits_C = mC(data_C.x, data_C.edge_index) |
|
|
logits_seed = torch.zeros(n, ds.num_classes, device=logits_C.device) |
|
|
logits_seed[C_idx] = logits_C |
|
|
logits_full = appnp_seed_propagate(logits_seed, |
|
|
data.edge_index, |
|
|
K=args.K, |
|
|
alpha=args.alpha) |
|
|
|
|
|
core_metrics = { |
|
|
"train": accuracy(logits_full, data.y, data.train_mask), |
|
|
"val": accuracy(logits_full, data.y, data.val_mask), |
|
|
"test": accuracy(logits_full, data.y, data.test_mask), |
|
|
} |
|
|
|
|
|
console.print("\n[bold]LRMC‑core model (trained on core, evaluated on full graph):[/bold]") |
|
|
core_table = Table(show_header=True, header_style="bold magenta") |
|
|
core_table.add_column("Metric", style="cyan") |
|
|
core_table.add_column("Value", style="magenta") |
|
|
core_table.add_row("Train Accuracy", f"{core_metrics['train']:.4f}") |
|
|
core_table.add_row("Validation Accuracy", f"{core_metrics['val']:.4f}") |
|
|
core_table.add_row("Test Accuracy", f"{core_metrics['test']:.4f}") |
|
|
core_table.add_row("Core Training Time (s)", f"{t3 - t2:.2f}") |
|
|
speedup = (t1 - t0) / (t3 - t2 + 1e-9) |
|
|
core_table.add_row("Speedup vs. Baseline", f"{speedup:.2f}×") |
|
|
console.print(core_table) |
|
|
|
|
|
|
|
|
results["single_run"]["core_model"] = { |
|
|
"mode": str(args.core_mode), |
|
|
"train": float(core_metrics["train"]), |
|
|
"val": float(core_metrics["val"]), |
|
|
"test": float(core_metrics["test"]), |
|
|
"core_train_time_s": float(t3 - t2), |
|
|
"speedup_vs_baseline": float(speedup), |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
console.print("\n[bold]Model Comparison: Baseline vs. L-RMC-core[/bold]") |
|
|
|
|
|
|
|
|
comparison_table = Table(title="Performance Comparison", show_header=True, header_style="bold magenta") |
|
|
comparison_table.add_column("Metric", style="cyan") |
|
|
comparison_table.add_column("Baseline", style="magenta") |
|
|
comparison_table.add_column("L-RMC-core", style="green") |
|
|
comparison_table.add_column("Speedup", style="yellow") |
|
|
|
|
|
|
|
|
for metric in ["train", "val", "test"]: |
|
|
comparison_table.add_row( |
|
|
f"{metric.capitalize()} Accuracy", |
|
|
f"{base_metrics[metric]:.4f}", |
|
|
f"{core_metrics[metric]:.4f}", |
|
|
"" |
|
|
) |
|
|
|
|
|
|
|
|
baseline_time = t1 - t0 |
|
|
core_time = t3 - t2 |
|
|
speedup = baseline_time / core_time if core_time > 0 else float('inf') |
|
|
|
|
|
comparison_table.add_row( |
|
|
"Training Time (s)", |
|
|
f"{baseline_time:.2f}", |
|
|
f"{core_time:.2f}", |
|
|
f"{speedup:.2f}x" |
|
|
) |
|
|
|
|
|
comparison_table.add_row( |
|
|
"Speedup", |
|
|
"1x", |
|
|
f"{speedup:.2f}x", |
|
|
"" |
|
|
) |
|
|
|
|
|
console.print(comparison_table) |
|
|
|
|
|
|
|
|
if args.warm_ft_epochs > 0: |
|
|
warm = GCN2(in_dim=ds.num_node_features, |
|
|
hid=args.hidden, |
|
|
out_dim=ds.num_classes, |
|
|
dropout=args.dropout) |
|
|
warm.load_state_dict(mC.state_dict()) |
|
|
|
|
|
t4 = time.perf_counter() |
|
|
train(warm, data, |
|
|
epochs=args.warm_ft_epochs, |
|
|
lr=args.warm_ft_lr, |
|
|
wd=args.wd, |
|
|
patience=args.warm_ft_epochs + 1) |
|
|
t5 = time.perf_counter() |
|
|
warm_metrics = eval_all(warm, data) |
|
|
|
|
|
console.print("\n[bold]Warm‑start finetune (start from core model, train on FULL graph):[/bold]") |
|
|
warm_table = Table(show_header=True, header_style="bold magenta") |
|
|
warm_table.add_column("Metric", style="cyan") |
|
|
warm_table.add_column("Value", style="magenta") |
|
|
warm_table.add_row("Train Accuracy", f"{warm_metrics['train']:.4f}") |
|
|
warm_table.add_row("Validation Accuracy", f"{warm_metrics['val']:.4f}") |
|
|
warm_table.add_row("Test Accuracy", f"{warm_metrics['test']:.4f}") |
|
|
warm_table.add_row("Finetune Time (s)", f"{t5 - t4:.2f}") |
|
|
warm_table.add_row("Total (core train + warm)", f"{(t3 - t2 + t5 - t4):.2f}s") |
|
|
console.print(warm_table) |
|
|
|
|
|
|
|
|
results["single_run"]["warm_finetune"] = { |
|
|
"train": float(warm_metrics["train"]), |
|
|
"val": float(warm_metrics["val"]), |
|
|
"test": float(warm_metrics["test"]), |
|
|
"finetune_time_s": float(t5 - t4), |
|
|
"total_time_s": float((t3 - t2) + (t5 - t4)), |
|
|
} |
|
|
|
|
|
|
|
|
maybe_save_results() |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
runs = args.runs |
|
|
console.print(f"\n[bold]Running {runs} seeds and averaging results[/bold]") |
|
|
|
|
|
|
|
|
base_train, base_val, base_test, base_time = [], [], [], [] |
|
|
core_train, core_val, core_test, core_time = [], [], [], [] |
|
|
speedups = [] |
|
|
|
|
|
warm_train, warm_val, warm_test, warm_time, warm_total_time = [], [], [], [], [] |
|
|
|
|
|
data_C = subset_data(data, C_idx) if C_idx.numel() > 0 else None |
|
|
results["core_empty"] = data_C is None |
|
|
|
|
|
for r in range(runs): |
|
|
set_seed(r) |
|
|
|
|
|
|
|
|
t0 = time.perf_counter() |
|
|
base = GCN2(in_dim=ds.num_node_features, |
|
|
hid=args.hidden, |
|
|
out_dim=ds.num_classes, |
|
|
dropout=args.dropout) |
|
|
train(base, data, epochs=args.epochs, lr=args.lr, wd=args.wd, patience=args.patience) |
|
|
bm = eval_all(base, data) |
|
|
t1 = time.perf_counter() |
|
|
|
|
|
base_train.append(bm["train"]) ; base_val.append(bm["val"]) ; base_test.append(bm["test"]) ; base_time.append(t1 - t0) |
|
|
|
|
|
|
|
|
if data_C is None: |
|
|
continue |
|
|
|
|
|
t2 = time.perf_counter() |
|
|
mC = GCN2(in_dim=ds.num_node_features, |
|
|
hid=args.hidden, |
|
|
out_dim=ds.num_classes, |
|
|
dropout=args.dropout) |
|
|
train(mC, data_C, epochs=args.epochs, lr=args.lr, wd=args.wd, patience=args.patience) |
|
|
t3 = time.perf_counter() |
|
|
|
|
|
if args.core_mode == "forward": |
|
|
mC.eval() |
|
|
logits_full = mC(data.x, data.edge_index) |
|
|
else: |
|
|
mC.eval() |
|
|
with torch.no_grad(): |
|
|
logits_C = mC(data_C.x, data_C.edge_index) |
|
|
logits_seed = torch.zeros(n, ds.num_classes, device=logits_C.device) |
|
|
logits_seed[C_idx] = logits_C |
|
|
logits_full = appnp_seed_propagate(logits_seed, |
|
|
data.edge_index, |
|
|
K=args.K, |
|
|
alpha=args.alpha) |
|
|
|
|
|
cm = { |
|
|
"train": accuracy(logits_full, data.y, data.train_mask), |
|
|
"val": accuracy(logits_full, data.y, data.val_mask), |
|
|
"test": accuracy(logits_full, data.y, data.test_mask), |
|
|
} |
|
|
|
|
|
core_train.append(cm["train"]) ; core_val.append(cm["val"]) ; core_test.append(cm["test"]) ; core_time.append(t3 - t2) |
|
|
speedups.append((t1 - t0) / (t3 - t2 + 1e-9)) |
|
|
|
|
|
|
|
|
if args.warm_ft_epochs > 0: |
|
|
warm = GCN2(in_dim=ds.num_node_features, |
|
|
hid=args.hidden, |
|
|
out_dim=ds.num_classes, |
|
|
dropout=args.dropout) |
|
|
warm.load_state_dict(mC.state_dict()) |
|
|
|
|
|
t4 = time.perf_counter() |
|
|
train(warm, data, |
|
|
epochs=args.warm_ft_epochs, |
|
|
lr=args.warm_ft_lr, |
|
|
wd=args.wd, |
|
|
patience=args.warm_ft_epochs + 1) |
|
|
t5 = time.perf_counter() |
|
|
wm = eval_all(warm, data) |
|
|
warm_train.append(wm["train"]) ; warm_val.append(wm["val"]) ; warm_test.append(wm["test"]) ; warm_time.append(t5 - t4) |
|
|
warm_total_time.append((t3 - t2) + (t5 - t4)) |
|
|
|
|
|
|
|
|
def fmt(values, prec=4): |
|
|
if not values: |
|
|
return "n/a" |
|
|
if len(values) == 1: |
|
|
return f"{values[0]:.{prec}f}" |
|
|
try: |
|
|
return f"{mean(values):.{prec}f} ± {stdev(values):.{prec}f}" |
|
|
except Exception: |
|
|
m = sum(values) / len(values) |
|
|
var = sum((v - m) ** 2 for v in values) / max(1, len(values) - 1) |
|
|
return f"{m:.{prec}f} ± {var ** 0.5:.{prec}f}" |
|
|
|
|
|
def stats(values): |
|
|
"""Return dict with list, mean, std, count for JSON.""" |
|
|
d = { |
|
|
"values": [float(v) for v in values], |
|
|
"count": int(len(values)), |
|
|
} |
|
|
if len(values) >= 1: |
|
|
d["mean"] = float(mean(values)) |
|
|
if len(values) >= 2: |
|
|
d["std"] = float(stdev(values)) |
|
|
else: |
|
|
d["std"] = None |
|
|
return d |
|
|
|
|
|
|
|
|
console.print("\n[bold]Baseline (averaged over runs):[/bold]") |
|
|
base_table = Table(show_header=True, header_style="bold magenta") |
|
|
base_table.add_column("Metric", style="cyan") |
|
|
base_table.add_column("Mean ± Std", style="magenta") |
|
|
base_table.add_row("Train Accuracy", fmt(base_train)) |
|
|
base_table.add_row("Validation Accuracy", fmt(base_val)) |
|
|
base_table.add_row("Test Accuracy", fmt(base_test)) |
|
|
base_table.add_row("Time (s)", fmt(base_time, prec=2)) |
|
|
console.print(base_table) |
|
|
|
|
|
|
|
|
results["multi_run"] = { |
|
|
"runs": int(runs), |
|
|
"baseline": { |
|
|
"train": stats(base_train), |
|
|
"val": stats(base_val), |
|
|
"test": stats(base_test), |
|
|
"time_s": stats(base_time), |
|
|
} |
|
|
} |
|
|
|
|
|
if data_C is None: |
|
|
console.print("[bold yellow]LRMC core is empty; no core runs to average.[/bold yellow]") |
|
|
maybe_save_results() |
|
|
return |
|
|
|
|
|
|
|
|
console.print("\n[bold]LRMC‑core (averaged over runs):[/bold]") |
|
|
core_table = Table(show_header=True, header_style="bold magenta") |
|
|
core_table.add_column("Metric", style="cyan") |
|
|
core_table.add_column("Mean ± Std", style="magenta") |
|
|
core_table.add_row("Train Accuracy", fmt(core_train)) |
|
|
core_table.add_row("Validation Accuracy", fmt(core_val)) |
|
|
core_table.add_row("Test Accuracy", fmt(core_test)) |
|
|
core_table.add_row("Core Training Time (s)", fmt(core_time, prec=2)) |
|
|
core_table.add_row("Speedup vs. Baseline", fmt(speedups, prec=2)) |
|
|
console.print(core_table) |
|
|
|
|
|
|
|
|
results["multi_run"]["core_model"] = { |
|
|
"mode": str(args.core_mode), |
|
|
"train": stats(core_train), |
|
|
"val": stats(core_val), |
|
|
"test": stats(core_test), |
|
|
"core_train_time_s": stats(core_time), |
|
|
"speedup_vs_baseline": stats(speedups), |
|
|
} |
|
|
|
|
|
|
|
|
console.print("\n[bold]Model Comparison (averaged): Baseline vs. L-RMC-core[/bold]") |
|
|
comparison_table = Table(title="Performance Comparison (Mean ± Std)", show_header=True, header_style="bold magenta") |
|
|
comparison_table.add_column("Metric", style="cyan") |
|
|
comparison_table.add_column("Baseline", style="magenta") |
|
|
comparison_table.add_column("L-RMC-core", style="green") |
|
|
comparison_table.add_column("Speedup", style="yellow") |
|
|
|
|
|
for metric, b_vals, c_vals in [ |
|
|
("Train Accuracy", base_train, core_train), |
|
|
("Validation Accuracy", base_val, core_val), |
|
|
("Test Accuracy", base_test, core_test), |
|
|
]: |
|
|
comparison_table.add_row(metric, fmt(b_vals), fmt(c_vals), "") |
|
|
|
|
|
comparison_table.add_row("Training Time (s)", fmt(base_time, prec=2), fmt(core_time, prec=2), fmt(speedups, prec=2)) |
|
|
comparison_table.add_row("Speedup", "1x", fmt(speedups, prec=2), "") |
|
|
console.print(comparison_table) |
|
|
|
|
|
|
|
|
if args.warm_ft_epochs > 0 and warm_time: |
|
|
console.print("\n[bold]Warm‑start finetune (averaged over runs):[/bold]") |
|
|
warm_table = Table(show_header=True, header_style="bold magenta") |
|
|
warm_table.add_column("Metric", style="cyan") |
|
|
warm_table.add_column("Mean ± Std", style="magenta") |
|
|
warm_table.add_row("Train Accuracy", fmt(warm_train)) |
|
|
warm_table.add_row("Validation Accuracy", fmt(warm_val)) |
|
|
warm_table.add_row("Test Accuracy", fmt(warm_test)) |
|
|
warm_table.add_row("Finetune Time (s)", fmt(warm_time, prec=2)) |
|
|
warm_table.add_row("Total (core train + warm)", fmt(warm_total_time, prec=2)) |
|
|
console.print(warm_table) |
|
|
|
|
|
|
|
|
results["multi_run"]["warm_finetune"] = { |
|
|
"train": stats(warm_train), |
|
|
"val": stats(warm_val), |
|
|
"test": stats(warm_test), |
|
|
"finetune_time_s": stats(warm_time), |
|
|
"total_time_s": stats(warm_total_time), |
|
|
} |
|
|
|
|
|
|
|
|
maybe_save_results() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|