File size: 13,287 Bytes
bf620c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 |
# gcn_lrmc_node_classify.py
# Node classification with GCN + L-RMC (static pooling + unpool + skip)
# Usage:
# python gcn_lrmc_node_classify.py --dataset Cora --lrmc_json /path/to/lrmc_seeds.json
# Options:
# --use_a2 true|false (default true; use A^2 before pooling as in Graph U-Nets)
# --epochs 200 --lr 0.005 --hidden 64 --cluster_hidden 64 --dropout 0.5
import argparse, json, os
import numpy as np, torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv, BatchNorm
from torch_geometric.utils import coalesce, to_undirected, remove_self_loops
from torch_geometric.utils import add_self_loops
from torch_scatter import scatter_mean
from torch_sparse import spspmm
# -----------------------------
# L-RMC assignment utilities
# -----------------------------
def load_lrmc_assignment(json_path, num_nodes):
"""
Build a single hard assignment: node -> cluster_id in [0, K-1].
If nodes appear in multiple clusters, keep the one with highest 'score'.
If any nodes are unassigned, put them into their own singleton clusters.
Returns:
assignment: LongTensor [num_nodes] with cluster ids
clusters: list of lists (members per cluster) aligned to remapped cluster ids
"""
with open(json_path, 'r') as f:
seeds = json.load(f)
clusters_raw = seeds.get("clusters", [])
# Sort clusters by score descending to prefer higher-scoring clusters on conflicts
clusters_raw = sorted(clusters_raw, key=lambda c: float(c.get("score", 0.0)), reverse=True)
chosen_cluster_for_node = [-1] * num_nodes
tmp_clusters = [] # will collect chosen clusters (members), before remap
for c in clusters_raw:
members = c.get("members", [])
# skip empty
if not members:
continue
# take only members not yet assigned
new_members = [u for u in members if 0 <= u < num_nodes and chosen_cluster_for_node[u] == -1]
if not new_members:
continue
# tentatively assign this cluster to those nodes (others in the cluster were already taken)
tmp_clusters.append(new_members)
cid = len(tmp_clusters) - 1
for u in new_members:
chosen_cluster_for_node[u] = cid
# Any nodes still -1 → singleton clusters
for u in range(num_nodes):
if chosen_cluster_for_node[u] == -1:
tmp_clusters.append([u])
cid = len(tmp_clusters) - 1
chosen_cluster_for_node[u] = cid
# Remap cluster ids to [0..K-1] (already contiguous by construction)
assignment = torch.tensor(chosen_cluster_for_node, dtype=torch.long)
clusters = tmp_clusters
return assignment, clusters
def lrmc_stats(assignment, clusters, edge_index):
N = assignment.numel(); K = int(assignment.max()) + 1
sizes = [len(c) for c in clusters]
sing = sum(1 for s in sizes if s==1)
print(f"[L-RMC] N={N} K={K} mean|C|={np.mean(sizes):.2f} "
f"median|C|={np.median(sizes):.0f} singleton%={100*sing/K:.1f}%")
# how many edges are intra-cluster?
same = (assignment[edge_index[0]] == assignment[edge_index[1]]).sum().item()
print(f"[L-RMC] intra-cluster edge ratio = {same/edge_index.size(1):.3f}")
# -----------------------------
# Graph helpers
# -----------------------------
def compute_A2_union(edge_index, num_nodes, device):
"""
Compute A^2 (binary) and return union edges A OR A^2, undirected & coalesced.
"""
# Make undirected and coalesced (no weights)
ei = to_undirected(coalesce(edge_index, num_nodes=num_nodes), num_nodes=num_nodes)
# Build ones weights for sparse-sparse multiply
E = ei.size(1)
if E == 0:
return ei # empty graph
val = torch.ones(E, device=device)
# spspmm: (m x k) @ (k x n) where here m=n=k=num_nodes
ei2, val2 = spspmm(ei, val, ei, val, num_nodes, num_nodes, num_nodes)
# Remove self-loops from A2 (optional; GCNConv adds its own self-loops later)
ei2, _ = remove_self_loops(ei2)
# Binarize & union with A
# (coalesce later will drop duplicates anyway)
ei_aug = torch.cat([ei, ei2], dim=1)
ei_aug = to_undirected(coalesce(ei_aug, num_nodes=num_nodes), num_nodes=num_nodes)
return ei_aug
def build_cluster_edges(edge_index_aug, assignment, num_clusters):
"""
Map node edges to cluster edges: (u,v) -> (c(u), c(v)), undirected + coalesced.
"""
c_src = assignment[edge_index_aug[0]]
c_dst = assignment[edge_index_aug[1]]
c_ei = torch.stack([c_src, c_dst], dim=0)
c_ei = to_undirected(coalesce(c_ei, num_nodes=num_clusters), num_nodes=num_clusters)
return c_ei
# -----------------------------
# Model
# -----------------------------
class Gate(nn.Module):
def __init__(self, d_enc, d_c):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(d_enc + d_c, d_enc, bias=True),
nn.ReLU(),
nn.Linear(d_enc, d_enc, bias=True),
nn.Sigmoid(),
)
def forward(self, h_enc, h_cluster_broadcast):
g = self.mlp(torch.cat([h_enc, h_cluster_broadcast], dim=-1))
return h_enc + g * h_cluster_broadcast # residual gated add
class GCN_LRMC_NodeClassifier(nn.Module):
"""
Encoder: GCN -> GCN on original graph
Pool: aggregate encoder features per L-RMC cluster
Coarse: GCN -> (optional GCN) on cluster graph
Unpool: broadcast cluster features back to nodes
Decoder: GCN (on original graph) -> logits
"""
def __init__(self, in_dim, hidden_dim, cluster_hidden_dim, out_dim,
edge_index, assignment, cluster_edge_index, dropout=0.5):
super().__init__()
self.edge_index = edge_index # original graph edges
self.assignment = assignment # [N]
self.cluster_edge_index = cluster_edge_index # edges on cluster graph
self.num_clusters = int(assignment.max().item() + 1)
self.dropout = dropout
# Encoder on node graph
self.enc1 = GCNConv(in_dim, hidden_dim, improved=True)
self.enc2 = GCNConv(hidden_dim, hidden_dim, improved=True)
# GCN(s) on cluster graph
self.cgc1 = GCNConv(hidden_dim, cluster_hidden_dim, improved=True)
self.cgc2 = GCNConv(cluster_hidden_dim, cluster_hidden_dim, improved=True)
# Decoder on node graph (combine skip from encoder + broadcast from cluster)
dec_in = hidden_dim + cluster_hidden_dim
self.dec1 = GCNConv(dec_in, hidden_dim, improved=True)
self.cls = GCNConv(hidden_dim, out_dim, improved=True) # final logits
self.bn_e1 = BatchNorm(hidden_dim)
self.bn_e2 = BatchNorm(hidden_dim)
self.bn_c1 = BatchNorm(cluster_hidden_dim)
self.bn_c2 = BatchNorm(cluster_hidden_dim)
self.bn_d1 = BatchNorm(hidden_dim)
self.gate = Gate(hidden_dim, cluster_hidden_dim)
def forward(self, x):
# Encoder on original graph
h = F.dropout(x, p=self.dropout, training=self.training)
h = F.relu(self.bn_e1(self.enc1(h, self.edge_index)))
h = F.dropout(h, p=self.dropout, training=self.training)
h2 = F.relu(self.bn_e2(self.enc2(h, self.edge_index)))
h = h + h2
h_enc = h # skip for decoder
# Pool: aggregate encoder features to clusters (mean)
# cluster_x: [K, hidden_dim]
cluster_x = scatter_mean(h_enc, self.assignment, dim=0, dim_size=self.num_clusters)
# Coarse GCN(s) on cluster graph
hc = F.dropout(cluster_x, p=self.dropout, training=self.training)
hc = F.relu(self.bn_c1(self.cgc1(cluster_x, self.cluster_edge_index)))
hc = F.dropout(hc, p=self.dropout, training=self.training)
hc2 = F.relu(self.bn_c2(self.cgc2(hc, self.cluster_edge_index)))
hc = hc + hc2
# Unpool: broadcast coarse features back to nodes via assignment
hc_broadcast = hc[self.assignment] # [N, cluster_hidden_dim]
# # after hc_broadcast is computed
# g_in = torch.cat([h_enc, hc_broadcast], dim=1)
# gate = torch.sigmoid(nn.Linear(g_in.size(1), h_enc.size(1)).to(g_in.device)(g_in))
# h_dec_in = h_enc + gate * hc_broadcast # gated residual instead of concat
# Decoder on original graph
h_dec_in = torch.cat([h_enc, hc_broadcast], dim=1) # [N, hidden_dim + cluster_hidden_dim]
h = F.dropout(h_dec_in, p=self.dropout, training=self.training)
h = F.relu(self.dec1(h, self.edge_index))
h = F.dropout(h, p=self.dropout, training=self.training)
out = self.cls(h, self.edge_index) # logits [N, C]
return out
# -----------------------------
# Train / Eval
# -----------------------------
@torch.no_grad()
def evaluate(model, data):
model.eval()
out = model(data.x)
y = data.y
pred = out.argmax(dim=-1)
def acc(mask):
m = mask if mask.dtype == torch.bool else mask.bool()
if m.sum() == 0:
return 0.0
return (pred[m] == y[m]).float().mean().item()
return acc(data.train_mask), acc(data.val_mask), acc(data.test_mask)
def train_loop(model, data, epochs=200, lr=5e-3, weight_decay=5e-4, patience=100):
optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
best_val, best_test = 0.0, 0.0
best_state = None
no_improve = 0
for epoch in range(1, epochs + 1):
model.train()
optimizer.zero_grad()
logits = model(data.x)
loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
tr, va, te = evaluate(model, data)
if va > best_val:
best_val, best_test = va, te
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
no_improve = 0
else:
no_improve += 1
print(f"Epoch {epoch:03d} | loss={loss.item():.4f} | "
f"train={tr*100:.2f}% val={va*100:.2f}% test={te*100:.2f}% test@best={best_test*100:.2f}%")
if no_improve >= patience:
print(f"Early stopping at epoch {epoch} (no val improvement for {patience})")
break
if best_state is not None:
model.load_state_dict(best_state)
tr, va, te = evaluate(model, data)
print(f"\nFinal (reloaded best): train={tr*100:.2f}% val={va*100:.2f}% test={te*100:.2f}%")
return te
# -----------------------------
# Main
# -----------------------------
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="Cora", choices=["Cora", "Citeseer", "Pubmed"])
parser.add_argument("--lrmc_json", type=str, required=True)
parser.add_argument("--use_a2", type=str, default="true", help="Use A^2 before pooling (true/false)")
parser.add_argument("--hidden", type=int, default=64)
parser.add_argument("--cluster_hidden", type=int, default=64)
parser.add_argument("--dropout", type=float, default=0.5)
parser.add_argument("--epochs", type=int, default=200)
parser.add_argument("--lr", type=float, default=5e-3)
parser.add_argument("--weight_decay", type=float, default=5e-4)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
torch.manual_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = Planetoid(root=os.path.join("data", args.dataset), name=args.dataset)
data = dataset[0].to(device)
num_nodes = data.num_nodes
in_dim = dataset.num_node_features
out_dim = dataset.num_classes
# Load L-RMC assignment
assignment, clusters = load_lrmc_assignment(args.lrmc_json, num_nodes)
assignment = assignment.to(device)
num_clusters = int(assignment.max().item() + 1)
print(f"[L-RMC] Loaded clusters: K={num_clusters} (N={num_nodes})")
lrmc_stats(assignment, clusters, data.edge_index)
# Build augmented node edge_index (A or A^2 ∪ A), then cluster edges
use_a2 = args.use_a2.lower() in ("1", "true", "yes", "y")
if use_a2:
edge_index_aug = compute_A2_union(data.edge_index, num_nodes, device)
print("[L-RMC] Using A^2 ∪ A before pooling (connectivity augmentation).")
else:
edge_index_aug = to_undirected(coalesce(data.edge_index, num_nodes=num_nodes), num_nodes=num_nodes)
print("[L-RMC] Using original A for pooling.")
cluster_edge_index = build_cluster_edges(edge_index_aug, assignment, num_clusters)
# Build model
model = GCN_LRMC_NodeClassifier(
in_dim=in_dim,
hidden_dim=args.hidden,
cluster_hidden_dim=args.cluster_hidden,
out_dim=out_dim,
edge_index=data.edge_index, # original graph for enc/dec
assignment=assignment, # node -> cluster
cluster_edge_index=cluster_edge_index, # cluster graph for coarse GCN
dropout=args.dropout,
).to(device)
# Train / evaluate
test_acc = train_loop(model, data, epochs=args.epochs, lr=args.lr,
weight_decay=args.weight_decay, patience=100)
if __name__ == "__main__":
main()
|