clique / src /old /train_gcn_lrmc_pool.py
qingy2024's picture
Upload folder using huggingface_hub
f74dd01 verified
# sweep_lowlab_compress.py
# Low-label and compression sweeps for L-RMC, DiffPool, and gPool on Cora.
import json
import math
import random
from pathlib import Path
from statistics import mean, pstdev
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, DenseGCNConv
from torch_geometric.nn.dense import dense_diff_pool
from rich import print
# ============ Paths ============
SEEDS_JSON = "../seeds_diam_1e-6.json" # your L-RMC export
CORA_CONTENT = "../cora/cora.content"
CORA_CITES = "../cora/cora.cites"
# ============ Sweep settings ============
LABEL_BUDGETS = [20, 10, 5, 3] # train_per_class
K_RATIOS = [0.10, 0.20, 0.40, 0.80] # K / N target
SEEDS = [0, 1, 2, 3, 4] # random seeds per cell
# ============ Train hyperparams ============
HIDDEN = 64
DROPOUT = 0.5
LR = 0.01
WEIGHT_DECAY = 5e-4
EPOCHS = 300
PATIENCE = 50
# DiffPool extras
DIFFPOOL_AUX_WEIGHT = 1e-2 # link + entropy regularizers
# ============ Utils ============
def set_seed(seed: int):
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def to_undirected(edge_index, num_nodes):
# Unique undirected edges without self loops
edges = edge_index.t().tolist()
uniq = set()
out = []
for u, v in edges:
if u == v:
continue
a, b = (u, v) if u < v else (v, u)
key = (a, b)
if key not in uniq:
uniq.add(key)
out.append([a, b])
if not out:
return torch.empty((2, 0), dtype=torch.long)
return torch.tensor(out, dtype=torch.long).t().contiguous()
def macro_f1_from_logits(logits, y, mask):
with torch.no_grad():
pred = logits.argmax(dim=1)
y_ = y[mask]
p_ = pred[mask]
C = int(y.max().item() + 1)
cm = torch.zeros((C, C), dtype=torch.long, device=logits.device)
for t, q in zip(y_, p_):
cm[t, q] += 1
eps = 1e-12
tp = cm.diag().to(torch.float)
fp = cm.sum(dim=0).to(torch.float) - tp
fn = cm.sum(dim=1).to(torch.float) - tp
precision = tp / (tp + fp + eps)
recall = tp / (tp + fn + eps)
f1 = 2 * precision * recall / (precision + recall + eps)
present = cm.sum(dim=1) > 0
return f1[present].mean().item() if present.any() else 0.0
def accuracy_from_logits(logits, y, mask):
with torch.no_grad():
pred = logits.argmax(dim=1)
correct = (pred[mask] == y[mask]).sum().item()
total = int(mask.sum().item())
return correct / max(total, 1)
# ============ Data ============
def load_cora_from_content_and_cites(content_path: str, cites_path: str):
lines = Path(content_path).read_text().strip().splitlines()
n = len(lines)
paper_ids, features, labels_raw = [], [], []
for line in lines:
toks = line.strip().split()
paper_ids.append(toks[0])
labels_raw.append(toks[-1])
features.append([int(x) for x in toks[1:-1]])
classes = sorted(set(labels_raw))
cls2idx = {c: i for i, c in enumerate(classes)}
y = torch.tensor([cls2idx[c] for c in labels_raw], dtype=torch.long)
x = torch.tensor(features, dtype=torch.float)
id2idx = {pid: i for i, pid in enumerate(paper_ids)}
edges = []
for line in Path(cites_path).read_text().strip().splitlines():
a, b = line.strip().split()
if a in id2idx and b in id2idx:
edges.append((id2idx[a], id2idx[b]))
if not edges:
raise RuntimeError("No edges from cites file.")
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
edge_index = to_undirected(edge_index, n)
data = Data(x=x, edge_index=edge_index, y=y)
data.num_nodes = n
data.num_classes = len(classes)
return data
def make_planetoid_style_split(y, num_classes, train_per_class=20, val_size=500, test_size=1000):
N = y.size(0)
all_idx = torch.arange(N)
train_mask = torch.zeros(N, dtype=torch.bool)
val_mask = torch.zeros(N, dtype=torch.bool)
test_mask = torch.zeros(N, dtype=torch.bool)
for c in range(num_classes):
idx_c = all_idx[(y == c)]
if idx_c.numel() == 0:
continue
sel = idx_c[torch.randperm(idx_c.numel())[: min(train_per_class, idx_c.numel())]]
train_mask[sel] = True
remaining = all_idx[~train_mask]
remaining = remaining[torch.randperm(remaining.numel())]
val_k = min(val_size, remaining.numel())
val_mask[remaining[:val_k]] = True
rem2 = remaining[val_k:]
test_k = min(test_size, rem2.numel())
test_mask[rem2[:test_k]] = True
return train_mask, val_mask, test_mask
# ============ L-RMC seeds and pooling ============
def load_lrmc_partition(path: str, num_nodes: int):
obj = json.loads(Path(path).read_text())
clusters = obj["clusters"]
cid_of_node = {}
for c in clusters:
cid = int(c["cluster_id"])
for u in c["members"]:
cid_of_node[int(u)] = cid
cluster_id = torch.full((num_nodes,), -1, dtype=torch.long)
for u, cid in cid_of_node.items():
if 0 <= u < num_nodes:
cluster_id[u] = cid
if (cluster_id < 0).any():
miss = int((cluster_id < 0).sum().item())
raise RuntimeError(f"{miss} nodes not covered by seeds.")
K = int(cluster_id.max().item() + 1)
return cluster_id, K
def pool_by_partition_weighted(x, edge_index, cluster_id, K):
if x.dim() != 2:
raise ValueError(f"Expected x to have shape [N, F], got {x.shape}")
if cluster_id.shape != (x.shape[0],):
raise ValueError(f"Expected cluster_id to have shape [{x.shape[0]}], got {cluster_id.shape}")
sums = torch.zeros((K, x.size(1)), device=x.device, dtype=x.dtype)
sums.index_add_(0, cluster_id, x)
counts = torch.bincount(cluster_id, minlength=K).clamp_min(1).to(x.device).unsqueeze(1).to(x.dtype)
x_pooled = sums / counts
cu = cluster_id[edge_index[0]]
cv = cluster_id[edge_index[1]]
pairs = torch.stack([cu, cv], dim=1)
uniq, w = torch.unique(pairs, dim=0, return_counts=True)
mask = uniq[:, 0] != uniq[:, 1]
edge_index_pooled = uniq[mask].t().contiguous()
edge_weight = w[mask].to(torch.float)
return x_pooled, edge_index_pooled, edge_weight
def compress_partition_to_K(cluster_id, K_target, edge_index):
cid = cluster_id.clone()
K_now = int(cid.max().item() + 1)
if K_now <= K_target:
return cid, K_now
sizes = torch.bincount(cid, minlength=K_now)
kept = set(int(k) for k in torch.topk(sizes, K_target).indices.tolist())
# inter-cluster weights
cu = cid[edge_index[0]].tolist()
cv = cid[edge_index[1]].tolist()
w = {}
for a, b in zip(cu, cv):
if a == b:
continue
w[(a, b)] = w.get((a, b), 0) + 1
w[(b, a)] = w.get((b, a), 0) + 1
mapping = {}
largest_kept = max(kept, key=lambda k: sizes[k].item())
for c in range(K_now):
if c in kept:
mapping[c] = c
else:
candidates = [(w.get((c, k), 0), k) for k in kept]
mapping[c] = max(candidates)[1] if candidates else largest_kept
for i in range(cid.numel()):
cid[i] = mapping[int(cid[i].item())]
kept_sorted = sorted(set(int(x) for x in cid.tolist()))
remap = {old: new for new, old in enumerate(kept_sorted)}
for i in range(cid.numel()):
cid[i] = remap[int(cid[i].item())]
return cid, len(kept_sorted)
# ============ Models ============
class LrmcSeededPoolGCN(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim, cluster_id, K, dropout=0.5):
super().__init__()
self.conv1 = GCNConv(in_dim, hidden_dim, add_self_loops=True, normalize=True)
self.conv2 = GCNConv(hidden_dim, out_dim, add_self_loops=True, normalize=True)
self.lin_skip = nn.Linear(hidden_dim, out_dim, bias=True)
self.score = nn.Linear(hidden_dim, 1, bias=False)
self.dropout = dropout
self.register_buffer("cluster_id", cluster_id)
self.K = K
def forward(self, x, edge_index):
if x.dim() != 2:
raise ValueError(f"Expected x to have shape [N, F], got {x.shape}")
x1 = F.relu(self.conv1(x, edge_index))
if x1.shape[1] != HIDDEN:
raise ValueError(f"Expected x1 to have shape [N, {HIDDEN}], got {x1.shape}")
x1 = F.dropout(x1, p=self.dropout, training=self.training)
gate = torch.tanh(self.score(x1)) # Remove .unsqueeze(-1)
if gate.shape != (x1.shape[0], 1):
raise ValueError(f"Expected gate to have shape [{x1.shape[0]}, 1], got {gate.shape}")
x1_g = x1 * gate
if x1_g.shape != x1.shape:
raise ValueError(f"Expected x1_g to have shape {x1.shape}, got {x1_g.shape}")
x_p, ei_p, ew_p = pool_by_partition_weighted(x1_g, edge_index, self.cluster_id, self.K)
x_p = self.conv2(x_p, ei_p, edge_weight=ew_p)
up = x_p[self.cluster_id]
skip = self.lin_skip(x1)
logits = up + skip
return logits, 0.0
class TopKPoolBroadcastGCN(nn.Module):
# gPool-style: learn scores, keep K, assign dropped to nearest kept by degree, weighted pooled GCN + skip.
def __init__(self, in_dim, hidden_dim, out_dim, K_target, dropout=0.5):
super().__init__()
self.conv1 = GCNConv(in_dim, hidden_dim, add_self_loops=True, normalize=True)
self.conv2 = GCNConv(hidden_dim, out_dim, add_self_loops=True, normalize=True)
self.lin_skip = nn.Linear(hidden_dim, out_dim, bias=True)
self.score = nn.Linear(hidden_dim, 1, bias=False)
self.dropout = dropout
self.K_target = K_target
@staticmethod
def _degrees(edge_index, N):
return torch.bincount(edge_index[0], minlength=N).to(torch.long)
def forward(self, x, edge_index):
N = x.size(0)
x1 = F.relu(self.conv1(x, edge_index))
x1 = F.dropout(x1, p=self.dropout, training=self.training)
raw = self.score(x1).squeeze(-1)
gate = torch.tanh(raw).unsqueeze(-1)
x1_g = x1 * gate
K = min(self.K_target, N)
kept = torch.topk(raw, K, sorted=True).indices
keep_mask = torch.zeros(N, dtype=torch.bool, device=x.device); keep_mask[kept] = True
deg = self._degrees(edge_index, N).to(x.device)
u_list, v_list = edge_index[0].tolist(), edge_index[1].tolist()
neigh = [[] for _ in range(N)]
for a, b in zip(u_list, v_list):
neigh[a].append(b); neigh[b].append(a)
cluster_id = torch.full((N,), -1, dtype=torch.long, device=x.device)
cluster_id[kept] = torch.arange(kept.numel(), device=x.device, dtype=torch.long)
best_global_kept = kept[torch.argmax(deg[kept])].item() if kept.numel() > 0 else 0
for u in range(N):
if keep_mask[u]:
continue
cand = [w for w in neigh[u] if keep_mask[w]]
cluster_id[u] = cluster_id[max(cand, key=lambda z: int(deg[z].item()))] if cand else cluster_id[best_global_kept]
Kc = int(cluster_id.max().item() + 1)
x_p, ei_p, ew_p = pool_by_partition_weighted(x1_g, edge_index, cluster_id, Kc)
x_p = self.conv2(x_p, ei_p, edge_weight=ew_p)
up = x_p[cluster_id]
skip = self.lin_skip(x1)
logits = up + skip
return logits, 0.0
class DiffPoolGCNNode(nn.Module):
# One DiffPool layer with K clusters, skip head to nodes.
def __init__(self, in_dim, hidden_dim, out_dim, K_clusters, dropout=0.5):
super().__init__()
self.dropout = dropout
self.K = K_clusters
self.gnn_embed1 = DenseGCNConv(in_dim, hidden_dim)
self.gnn_embed2 = DenseGCNConv(hidden_dim, hidden_dim)
self.gnn_assign1 = DenseGCNConv(in_dim, hidden_dim)
self.gnn_assign2 = DenseGCNConv(hidden_dim, K_clusters)
self.gnn_post1 = DenseGCNConv(hidden_dim, hidden_dim)
self.gnn_post2 = DenseGCNConv(hidden_dim, out_dim)
self.lin_skip = nn.Linear(hidden_dim, out_dim, bias=True)
def forward(self, x, edge_index):
N, device = x.size(0), x.device
adj_dense = torch.zeros((N, N), device=device)
adj_dense[edge_index[0], edge_index[1]] = 1.0
idx = torch.arange(N, device=device)
adj_dense[idx, idx] = 1.0
x = x.unsqueeze(0) # [1, N, F]
adj = adj_dense.unsqueeze(0) # [1, N, N]
mask = torch.ones((1, N), device=device)
z = F.relu(self.gnn_embed1(x, adj, mask))
z = F.dropout(z, p=self.dropout, training=self.training)
z = F.relu(self.gnn_embed2(z, adj, mask))
s = F.relu(self.gnn_assign1(x, adj, mask))
s = F.dropout(s, p=self.dropout, training=self.training)
s = self.gnn_assign2(s, adj, mask).softmax(dim=-1) # [1, N, K]
x_pool, adj_pool, link_loss, ent_loss = dense_diff_pool(z, adj, s, mask)
h = F.relu(self.gnn_post1(x_pool, adj_pool))
h = F.dropout(h, p=self.dropout, training=self.training)
h = self.gnn_post2(h, adj_pool) # [1, K, C]
skip = self.lin_skip(z.squeeze(0)) # [N, C]
logits_nodes = torch.matmul(s.squeeze(0), h.squeeze(0)) + skip
aux_loss = link_loss + ent_loss
return logits_nodes, aux_loss
# ============ Train ============
def train_one(model, data, train_mask, val_mask, test_mask, device, aux_weight=0.0):
model = model.to(device)
data = data.to(device)
train_mask = train_mask.to(device)
val_mask = val_mask.to(device)
test_mask = test_mask.to(device)
opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
best_state = None
best_val = -math.inf
bad = 0
for epoch in range(1, EPOCHS + 1):
model.train()
opt.zero_grad()
logits, aux_loss = model(data.x, data.edge_index)
loss = F.cross_entropy(logits[train_mask], data.y[train_mask])
if aux_weight > 0.0:
loss = loss + aux_weight * aux_loss
loss.backward()
opt.step()
model.eval()
with torch.no_grad():
logits, _ = model(data.x, data.edge_index)
val_metric = accuracy_from_logits(logits, data.y, val_mask)
if val_metric > best_val:
best_val = val_metric
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
bad = 0
else:
bad += 1
if bad >= PATIENCE:
break
if best_state is not None:
model.load_state_dict({k: v.to(device) for k, v in best_state.items()})
model.eval()
with torch.no_grad():
logits, _ = model(data.x, data.edge_index)
test_acc = accuracy_from_logits(logits, data.y, test_mask)
test_f1 = macro_f1_from_logits(logits, data.y, test_mask)
return test_acc, test_f1
# ============ Sweep runner ============
def run_sweeps():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = load_cora_from_content_and_cites(CORA_CONTENT, CORA_CITES)
N = data.num_nodes
cluster_id_full, K_full = load_lrmc_partition(SEEDS_JSON, data.num_nodes)
print(f"Loaded Cora: N={data.num_nodes}, E={data.edge_index.size(1)}, F={data.num_features}, C={data.num_classes}")
print(f"L-RMC base K = {K_full} (K/N = {K_full/N:.3f})")
print("\nResults averaged over seeds:", SEEDS)
print("tpc, K/N, K, Method, acc_mean, acc_std, f1_mean, f1_std")
for tpc in LABEL_BUDGETS:
for ratio in K_RATIOS:
K_target = max(1, int(ratio * N))
accs = { "LRMC": [], "gPool": [], "DiffPool": [] }
f1s = { "LRMC": [], "gPool": [], "DiffPool": [] }
for s in SEEDS:
set_seed(s)
train_mask, val_mask, test_mask = make_planetoid_style_split(
data.y, data.num_classes, train_per_class=tpc, val_size=500, test_size=1000
)
# Equal K across methods
cid_eq, K_eq = compress_partition_to_K(cluster_id_full, K_target, data.edge_index)
# L-RMC
lrmc_model = LrmcSeededPoolGCN(
in_dim=data.num_features, hidden_dim=HIDDEN, out_dim=data.num_classes,
cluster_id=cid_eq.to(data.x.device), K=K_eq, dropout=DROPOUT,
)
a, f = train_one(lrmc_model, data, train_mask, val_mask, test_mask, device)
accs["LRMC"].append(a); f1s["LRMC"].append(f)
# gPool
g_model = TopKPoolBroadcastGCN(
in_dim=data.num_features, hidden_dim=HIDDEN, out_dim=data.num_classes,
K_target=K_eq, dropout=DROPOUT,
)
a, f = train_one(g_model, data, train_mask, val_mask, test_mask, device)
accs["gPool"].append(a); f1s["gPool"].append(f)
# DiffPool
d_model = DiffPoolGCNNode(
in_dim=data.num_features, hidden_dim=HIDDEN, out_dim=data.num_classes,
K_clusters=K_eq, dropout=0.3, # a little lower dropout helps DiffPool
)
a, f = train_one(d_model, data, train_mask, val_mask, test_mask, device,
aux_weight=DIFFPOOL_AUX_WEIGHT)
accs["DiffPool"].append(a); f1s["DiffPool"].append(f)
def ms(x): # mean, std
return mean(x), (0.0 if len(x) < 2 else pstdev(x))
for name in ["LRMC", "gPool", "DiffPool"]:
am, asd = ms(accs[name])
fm, fsd = ms(f1s[name])
print(f"{tpc:3d}, {ratio:0.2f}, {K_eq:4d}, {name:7s}, "
f"{am:.3f}, {asd:.3f}, {fm:.3f}, {fsd:.3f}")
if __name__ == "__main__":
run_sweeps()