|
|
|
|
|
|
|
|
|
|
|
import json |
|
|
import math |
|
|
import random |
|
|
from pathlib import Path |
|
|
from statistics import mean, pstdev |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from torch_geometric.data import Data |
|
|
from torch_geometric.nn import GCNConv, DenseGCNConv |
|
|
from torch_geometric.nn.dense import dense_diff_pool |
|
|
|
|
|
from rich import print |
|
|
|
|
|
|
|
|
SEEDS_JSON = "../seeds_diam_1e-6.json" |
|
|
CORA_CONTENT = "../cora/cora.content" |
|
|
CORA_CITES = "../cora/cora.cites" |
|
|
|
|
|
|
|
|
LABEL_BUDGETS = [20, 10, 5, 3] |
|
|
K_RATIOS = [0.10, 0.20, 0.40, 0.80] |
|
|
SEEDS = [0, 1, 2, 3, 4] |
|
|
|
|
|
|
|
|
HIDDEN = 64 |
|
|
DROPOUT = 0.5 |
|
|
LR = 0.01 |
|
|
WEIGHT_DECAY = 5e-4 |
|
|
EPOCHS = 300 |
|
|
PATIENCE = 50 |
|
|
|
|
|
|
|
|
DIFFPOOL_AUX_WEIGHT = 1e-2 |
|
|
|
|
|
|
|
|
def set_seed(seed: int): |
|
|
random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
def to_undirected(edge_index, num_nodes): |
|
|
|
|
|
edges = edge_index.t().tolist() |
|
|
uniq = set() |
|
|
out = [] |
|
|
for u, v in edges: |
|
|
if u == v: |
|
|
continue |
|
|
a, b = (u, v) if u < v else (v, u) |
|
|
key = (a, b) |
|
|
if key not in uniq: |
|
|
uniq.add(key) |
|
|
out.append([a, b]) |
|
|
if not out: |
|
|
return torch.empty((2, 0), dtype=torch.long) |
|
|
return torch.tensor(out, dtype=torch.long).t().contiguous() |
|
|
|
|
|
def macro_f1_from_logits(logits, y, mask): |
|
|
with torch.no_grad(): |
|
|
pred = logits.argmax(dim=1) |
|
|
y_ = y[mask] |
|
|
p_ = pred[mask] |
|
|
C = int(y.max().item() + 1) |
|
|
cm = torch.zeros((C, C), dtype=torch.long, device=logits.device) |
|
|
for t, q in zip(y_, p_): |
|
|
cm[t, q] += 1 |
|
|
eps = 1e-12 |
|
|
tp = cm.diag().to(torch.float) |
|
|
fp = cm.sum(dim=0).to(torch.float) - tp |
|
|
fn = cm.sum(dim=1).to(torch.float) - tp |
|
|
precision = tp / (tp + fp + eps) |
|
|
recall = tp / (tp + fn + eps) |
|
|
f1 = 2 * precision * recall / (precision + recall + eps) |
|
|
present = cm.sum(dim=1) > 0 |
|
|
return f1[present].mean().item() if present.any() else 0.0 |
|
|
|
|
|
def accuracy_from_logits(logits, y, mask): |
|
|
with torch.no_grad(): |
|
|
pred = logits.argmax(dim=1) |
|
|
correct = (pred[mask] == y[mask]).sum().item() |
|
|
total = int(mask.sum().item()) |
|
|
return correct / max(total, 1) |
|
|
|
|
|
|
|
|
def load_cora_from_content_and_cites(content_path: str, cites_path: str): |
|
|
lines = Path(content_path).read_text().strip().splitlines() |
|
|
n = len(lines) |
|
|
paper_ids, features, labels_raw = [], [], [] |
|
|
for line in lines: |
|
|
toks = line.strip().split() |
|
|
paper_ids.append(toks[0]) |
|
|
labels_raw.append(toks[-1]) |
|
|
features.append([int(x) for x in toks[1:-1]]) |
|
|
classes = sorted(set(labels_raw)) |
|
|
cls2idx = {c: i for i, c in enumerate(classes)} |
|
|
y = torch.tensor([cls2idx[c] for c in labels_raw], dtype=torch.long) |
|
|
x = torch.tensor(features, dtype=torch.float) |
|
|
|
|
|
id2idx = {pid: i for i, pid in enumerate(paper_ids)} |
|
|
edges = [] |
|
|
for line in Path(cites_path).read_text().strip().splitlines(): |
|
|
a, b = line.strip().split() |
|
|
if a in id2idx and b in id2idx: |
|
|
edges.append((id2idx[a], id2idx[b])) |
|
|
if not edges: |
|
|
raise RuntimeError("No edges from cites file.") |
|
|
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() |
|
|
edge_index = to_undirected(edge_index, n) |
|
|
|
|
|
data = Data(x=x, edge_index=edge_index, y=y) |
|
|
data.num_nodes = n |
|
|
data.num_classes = len(classes) |
|
|
return data |
|
|
|
|
|
def make_planetoid_style_split(y, num_classes, train_per_class=20, val_size=500, test_size=1000): |
|
|
N = y.size(0) |
|
|
all_idx = torch.arange(N) |
|
|
train_mask = torch.zeros(N, dtype=torch.bool) |
|
|
val_mask = torch.zeros(N, dtype=torch.bool) |
|
|
test_mask = torch.zeros(N, dtype=torch.bool) |
|
|
for c in range(num_classes): |
|
|
idx_c = all_idx[(y == c)] |
|
|
if idx_c.numel() == 0: |
|
|
continue |
|
|
sel = idx_c[torch.randperm(idx_c.numel())[: min(train_per_class, idx_c.numel())]] |
|
|
train_mask[sel] = True |
|
|
remaining = all_idx[~train_mask] |
|
|
remaining = remaining[torch.randperm(remaining.numel())] |
|
|
val_k = min(val_size, remaining.numel()) |
|
|
val_mask[remaining[:val_k]] = True |
|
|
rem2 = remaining[val_k:] |
|
|
test_k = min(test_size, rem2.numel()) |
|
|
test_mask[rem2[:test_k]] = True |
|
|
return train_mask, val_mask, test_mask |
|
|
|
|
|
|
|
|
def load_lrmc_partition(path: str, num_nodes: int): |
|
|
obj = json.loads(Path(path).read_text()) |
|
|
clusters = obj["clusters"] |
|
|
cid_of_node = {} |
|
|
for c in clusters: |
|
|
cid = int(c["cluster_id"]) |
|
|
for u in c["members"]: |
|
|
cid_of_node[int(u)] = cid |
|
|
cluster_id = torch.full((num_nodes,), -1, dtype=torch.long) |
|
|
for u, cid in cid_of_node.items(): |
|
|
if 0 <= u < num_nodes: |
|
|
cluster_id[u] = cid |
|
|
if (cluster_id < 0).any(): |
|
|
miss = int((cluster_id < 0).sum().item()) |
|
|
raise RuntimeError(f"{miss} nodes not covered by seeds.") |
|
|
K = int(cluster_id.max().item() + 1) |
|
|
return cluster_id, K |
|
|
|
|
|
def pool_by_partition_weighted(x, edge_index, cluster_id, K): |
|
|
if x.dim() != 2: |
|
|
raise ValueError(f"Expected x to have shape [N, F], got {x.shape}") |
|
|
if cluster_id.shape != (x.shape[0],): |
|
|
raise ValueError(f"Expected cluster_id to have shape [{x.shape[0]}], got {cluster_id.shape}") |
|
|
sums = torch.zeros((K, x.size(1)), device=x.device, dtype=x.dtype) |
|
|
sums.index_add_(0, cluster_id, x) |
|
|
counts = torch.bincount(cluster_id, minlength=K).clamp_min(1).to(x.device).unsqueeze(1).to(x.dtype) |
|
|
x_pooled = sums / counts |
|
|
cu = cluster_id[edge_index[0]] |
|
|
cv = cluster_id[edge_index[1]] |
|
|
pairs = torch.stack([cu, cv], dim=1) |
|
|
uniq, w = torch.unique(pairs, dim=0, return_counts=True) |
|
|
mask = uniq[:, 0] != uniq[:, 1] |
|
|
edge_index_pooled = uniq[mask].t().contiguous() |
|
|
edge_weight = w[mask].to(torch.float) |
|
|
return x_pooled, edge_index_pooled, edge_weight |
|
|
|
|
|
def compress_partition_to_K(cluster_id, K_target, edge_index): |
|
|
cid = cluster_id.clone() |
|
|
K_now = int(cid.max().item() + 1) |
|
|
if K_now <= K_target: |
|
|
return cid, K_now |
|
|
sizes = torch.bincount(cid, minlength=K_now) |
|
|
kept = set(int(k) for k in torch.topk(sizes, K_target).indices.tolist()) |
|
|
|
|
|
cu = cid[edge_index[0]].tolist() |
|
|
cv = cid[edge_index[1]].tolist() |
|
|
w = {} |
|
|
for a, b in zip(cu, cv): |
|
|
if a == b: |
|
|
continue |
|
|
w[(a, b)] = w.get((a, b), 0) + 1 |
|
|
w[(b, a)] = w.get((b, a), 0) + 1 |
|
|
mapping = {} |
|
|
largest_kept = max(kept, key=lambda k: sizes[k].item()) |
|
|
for c in range(K_now): |
|
|
if c in kept: |
|
|
mapping[c] = c |
|
|
else: |
|
|
candidates = [(w.get((c, k), 0), k) for k in kept] |
|
|
mapping[c] = max(candidates)[1] if candidates else largest_kept |
|
|
for i in range(cid.numel()): |
|
|
cid[i] = mapping[int(cid[i].item())] |
|
|
kept_sorted = sorted(set(int(x) for x in cid.tolist())) |
|
|
remap = {old: new for new, old in enumerate(kept_sorted)} |
|
|
for i in range(cid.numel()): |
|
|
cid[i] = remap[int(cid[i].item())] |
|
|
return cid, len(kept_sorted) |
|
|
|
|
|
|
|
|
class LrmcSeededPoolGCN(nn.Module): |
|
|
def __init__(self, in_dim, hidden_dim, out_dim, cluster_id, K, dropout=0.5): |
|
|
super().__init__() |
|
|
self.conv1 = GCNConv(in_dim, hidden_dim, add_self_loops=True, normalize=True) |
|
|
self.conv2 = GCNConv(hidden_dim, out_dim, add_self_loops=True, normalize=True) |
|
|
self.lin_skip = nn.Linear(hidden_dim, out_dim, bias=True) |
|
|
self.score = nn.Linear(hidden_dim, 1, bias=False) |
|
|
self.dropout = dropout |
|
|
self.register_buffer("cluster_id", cluster_id) |
|
|
self.K = K |
|
|
|
|
|
def forward(self, x, edge_index): |
|
|
if x.dim() != 2: |
|
|
raise ValueError(f"Expected x to have shape [N, F], got {x.shape}") |
|
|
x1 = F.relu(self.conv1(x, edge_index)) |
|
|
if x1.shape[1] != HIDDEN: |
|
|
raise ValueError(f"Expected x1 to have shape [N, {HIDDEN}], got {x1.shape}") |
|
|
x1 = F.dropout(x1, p=self.dropout, training=self.training) |
|
|
gate = torch.tanh(self.score(x1)) |
|
|
if gate.shape != (x1.shape[0], 1): |
|
|
raise ValueError(f"Expected gate to have shape [{x1.shape[0]}, 1], got {gate.shape}") |
|
|
x1_g = x1 * gate |
|
|
if x1_g.shape != x1.shape: |
|
|
raise ValueError(f"Expected x1_g to have shape {x1.shape}, got {x1_g.shape}") |
|
|
x_p, ei_p, ew_p = pool_by_partition_weighted(x1_g, edge_index, self.cluster_id, self.K) |
|
|
x_p = self.conv2(x_p, ei_p, edge_weight=ew_p) |
|
|
up = x_p[self.cluster_id] |
|
|
skip = self.lin_skip(x1) |
|
|
logits = up + skip |
|
|
return logits, 0.0 |
|
|
|
|
|
class TopKPoolBroadcastGCN(nn.Module): |
|
|
|
|
|
def __init__(self, in_dim, hidden_dim, out_dim, K_target, dropout=0.5): |
|
|
super().__init__() |
|
|
self.conv1 = GCNConv(in_dim, hidden_dim, add_self_loops=True, normalize=True) |
|
|
self.conv2 = GCNConv(hidden_dim, out_dim, add_self_loops=True, normalize=True) |
|
|
self.lin_skip = nn.Linear(hidden_dim, out_dim, bias=True) |
|
|
self.score = nn.Linear(hidden_dim, 1, bias=False) |
|
|
self.dropout = dropout |
|
|
self.K_target = K_target |
|
|
@staticmethod |
|
|
def _degrees(edge_index, N): |
|
|
return torch.bincount(edge_index[0], minlength=N).to(torch.long) |
|
|
def forward(self, x, edge_index): |
|
|
N = x.size(0) |
|
|
x1 = F.relu(self.conv1(x, edge_index)) |
|
|
x1 = F.dropout(x1, p=self.dropout, training=self.training) |
|
|
raw = self.score(x1).squeeze(-1) |
|
|
gate = torch.tanh(raw).unsqueeze(-1) |
|
|
x1_g = x1 * gate |
|
|
K = min(self.K_target, N) |
|
|
kept = torch.topk(raw, K, sorted=True).indices |
|
|
keep_mask = torch.zeros(N, dtype=torch.bool, device=x.device); keep_mask[kept] = True |
|
|
deg = self._degrees(edge_index, N).to(x.device) |
|
|
u_list, v_list = edge_index[0].tolist(), edge_index[1].tolist() |
|
|
neigh = [[] for _ in range(N)] |
|
|
for a, b in zip(u_list, v_list): |
|
|
neigh[a].append(b); neigh[b].append(a) |
|
|
cluster_id = torch.full((N,), -1, dtype=torch.long, device=x.device) |
|
|
cluster_id[kept] = torch.arange(kept.numel(), device=x.device, dtype=torch.long) |
|
|
best_global_kept = kept[torch.argmax(deg[kept])].item() if kept.numel() > 0 else 0 |
|
|
for u in range(N): |
|
|
if keep_mask[u]: |
|
|
continue |
|
|
cand = [w for w in neigh[u] if keep_mask[w]] |
|
|
cluster_id[u] = cluster_id[max(cand, key=lambda z: int(deg[z].item()))] if cand else cluster_id[best_global_kept] |
|
|
Kc = int(cluster_id.max().item() + 1) |
|
|
x_p, ei_p, ew_p = pool_by_partition_weighted(x1_g, edge_index, cluster_id, Kc) |
|
|
x_p = self.conv2(x_p, ei_p, edge_weight=ew_p) |
|
|
up = x_p[cluster_id] |
|
|
skip = self.lin_skip(x1) |
|
|
logits = up + skip |
|
|
return logits, 0.0 |
|
|
|
|
|
class DiffPoolGCNNode(nn.Module): |
|
|
|
|
|
def __init__(self, in_dim, hidden_dim, out_dim, K_clusters, dropout=0.5): |
|
|
super().__init__() |
|
|
self.dropout = dropout |
|
|
self.K = K_clusters |
|
|
self.gnn_embed1 = DenseGCNConv(in_dim, hidden_dim) |
|
|
self.gnn_embed2 = DenseGCNConv(hidden_dim, hidden_dim) |
|
|
self.gnn_assign1 = DenseGCNConv(in_dim, hidden_dim) |
|
|
self.gnn_assign2 = DenseGCNConv(hidden_dim, K_clusters) |
|
|
self.gnn_post1 = DenseGCNConv(hidden_dim, hidden_dim) |
|
|
self.gnn_post2 = DenseGCNConv(hidden_dim, out_dim) |
|
|
self.lin_skip = nn.Linear(hidden_dim, out_dim, bias=True) |
|
|
def forward(self, x, edge_index): |
|
|
N, device = x.size(0), x.device |
|
|
adj_dense = torch.zeros((N, N), device=device) |
|
|
adj_dense[edge_index[0], edge_index[1]] = 1.0 |
|
|
idx = torch.arange(N, device=device) |
|
|
adj_dense[idx, idx] = 1.0 |
|
|
x = x.unsqueeze(0) |
|
|
adj = adj_dense.unsqueeze(0) |
|
|
mask = torch.ones((1, N), device=device) |
|
|
z = F.relu(self.gnn_embed1(x, adj, mask)) |
|
|
z = F.dropout(z, p=self.dropout, training=self.training) |
|
|
z = F.relu(self.gnn_embed2(z, adj, mask)) |
|
|
s = F.relu(self.gnn_assign1(x, adj, mask)) |
|
|
s = F.dropout(s, p=self.dropout, training=self.training) |
|
|
s = self.gnn_assign2(s, adj, mask).softmax(dim=-1) |
|
|
x_pool, adj_pool, link_loss, ent_loss = dense_diff_pool(z, adj, s, mask) |
|
|
h = F.relu(self.gnn_post1(x_pool, adj_pool)) |
|
|
h = F.dropout(h, p=self.dropout, training=self.training) |
|
|
h = self.gnn_post2(h, adj_pool) |
|
|
skip = self.lin_skip(z.squeeze(0)) |
|
|
logits_nodes = torch.matmul(s.squeeze(0), h.squeeze(0)) + skip |
|
|
aux_loss = link_loss + ent_loss |
|
|
return logits_nodes, aux_loss |
|
|
|
|
|
|
|
|
def train_one(model, data, train_mask, val_mask, test_mask, device, aux_weight=0.0): |
|
|
model = model.to(device) |
|
|
data = data.to(device) |
|
|
train_mask = train_mask.to(device) |
|
|
val_mask = val_mask.to(device) |
|
|
test_mask = test_mask.to(device) |
|
|
|
|
|
opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY) |
|
|
best_state = None |
|
|
best_val = -math.inf |
|
|
bad = 0 |
|
|
|
|
|
for epoch in range(1, EPOCHS + 1): |
|
|
model.train() |
|
|
opt.zero_grad() |
|
|
logits, aux_loss = model(data.x, data.edge_index) |
|
|
loss = F.cross_entropy(logits[train_mask], data.y[train_mask]) |
|
|
if aux_weight > 0.0: |
|
|
loss = loss + aux_weight * aux_loss |
|
|
loss.backward() |
|
|
opt.step() |
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
logits, _ = model(data.x, data.edge_index) |
|
|
val_metric = accuracy_from_logits(logits, data.y, val_mask) |
|
|
|
|
|
if val_metric > best_val: |
|
|
best_val = val_metric |
|
|
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} |
|
|
bad = 0 |
|
|
else: |
|
|
bad += 1 |
|
|
|
|
|
if bad >= PATIENCE: |
|
|
break |
|
|
|
|
|
if best_state is not None: |
|
|
model.load_state_dict({k: v.to(device) for k, v in best_state.items()}) |
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
logits, _ = model(data.x, data.edge_index) |
|
|
test_acc = accuracy_from_logits(logits, data.y, test_mask) |
|
|
test_f1 = macro_f1_from_logits(logits, data.y, test_mask) |
|
|
return test_acc, test_f1 |
|
|
|
|
|
|
|
|
def run_sweeps(): |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
data = load_cora_from_content_and_cites(CORA_CONTENT, CORA_CITES) |
|
|
N = data.num_nodes |
|
|
cluster_id_full, K_full = load_lrmc_partition(SEEDS_JSON, data.num_nodes) |
|
|
|
|
|
print(f"Loaded Cora: N={data.num_nodes}, E={data.edge_index.size(1)}, F={data.num_features}, C={data.num_classes}") |
|
|
print(f"L-RMC base K = {K_full} (K/N = {K_full/N:.3f})") |
|
|
|
|
|
print("\nResults averaged over seeds:", SEEDS) |
|
|
print("tpc, K/N, K, Method, acc_mean, acc_std, f1_mean, f1_std") |
|
|
|
|
|
for tpc in LABEL_BUDGETS: |
|
|
for ratio in K_RATIOS: |
|
|
K_target = max(1, int(ratio * N)) |
|
|
accs = { "LRMC": [], "gPool": [], "DiffPool": [] } |
|
|
f1s = { "LRMC": [], "gPool": [], "DiffPool": [] } |
|
|
|
|
|
for s in SEEDS: |
|
|
set_seed(s) |
|
|
train_mask, val_mask, test_mask = make_planetoid_style_split( |
|
|
data.y, data.num_classes, train_per_class=tpc, val_size=500, test_size=1000 |
|
|
) |
|
|
|
|
|
|
|
|
cid_eq, K_eq = compress_partition_to_K(cluster_id_full, K_target, data.edge_index) |
|
|
|
|
|
|
|
|
lrmc_model = LrmcSeededPoolGCN( |
|
|
in_dim=data.num_features, hidden_dim=HIDDEN, out_dim=data.num_classes, |
|
|
cluster_id=cid_eq.to(data.x.device), K=K_eq, dropout=DROPOUT, |
|
|
) |
|
|
a, f = train_one(lrmc_model, data, train_mask, val_mask, test_mask, device) |
|
|
accs["LRMC"].append(a); f1s["LRMC"].append(f) |
|
|
|
|
|
|
|
|
g_model = TopKPoolBroadcastGCN( |
|
|
in_dim=data.num_features, hidden_dim=HIDDEN, out_dim=data.num_classes, |
|
|
K_target=K_eq, dropout=DROPOUT, |
|
|
) |
|
|
a, f = train_one(g_model, data, train_mask, val_mask, test_mask, device) |
|
|
accs["gPool"].append(a); f1s["gPool"].append(f) |
|
|
|
|
|
|
|
|
d_model = DiffPoolGCNNode( |
|
|
in_dim=data.num_features, hidden_dim=HIDDEN, out_dim=data.num_classes, |
|
|
K_clusters=K_eq, dropout=0.3, |
|
|
) |
|
|
a, f = train_one(d_model, data, train_mask, val_mask, test_mask, device, |
|
|
aux_weight=DIFFPOOL_AUX_WEIGHT) |
|
|
accs["DiffPool"].append(a); f1s["DiffPool"].append(f) |
|
|
|
|
|
def ms(x): |
|
|
return mean(x), (0.0 if len(x) < 2 else pstdev(x)) |
|
|
|
|
|
for name in ["LRMC", "gPool", "DiffPool"]: |
|
|
am, asd = ms(accs[name]) |
|
|
fm, fsd = ms(f1s[name]) |
|
|
print(f"{tpc:3d}, {ratio:0.2f}, {K_eq:4d}, {name:7s}, " |
|
|
f"{am:.3f}, {asd:.3f}, {fm:.3f}, {fsd:.3f}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
run_sweeps() |
|
|
|