|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse, json |
|
|
from pathlib import Path |
|
|
from typing import List, Tuple, Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from torch import Tensor |
|
|
from torch_scatter import scatter_add, scatter_mean |
|
|
from torch_sparse import coalesce, spspmm |
|
|
|
|
|
from torch_geometric.datasets import Planetoid |
|
|
from torch_geometric.nn import GCNConv |
|
|
from torch_geometric.utils import subgraph, degree |
|
|
|
|
|
from rich import print |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_scaled_self_loops(edge_index: Tensor, |
|
|
edge_weight: Optional[Tensor], |
|
|
num_nodes: int, |
|
|
scale: float = 1.0) -> Tuple[Tensor, Tensor]: |
|
|
"""Add self-loops with chosen weight (scale). If scale=0, return unchanged (and create weights if None).""" |
|
|
if scale == 0.0: |
|
|
if edge_weight is None: |
|
|
edge_weight = torch.ones(edge_index.size(1), device=edge_index.device) |
|
|
return edge_index, edge_weight |
|
|
device = edge_index.device |
|
|
self_loops = torch.arange(num_nodes, device=device) |
|
|
self_index = torch.stack([self_loops, self_loops], dim=0) |
|
|
self_weight = torch.full((num_nodes,), float(scale), device=device) |
|
|
base_w = edge_weight if edge_weight is not None else torch.ones(edge_index.size(1), device=device) |
|
|
ei = torch.cat([edge_index, self_index], dim=1) |
|
|
ew = torch.cat([base_w, self_weight], dim=0) |
|
|
ei, ew = coalesce(ei, ew, num_nodes, num_nodes, op='add') |
|
|
return ei, ew |
|
|
|
|
|
|
|
|
def adjacency_power(edge_index: Tensor, num_nodes: int, k: int = 2) -> Tensor: |
|
|
""" |
|
|
Compute (binary) k-th power adjacency using sparse matmul (torch_sparse.spspmm). |
|
|
Here we use k=2. Returns coalesced edge_index without weights. |
|
|
""" |
|
|
row, col = edge_index |
|
|
val = torch.ones(row.numel(), device=edge_index.device) |
|
|
Ai, Av = edge_index, val |
|
|
|
|
|
Ri, Rv = spspmm(Ai, Av, Ai, Av, num_nodes, num_nodes, num_nodes) |
|
|
mask = Ri[0] != Ri[1] |
|
|
Ri = Ri[:, mask] |
|
|
Ri, _ = coalesce(Ri, torch.ones(Ri.size(1), device=edge_index.device), num_nodes, num_nodes, op='add') |
|
|
return Ri |
|
|
|
|
|
|
|
|
def build_cluster_graph(edge_index: Tensor, |
|
|
num_nodes: int, |
|
|
node2cluster: Tensor, |
|
|
weight_per_edge: Optional[Tensor] = None, |
|
|
num_clusters: Optional[int] = None |
|
|
) -> Tuple[Tensor, Tensor, int]: |
|
|
""" |
|
|
Build cluster graph A_c = S^T A S with summed multiplicities as weights. |
|
|
node2cluster: [N] long tensor mapping each node -> cluster id. |
|
|
""" |
|
|
K = int(node2cluster.max().item()) + 1 if num_clusters is None else num_clusters |
|
|
src, dst = edge_index |
|
|
csrc = node2cluster[src] |
|
|
cdst = node2cluster[dst] |
|
|
edge_c = torch.stack([csrc, cdst], dim=0) |
|
|
w = weight_per_edge if weight_per_edge is not None else torch.ones(edge_c.size(1), device=edge_c.device) |
|
|
edge_c, w = coalesce(edge_c, w, K, K, op='add') |
|
|
return edge_c, w, K |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _pick_top1_cluster(obj: dict) -> List[int]: |
|
|
""" |
|
|
From LRMC JSON with structure: {"clusters":[{"seed_nodes":[...], "score":float, ...}, ...]} |
|
|
choose the cluster with max (score, size) and return its members as 0-indexed. |
|
|
""" |
|
|
clusters = obj.get("clusters", []) |
|
|
if not clusters: |
|
|
return [] |
|
|
|
|
|
best = max(clusters, key=lambda c: (float(c.get("score", 0.0)), len(c.get("seed_nodes", [])))) |
|
|
return list(map(lambda x: x - 1, best.get("seed_nodes", []))) |
|
|
|
|
|
|
|
|
def load_top1_assignment(seeds_json: str, n_nodes: int) -> Tuple[Tensor, Tensor, Tensor]: |
|
|
""" |
|
|
Create a hard assignment for top-1 LRMC cluster: |
|
|
- cluster 0 = top-1 LRMC set |
|
|
- nodes outside are singletons (1..K-1) |
|
|
Returns: |
|
|
node2cluster: [N] long |
|
|
cluster_scores: [K,1] with 1.0 for top cluster, 0.0 for singletons |
|
|
core_nodes: [|C|] long, original indices of nodes in the top-1 LRMC cluster |
|
|
""" |
|
|
obj = json.loads(Path(seeds_json).read_text()) |
|
|
C_star_list = _pick_top1_cluster(obj) |
|
|
C_star = torch.tensor(sorted(set(C_star_list)), dtype=torch.long) |
|
|
|
|
|
node2cluster = torch.full((n_nodes,), -1, dtype=torch.long) |
|
|
node2cluster[C_star] = 0 |
|
|
outside = torch.tensor(sorted(set(range(n_nodes)) - set(C_star.tolist())), dtype=torch.long) |
|
|
if outside.numel() > 0: |
|
|
node2cluster[outside] = torch.arange(1, 1 + outside.numel(), dtype=torch.long) |
|
|
assert int(node2cluster.min()) >= 0, "All nodes must be assigned." |
|
|
|
|
|
K = 1 + outside.numel() |
|
|
cluster_scores = torch.zeros(K, 1, dtype=torch.float32) |
|
|
if C_star.numel() > 0: |
|
|
cluster_scores[0, 0] = 1.0 |
|
|
return node2cluster, cluster_scores, C_star |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GCN2(nn.Module): |
|
|
"""Plain 2-layer GCN baseline.""" |
|
|
def __init__(self, in_dim, hid, out_dim, dropout_p: float = 0.5): |
|
|
super().__init__() |
|
|
self.conv1 = GCNConv(in_dim, hid) |
|
|
self.conv2 = GCNConv(hid, out_dim) |
|
|
self.dropout_p = dropout_p |
|
|
|
|
|
def forward(self, x, edge_index): |
|
|
x = F.relu(self.conv1(x, edge_index)) |
|
|
x = F.dropout(x, p=self.dropout_p, training=self.training) |
|
|
x = self.conv2(x, edge_index) |
|
|
return x |
|
|
|
|
|
|
|
|
class OneClusterPool(nn.Module): |
|
|
""" |
|
|
Node-GCN -> pool to one-cluster + singletons -> Cluster-GCN -> broadcast + skip -> Node-GCN -> classifier |
|
|
This version includes L-RMC stability tricks: |
|
|
1. Backbone-invariance regularizer (loss computed in forward). |
|
|
2. Boundary damping on node graph. |
|
|
""" |
|
|
def __init__(self, |
|
|
in_dim: int, |
|
|
hid: int, |
|
|
out_dim: int, |
|
|
node2cluster: Tensor, |
|
|
core_nodes: Tensor, |
|
|
edge_index_node: Tensor, |
|
|
num_nodes: int, |
|
|
self_loop_scale: float = 0.0, |
|
|
use_a2_for_clusters: bool = False, |
|
|
lrmc_gamma: float = 1.0, |
|
|
dropout_p: float = 0.5): |
|
|
super().__init__() |
|
|
self.n2c = node2cluster.long() |
|
|
self.K = int(self.n2c.max().item()) + 1 |
|
|
self.core_nodes = core_nodes |
|
|
self.lrmc_gamma = lrmc_gamma |
|
|
self.dropout_p = dropout_p |
|
|
|
|
|
|
|
|
ei_node = edge_index_node |
|
|
ew_node_base = None |
|
|
ei_node, ew_node = add_scaled_self_loops(ei_node, ew_node_base, num_nodes, scale=self_loop_scale) |
|
|
|
|
|
|
|
|
if self.lrmc_gamma < 1.0 and self.core_nodes.numel() > 0: |
|
|
is_core = torch.zeros(num_nodes, dtype=torch.bool, device=ei_node.device) |
|
|
is_core[self.core_nodes] = True |
|
|
|
|
|
src_is_core = is_core[ei_node[0]] |
|
|
dst_is_core = is_core[ei_node[1]] |
|
|
cross_boundary_mask = (src_is_core != dst_is_core) |
|
|
|
|
|
ew_node[cross_boundary_mask] *= self.lrmc_gamma |
|
|
|
|
|
|
|
|
self.register_buffer("edge_index_node", ei_node) |
|
|
self.register_buffer("edge_weight_node", ew_node) |
|
|
|
|
|
|
|
|
ei_for_c = adjacency_power(edge_index_node, num_nodes, k=2) if use_a2_for_clusters else edge_index_node |
|
|
edge_index_c, edge_weight_c, K = build_cluster_graph(ei_for_c, num_nodes, self.n2c) |
|
|
self.register_buffer("edge_index_c", edge_index_c) |
|
|
self.register_buffer("edge_weight_c", edge_weight_c) |
|
|
self.K = K |
|
|
|
|
|
|
|
|
self.gcn_node1 = GCNConv(in_dim, hid, add_self_loops=False, normalize=True) |
|
|
self.gcn_cluster = GCNConv(hid, hid, add_self_loops=True, normalize=True) |
|
|
self.gcn_node2 = GCNConv(hid * 2, out_dim) |
|
|
|
|
|
def forward(self, x: Tensor, edge_index_node: Tensor) -> Tuple[Tensor, Optional[Tensor]]: |
|
|
|
|
|
h1 = F.relu(self.gcn_node1(x, self.edge_index_node, self.edge_weight_node)) |
|
|
h1 = F.dropout(h1, p=self.dropout_p, training=self.training) |
|
|
|
|
|
|
|
|
lrmc_inv_loss = None |
|
|
|
|
|
if self.core_nodes.numel() > 0: |
|
|
core_embeddings = h1[self.core_nodes] |
|
|
|
|
|
avg_embedding = core_embeddings.mean(dim=0, keepdim=True) |
|
|
|
|
|
lrmc_inv_loss = F.mse_loss(core_embeddings, avg_embedding.expand_as(core_embeddings), reduction='mean') |
|
|
|
|
|
|
|
|
|
|
|
z = scatter_mean(h1, self.n2c, dim=0, dim_size=self.K) |
|
|
|
|
|
|
|
|
z2 = F.relu(self.gcn_cluster(z, self.edge_index_c, self.edge_weight_c)) |
|
|
|
|
|
|
|
|
hb = z2[self.n2c] |
|
|
hcat = torch.cat([h1, hb], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
out = self.gcn_node2(hcat, self.edge_index_node, self.edge_weight_node) |
|
|
return out, lrmc_inv_loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def accuracy(logits: Tensor, y: Tensor, mask: Tensor) -> float: |
|
|
pred = logits[mask].argmax(dim=1) |
|
|
return (pred == y[mask]).float().mean().item() |
|
|
|
|
|
|
|
|
def run_train_eval(model: nn.Module, data, epochs=200, lr=0.01, wd=5e-4, lrmc_inv_weight: float = 0.0): |
|
|
opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd) |
|
|
best_val, best_state = 0.0, None |
|
|
for ep in range(1, epochs + 1): |
|
|
model.train() |
|
|
opt.zero_grad(set_to_none=True) |
|
|
|
|
|
|
|
|
res = model(data.x, data.edge_index) |
|
|
|
|
|
current_lrmc_inv_loss = None |
|
|
if isinstance(model, OneClusterPool): |
|
|
logits, current_lrmc_inv_loss = res |
|
|
loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask]) |
|
|
if current_lrmc_inv_loss is not None and lrmc_inv_weight > 0: |
|
|
loss += lrmc_inv_weight * current_lrmc_inv_loss |
|
|
else: |
|
|
logits = res |
|
|
loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask]) |
|
|
|
|
|
loss.backward(); opt.step() |
|
|
|
|
|
|
|
|
model.eval() |
|
|
|
|
|
logits_eval, _ = model(data.x, data.edge_index) if isinstance(model, OneClusterPool) else (model(data.x, data.edge_index), None) |
|
|
|
|
|
val_acc = accuracy(logits_eval, data.y, data.val_mask) |
|
|
if val_acc > best_val: |
|
|
best_val, best_state = val_acc, {k: v.detach().clone() for k, v in model.state_dict().items()} |
|
|
if ep % 20 == 0: |
|
|
tr = accuracy(logits_eval, data.y, data.train_mask) |
|
|
te = accuracy(logits_eval, data.y, data.test_mask) |
|
|
lrmc_loss_str = f" inv_l={current_lrmc_inv_loss.item():.4f}" if current_lrmc_inv_loss is not None else "" |
|
|
print(f"[{ep:04d}] loss={loss.item():.4f}{lrmc_loss_str} train={tr:.3f} val={val_acc:.3f} test={te:.3f}") |
|
|
|
|
|
|
|
|
if best_state is not None: |
|
|
model.load_state_dict(best_state) |
|
|
model.eval() |
|
|
logits_final, _ = model(data.x, data.edge_index) if isinstance(model, OneClusterPool) else (model(data.x, data.edge_index), None) |
|
|
return { |
|
|
"val": accuracy(logits_final, data.y, data.val_mask), |
|
|
"test": accuracy(logits_final, data.y, data.test_mask) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_lrmc_stability_score(core_nodes: Tensor, edge_index: Tensor, |
|
|
num_nodes: int, epsilon: float = 1e-6, |
|
|
unique_undirected: bool = True) -> float: |
|
|
if core_nodes.numel() == 0: |
|
|
return 0.0 |
|
|
|
|
|
|
|
|
out = subgraph(core_nodes, edge_index, |
|
|
relabel_nodes=True, |
|
|
num_nodes=num_nodes, |
|
|
return_edge_mask=True) |
|
|
|
|
|
sub_edge_index = out[0] |
|
|
|
|
|
if sub_edge_index.numel() == 0: |
|
|
return float(core_nodes.numel()) / epsilon |
|
|
|
|
|
if unique_undirected: |
|
|
keep = sub_edge_index[0] < sub_edge_index[1] |
|
|
sub_edge_index = sub_edge_index[:, keep] |
|
|
if sub_edge_index.numel() == 0: |
|
|
return float(core_nodes.numel()) / epsilon |
|
|
|
|
|
Cn = core_nodes.numel() |
|
|
degC = degree(sub_edge_index[0], num_nodes=Cn, dtype=torch.float) + \ |
|
|
degree(sub_edge_index[1], num_nodes=Cn, dtype=torch.float) |
|
|
|
|
|
di = degC[sub_edge_index[0]] |
|
|
dj = degC[sub_edge_index[1]] |
|
|
quad = ((di - dj)**2).sum().item() |
|
|
|
|
|
return float(Cn) / (quad + epsilon) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
ap = argparse.ArgumentParser() |
|
|
ap.add_argument("--dataset", required=True, choices=["Cora", "Citeseer", "Pubmed"]) |
|
|
ap.add_argument("--seeds", required=True, help="Path to LRMC seeds JSON (single large graph).") |
|
|
ap.add_argument("--variant", choices=["baseline", "pool"], default="pool", |
|
|
help="baseline=plain GCN; pool=top-1 LRMC one-cluster pooling (with new L-RMC tricks)") |
|
|
ap.add_argument("--hidden", type=int, default=128) |
|
|
ap.add_argument("--epochs", type=int, default=200) |
|
|
ap.add_argument("--lr", type=float, default=0.01) |
|
|
ap.add_argument("--wd", type=float, default=5e-4) |
|
|
ap.add_argument("--dropout", type=float, default=0.5, help="Dropout rate for GCN layers.") |
|
|
ap.add_argument("--self_loop_scale", type=float, default=0.0, help="λ for A+λI on node graph (0 disables)") |
|
|
ap.add_argument("--use_a2", action="store_true", help="Use A^2 to build the cluster graph (recommended for pool)") |
|
|
ap.add_argument("--lrmc_inv_weight", type=float, default=0.0, |
|
|
help="Weight for the backbone-invariance regularizer (0 disables).") |
|
|
ap.add_argument("--lrmc_gamma", type=float, default=1.0, |
|
|
help="Damping factor for cross-boundary edges (1.0 means no damping).") |
|
|
ap.add_argument("--seed", type=int, default=42) |
|
|
args = ap.parse_args() |
|
|
|
|
|
torch.manual_seed(args.seed) |
|
|
|
|
|
|
|
|
ds = Planetoid(root=f"./data/{args.dataset}", name=args.dataset) |
|
|
data = ds[0] |
|
|
in_dim, out_dim, n = ds.num_node_features, ds.num_classes, data.num_nodes |
|
|
|
|
|
if args.variant == "baseline": |
|
|
model = GCN2(in_dim, args.hidden, out_dim, dropout_p=args.dropout) |
|
|
|
|
|
res = run_train_eval(model, data, epochs=args.epochs, lr=args.lr, wd=args.wd) |
|
|
print(f"Baseline GCN: val={res['val']:.4f} test={res['test']:.4f}") |
|
|
return |
|
|
|
|
|
|
|
|
node2cluster, _, core_nodes = load_top1_assignment(args.seeds, n) |
|
|
|
|
|
|
|
|
lrmc_score = compute_lrmc_stability_score(core_nodes, data.edge_index, n) |
|
|
print(f"LRMC Top-1 Cluster Size: {core_nodes.numel()} nodes. Stability Score (S_L): {lrmc_score:.4f}") |
|
|
|
|
|
|
|
|
|
|
|
model = OneClusterPool(in_dim=in_dim, |
|
|
hid=args.hidden, |
|
|
out_dim=out_dim, |
|
|
node2cluster=node2cluster, |
|
|
core_nodes=core_nodes, |
|
|
edge_index_node=data.edge_index, |
|
|
num_nodes=n, |
|
|
self_loop_scale=args.self_loop_scale, |
|
|
use_a2_for_clusters=args.use_a2, |
|
|
lrmc_gamma=args.lrmc_gamma, |
|
|
dropout_p=args.dropout) |
|
|
|
|
|
res = run_train_eval(model, data, epochs=args.epochs, lr=args.lr, wd=args.wd, |
|
|
lrmc_inv_weight=args.lrmc_inv_weight) |
|
|
print(f"L-RMC (top-1 pool with tricks): val={res['val']:.4f} test={res['test']:.4f}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|