clique / src /2.5_lrmc_bilevel.py
qingy2024's picture
Upload folder using huggingface_hub
f74dd01 verified
"""
L-RMC Anchored GCN vs. Plain GCN (dynamic robustness evaluation)
==============================================================
This script trains a baseline two‑layer GCN and a new **anchor‑gated** GCN on
Planetoid citation networks (Cora/Citeseer/Pubmed). The anchor‑gated GCN uses
the top‑1 L‑RMC cluster (loaded from a provided JSON file) as a *decentralized
core*. During message passing it blends standard neighborhood aggregation
(`h_base`) with aggregation restricted to the core (`h_core`) via a per‑node
gating network. Cross‑boundary edges are optionally down‑weighted by a
damping factor `γ`.
After training on the static graph, the script evaluates *robustness over
time*. Starting from the original adjacency, it repeatedly performs random
edge rewires (removes a fraction of existing edges and adds the same number
of random new edges) and measures test accuracy at each step **without
retraining**. The area under the accuracy–time curve (AUC‑AT) is reported
for both the baseline and the anchored model. A higher AUC‑AT indicates
longer resilience to graph churn.
Usage examples::
# Train only baseline and report dynamic AUC
python 2.5_lrmc_bilevel.py --dataset Cora --seeds path/to/lrmc_seeds.json --variant baseline
# Train baseline and anchor models, evaluate AUC‑over‑time on 30 steps with 5% rewiring
python 2.5_lrmc_bilevel.py --dataset Cora --seeds path/to/lrmc_seeds.json --variant anchor \
--dynamic_steps 30 --flip_fraction 0.05 --gamma 0.8
Notes
-----
* The seeds JSON must contain an entry ``"clusters"`` with a list of clusters; the
cluster with maximum (score, size) is chosen as the core.
* For fairness, both models are trained on the identical training mask and
evaluated on the same dynamic perturbations.
* Random rewiring is undirected: an edge (u,v) is treated as the same as (v,u).
* Cross‑boundary damping and the gating network use only structural
information; features are left unchanged during perturbations.
"""
import argparse
import json
import random
from pathlib import Path
from typing import Tuple, List, Optional, Set
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
from rich import print
# -----------------------------------------------------------------------------
# Utilities for loading LRMC core assignment
# -----------------------------------------------------------------------------
def _pick_top1_cluster(obj: dict) -> List[int]:
"""
From LRMC JSON with structure {"clusters":[{"seed_nodes":[...],"score":float,...},...]}
choose the cluster with the highest (score, size) and return its members as
0‑indexed integers. If no clusters exist, returns an empty list.
"""
clusters = obj.get("clusters", [])
if not clusters:
return []
# Choose by highest score, tie‑break by size
best = max(clusters, key=lambda c: (float(c.get("score", 0.0)), len(c.get("seed_nodes", []))))
return [nid - 1 for nid in best.get("seed_nodes", [])]
def load_top1_assignment(seeds_json: str, n_nodes: int) -> Tuple[Tensor, Tensor]:
"""
Given a path to the LRMC seeds JSON and total number of nodes, returns:
* core_mask: bool Tensor of shape [N] where True indicates membership in the
top‑1 LRMC cluster.
* core_nodes: Long Tensor containing the indices of the core nodes.
Nodes not in the core form the periphery. If the JSON has no clusters,
the core is empty.
"""
obj = json.loads(Path(seeds_json).read_text())
core_list = _pick_top1_cluster(obj)
core_nodes = torch.tensor(sorted(set(core_list)), dtype=torch.long)
core_mask = torch.zeros(n_nodes, dtype=torch.bool)
if core_nodes.numel() > 0:
core_mask[core_nodes] = True
return core_mask, core_nodes
# -----------------------------------------------------------------------------
# Baseline GCN: standard two‑layer GCN
# -----------------------------------------------------------------------------
class GCN2(nn.Module):
"""Plain 2‑layer GCN (baseline)."""
def __init__(self, in_dim: int, hid_dim: int, out_dim: int, dropout: float = 0.5):
super().__init__()
self.conv1 = GCNConv(in_dim, hid_dim)
self.conv2 = GCNConv(hid_dim, out_dim)
self.dropout = dropout
def forward(self, x: Tensor, edge_index: Tensor, edge_weight: Optional[Tensor] = None) -> Tensor:
# Use self loops implicitly (GCNConv defaults add_self_loops=True)
x = F.relu(self.conv1(x, edge_index, edge_weight))
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.conv2(x, edge_index, edge_weight)
return x
# -----------------------------------------------------------------------------
# Anchor‑gated GCN
# -----------------------------------------------------------------------------
class AnchorGCN(nn.Module):
"""
A two‑layer GCN that injects a core‑restricted aggregation channel and
down‑weights edges crossing the core boundary. After the first GCN layer
computes base features, a gating network mixes them with features
aggregated only among core neighbors.
Parameters
----------
in_dim : int
Dimensionality of input node features.
hid_dim : int
Dimensionality of hidden layer.
out_dim : int
Number of output classes.
core_mask : Tensor[bool]
Boolean mask indicating which nodes belong to the L‑RMC core.
gamma : float, optional
Damping factor for edges that connect core and non‑core nodes.
Values <1.0 reduce the influence of boundary edges. Default is 1.0
(no damping).
dropout : float, optional
Dropout probability applied after the first layer.
"""
def __init__(self,
in_dim: int,
hid_dim: int,
out_dim: int,
core_mask: Tensor,
gamma: float = 1.0,
dropout: float = 0.5):
super().__init__()
self.core_mask = core_mask.clone().detach()
self.gamma = float(gamma)
self.dropout = dropout
# Base and core convolutions for the first layer
# Base conv uses self loops; core conv disables self loops to avoid
# spurious core contributions on non‑core nodes
self.base1 = GCNConv(in_dim, hid_dim, add_self_loops=True)
self.core1 = GCNConv(in_dim, hid_dim, add_self_loops=False)
# Second layer: standard GCN on mixed features
self.conv2 = GCNConv(hid_dim, out_dim)
# Gating network: maps structural features to α ∈ [0,1]
self.gate = nn.Sequential(
nn.Linear(3, 16),
nn.ReLU(),
nn.Linear(16, 1),
nn.Sigmoid(),
)
def _compute_edge_weights(self, edge_index: Tensor) -> Tensor:
"""
Given an edge index (two‑row tensor), return a weight tensor of ones
multiplied by ``gamma`` for edges with exactly one endpoint in the core.
Self loops (if present) are untouched. Edge weights are 1 for base
edges and <1 for cross‑boundary edges.
"""
if self.gamma >= 1.0:
return torch.ones(edge_index.size(1), dtype=torch.float32, device=edge_index.device)
src, dst = edge_index[0], edge_index[1]
in_core_src = self.core_mask[src]
in_core_dst = self.core_mask[dst]
cross = in_core_src ^ in_core_dst
w = torch.ones(edge_index.size(1), dtype=torch.float32, device=edge_index.device)
w[cross] *= self.gamma
return w
def _compute_structural_features(self, edge_index: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""
Compute structural features used by the gating network:
* `in_core` – 1 if node in core, else 0
* `frac_core_nbrs` – fraction of neighbors that are in the core
* `is_boundary` – 1 if node has both core and non‑core neighbors
The features are returned as a tuple of three tensors of shape [N,1].
Nodes with zero degree get frac_core_nbrs=0 and is_boundary=0.
"""
N = self.core_mask.size(0)
device = edge_index.device
# Degree and core neighbor counts
src = edge_index[0]
dst = edge_index[1]
deg = torch.zeros(N, dtype=torch.float32, device=device)
core_deg = torch.zeros(N, dtype=torch.float32, device=device)
# Count contributions of directed edges; duplicates will double‑count but
# the ratio remains stable if the graph is symmetric.
deg.index_add_(0, src, torch.ones_like(src, dtype=torch.float32))
# Count core neighbors: only increment source if destination is core
core_flags = self.core_mask[dst].float()
core_deg.index_add_(0, src, core_flags)
# Avoid division by zero
frac_core = torch.zeros(N, dtype=torch.float32, device=device)
nonzero = deg > 0
frac_core[nonzero] = core_deg[nonzero] / deg[nonzero]
# Determine boundary: at least one core neighbor AND at least one non‑core neighbor
has_core = core_deg > 0
has_non_core = (deg - core_deg) > 0
is_boundary = (has_core & has_non_core).float()
in_core = self.core_mask.float()
return in_core.view(-1, 1), frac_core.view(-1, 1), is_boundary.view(-1, 1)
def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
# Compute dynamic edge weights (for base channels) using damping
w = self._compute_edge_weights(edge_index)
# First layer: base aggregation (standard neighbors with self loops)
h_base = self.base1(x, edge_index, w)
h_base = F.relu(h_base)
# First layer: core aggregation (only core neighbors, no self loops)
# Extract edges where both endpoints are core
src, dst = edge_index
mask_core_edges = self.core_mask[src] & self.core_mask[dst]
ei_core = edge_index[:, mask_core_edges]
# If no core edges exist, h_core will be zeros
if ei_core.numel() == 0:
h_core = torch.zeros_like(h_base)
else:
h_core = self.core1(x, ei_core)
h_core = F.relu(h_core)
# Structural features for gating
in_core, frac_core, is_boundary = self._compute_structural_features(edge_index)
feats = torch.cat([in_core, frac_core, is_boundary], dim=1)
alpha = self.gate(feats).view(-1) # shape [N]
# Force α=0 for nodes with no core neighbors to avoid modifying true periphery.
# Nodes with frac_core == 0 have zero core neighbors by construction.
no_core_neighbors = (frac_core.view(-1) == 0)
alpha = torch.where(no_core_neighbors, torch.zeros_like(alpha), alpha)
# Mix base and core features; h_final = h_base + α (h_core - h_base)
# Equivalent to (1-α)*h_base + α*h_core
h1 = h_base + alpha.unsqueeze(1) * (h_core - h_base)
h1 = F.dropout(h1, p=self.dropout, training=self.training)
# Second layer: standard GCN with the same damping weights
out = self.conv2(h1, edge_index, w)
return out
# The deg_for_division helper is no longer used but left here for completeness.
def deg_for_division(edge_index: Tensor, num_nodes: int) -> Tensor:
src = edge_index[0]
deg = torch.zeros(num_nodes, dtype=torch.float32, device=edge_index.device)
deg.index_add_(0, src, torch.ones_like(src, dtype=torch.float32))
return deg
# -----------------------------------------------------------------------------
# Training and evaluation routines
# -----------------------------------------------------------------------------
@torch.no_grad()
def accuracy(logits: Tensor, y: Tensor, mask: Tensor) -> float:
"""Compute accuracy of the predictions over the mask."""
pred = logits[mask].argmax(dim=1)
return (pred == y[mask]).float().mean().item()
def train_model(model: nn.Module,
data,
epochs: int = 200,
lr: float = 0.01,
weight_decay: float = 5e-4) -> None:
"""Standard training loop for either baseline or anchor models."""
opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
best_val = 0.0
best_state = None
for ep in range(1, epochs + 1):
model.train()
opt.zero_grad(set_to_none=True)
logits = model(data.x, data.edge_index)
loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask])
loss.backward()
opt.step()
# Evaluate on validation
model.eval()
logits_val = model(data.x, data.edge_index)
val_acc = accuracy(logits_val, data.y, data.val_mask)
if val_acc > best_val:
best_val = val_acc
best_state = {k: v.detach().clone() for k, v in model.state_dict().items()}
if best_state is not None:
model.load_state_dict(best_state)
model.eval()
def evaluate_model(model: nn.Module, data) -> dict:
"""Evaluate a trained model on train, val, and test masks."""
model.eval()
logits = model(data.x, data.edge_index)
return {
"train": accuracy(logits, data.y, data.train_mask),
"val": accuracy(logits, data.y, data.val_mask),
"test": accuracy(logits, data.y, data.test_mask),
}
# -----------------------------------------------------------------------------
# Dynamic graph perturbation utilities
# -----------------------------------------------------------------------------
def undirected_edge_set(edge_index: Tensor) -> Set[Tuple[int, int]]:
"""
Convert a directed edge index into a set of undirected edges represented
as (u,v) tuples with u < v. Self loops are ignored.
"""
edges = set()
src = edge_index[0].tolist()
dst = edge_index[1].tolist()
for u, v in zip(src, dst):
if u == v:
continue
a, b = (u, v) if u < v else (v, u)
edges.add((a, b))
return edges
def edge_set_to_index(edges: Set[Tuple[int, int]], num_nodes: int) -> Tensor:
"""
Convert an undirected edge set into a directed edge_index tensor of shape
[2, 2*|edges|] by adding both (u,v) and (v,u) for each undirected edge.
Self loops are omitted; GCNConv adds them automatically.
"""
if not edges:
return torch.empty(2, 0, dtype=torch.long)
src_list = []
dst_list = []
for u, v in edges:
src_list.extend([u, v])
dst_list.extend([v, u])
edge_index = torch.tensor([src_list, dst_list], dtype=torch.long)
return edge_index
def random_rewire(edges: Set[Tuple[int, int]], num_nodes: int, n_changes: int, rng: random.Random) -> Set[Tuple[int, int]]:
"""
Perform n_changes edge removals and n_changes edge additions on the given
undirected edge set. For each change we remove a random existing edge and
add a random new edge (u,v) not already present. Self loops are never
added. Duplicate additions are skipped.
"""
edges = set(edges) # copy
# If there are fewer edges than n_changes, adjust
n_changes = min(n_changes, len(edges))
# Remove random edges
to_remove = rng.sample(list(edges), n_changes)
for e in to_remove:
edges.remove(e)
# Add random new edges
added = 0
attempts = 0
while added < n_changes and attempts < n_changes * 10:
u = rng.randrange(num_nodes)
v = rng.randrange(num_nodes)
if u == v:
attempts += 1
continue
a, b = (u, v) if u < v else (v, u)
if (a, b) not in edges:
edges.add((a, b))
added += 1
attempts += 1
return edges
def auc_over_time(acc_list: List[float]) -> float:
"""
Compute the area under an accuracy–time curve using the trapezoidal rule.
``acc_list`` should contain the accuracies at t=0,1,...,T. The AUC is
normalized by T so that a perfect score of 1.0 yields AUC=1.0.
"""
if not acc_list:
return 0.0
area = 0.0
for i in range(1, len(acc_list)):
area += (acc_list[i] + acc_list[i-1]) / 2.0
return area / (len(acc_list) - 1)
def evaluate_dynamic_auc(model: nn.Module,
data,
core_mask: Tensor,
steps: int = 30,
flip_fraction: float = 0.05,
rng_seed: int = 1234) -> List[float]:
"""
Evaluate a model's test accuracy over a sequence of random edge rewiring steps.
Parameters
----------
model : nn.Module
A trained model that accepts (x, edge_index) and returns logits.
data : Data
PyG data object with attributes x, y, test_mask. ``data.edge_index``
provides the initial adjacency.
core_mask : Tensor[bool]
Boolean mask indicating core nodes (used for gating during evaluation).
The baseline model ignores it.
steps : int, optional
Number of rewiring steps to perform. The accuracy at t=0 is computed
before any rewiring. Default: 30.
flip_fraction : float, optional
Fraction of edges to remove/add at each step. For example, 0.05
rewires 5% of existing edges per step. Default: 0.05.
rng_seed : int, optional
Random seed for reproducibility. Default: 1234.
Returns
-------
List[float]
A list of length ``steps+1`` containing the test accuracy at each
iteration (including t=0).
"""
# Convert initial edge_index to undirected edge set
base_edges = undirected_edge_set(data.edge_index)
num_edges = len(base_edges)
# Determine number of changes per step
n_changes = max(1, int(flip_fraction * num_edges))
# Clone model state so we don't accidentally update it during evaluation
model.eval()
# Random generator
rng = random.Random(rng_seed)
# Copy of edges for dynamic modification
cur_edges = set(base_edges)
accuracies = []
# Evaluate at t=0
ei = edge_set_to_index(cur_edges, data.num_nodes)
# Because PyG expects a tensor on the same device as data.x
ei = ei.to(data.x.device)
logits = model(data.x, ei)
accuracies.append(accuracy(logits, data.y, data.test_mask))
# Perform rewiring steps
for t in range(1, steps + 1):
cur_edges = random_rewire(cur_edges, data.num_nodes, n_changes, rng)
ei = edge_set_to_index(cur_edges, data.num_nodes).to(data.x.device)
logits = model(data.x, ei)
acc = accuracy(logits, data.y, data.test_mask)
accuracies.append(acc)
return accuracies
# -----------------------------------------------------------------------------
# Main entrypoint
# -----------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="L‑RMC anchored GCN vs. baseline with dynamic evaluation.")
parser.add_argument("--dataset", required=True, choices=["Cora", "Citeseer", "Pubmed"],
help="Planetoid dataset to load.")
parser.add_argument("--seeds", required=True, help="Path to LRMC seeds JSON (for core extraction).")
parser.add_argument("--variant", choices=["baseline", "anchor"], default="anchor", help="Which variant to run.")
parser.add_argument("--hidden", type=int, default=64, help="Hidden dimension.")
parser.add_argument("--epochs", type=int, default=200, help="Number of training epochs.")
parser.add_argument("--lr", type=float, default=0.01, help="Learning rate.")
parser.add_argument("--wd", type=float, default=5e-4, help="Weight decay (L2).")
parser.add_argument("--dropout", type=float, default=0.5, help="Dropout probability.")
parser.add_argument("--gamma", type=float, default=1.0, help="Damping factor γ for cross‑boundary edges (anchor only).")
parser.add_argument("--dynamic_steps", type=int, default=30, help="Number of dynamic rewiring steps for AUC evaluation.")
parser.add_argument("--flip_fraction", type=float, default=0.05, help="Fraction of edges rewired at each step.")
parser.add_argument("--seed", type=int, default=42, help="Random seed for PyTorch.")
args = parser.parse_args()
# Set seeds
torch.manual_seed(args.seed)
random.seed(args.seed)
# Load dataset
dataset = Planetoid(root=f"./data/{args.dataset}", name=args.dataset)
data = dataset[0]
in_dim = dataset.num_node_features
out_dim = dataset.num_classes
num_nodes = data.num_nodes
# Load core assignment
core_mask, core_nodes = load_top1_assignment(args.seeds, num_nodes)
print(f"Loaded core of size {core_nodes.numel()} from {args.seeds}.")
if args.variant == "baseline":
# Train baseline only
baseline = GCN2(in_dim, args.hidden, out_dim, dropout=args.dropout)
train_model(baseline, data, epochs=args.epochs, lr=args.lr, weight_decay=args.wd)
res = evaluate_model(baseline, data)
print(f"Baseline GCN: train={res['train']:.4f} val={res['val']:.4f} test={res['test']:.4f}")
# Evaluate dynamic AUC
accs = evaluate_dynamic_auc(baseline, data, core_mask, steps=args.dynamic_steps,
flip_fraction=args.flip_fraction, rng_seed=args.seed)
auc = auc_over_time(accs)
print(f"Baseline dynamic AUC‑AT (steps={args.dynamic_steps}, flip={args.flip_fraction}): {auc:.4f}")
return
# ----- Train both baseline and anchor variants -----
# Baseline
baseline = GCN2(in_dim, args.hidden, out_dim, dropout=args.dropout)
train_model(baseline, data, epochs=args.epochs, lr=args.lr, weight_decay=args.wd)
res_base = evaluate_model(baseline, data)
print(f"Baseline GCN: train={res_base['train']:.4f} val={res_base['val']:.4f} test={res_base['test']:.4f}")
# Anchor model
anchor = AnchorGCN(in_dim, args.hidden, out_dim,
core_mask=core_mask,
gamma=args.gamma,
dropout=args.dropout)
train_model(anchor, data, epochs=args.epochs, lr=args.lr, weight_decay=args.wd)
res_anchor = evaluate_model(anchor, data)
print(f"Anchor‑GCN: train={res_anchor['train']:.4f} val={res_anchor['val']:.4f} test={res_anchor['test']:.4f}")
# Dynamic evaluation
accs_base = evaluate_dynamic_auc(baseline, data, core_mask, steps=args.dynamic_steps,
flip_fraction=args.flip_fraction, rng_seed=args.seed)
accs_anchor = evaluate_dynamic_auc(anchor, data, core_mask, steps=args.dynamic_steps,
flip_fraction=args.flip_fraction, rng_seed=args.seed)
auc_base = auc_over_time(accs_base)
auc_anchor = auc_over_time(accs_anchor)
print(f"Dynamic AUC‑AT (steps={args.dynamic_steps}, flip={args.flip_fraction}):")
print(f" Baseline : {auc_base:.4f}\n Anchor : {auc_anchor:.4f}")
if __name__ == "__main__":
main()