# compute cgscore for gcn - Final Optimized Complete Version ## 精度有损失,但不多 import torch import numpy as np import torch.multiprocessing as mp import argparse from tqdm import tqdm from deeprobust.graph.utils import * from deeprobust.graph.data import Dataset, PrePtbDataset from torch.cuda.amp import autocast from deeprobust.graph import utils parser = argparse.ArgumentParser() parser.add_argument('--seed', type=int, default=15) parser.add_argument('--dataset', type=str, default='pubmed', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed']) parser.add_argument('--ptb_rate', type=float, default=0.05) args = parser.parse_args() args.cuda = torch.cuda.is_available() print('cuda: %s' % args.cuda) device = torch.device("cuda:0" if args.cuda else "cpu") np.random.seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) data = Dataset(root='/tmp/', name=args.dataset, setting='prognn') adj, features, labels = data.adj, data.features, data.labels perturbed_data = PrePtbDataset(root='/tmp/', name=args.dataset, attack_method='meta', ptb_rate=args.ptb_rate) perturbed_adj = (perturbed_data.adj + perturbed_data.adj.T) / 2 def save_cg_scores(cg_scores, filename="cg_scores.npy"): np.save(filename, cg_scores) print(f"CG-scores saved to {filename}") def calc_cg_score_gnn_with_sampling(A, X, labels, device, rep_num=1, unbalance_ratio=1, batch_size=1024, node_filter=None): N = A.shape[0] cg_scores = {"vi": np.zeros((N, N)), "times": np.zeros((N, N))} A, X, labels = A.to(device), X.to(device), labels.to(device) @torch.no_grad() def normalize(tensor): return tensor / (torch.norm(tensor, dim=1, keepdim=True) + 1e-8) for _ in range(rep_num): AX = torch.matmul(A, X) norm_AX = normalize(AX) unique_labels = torch.unique(labels) label_to_indices = {label.item(): (labels == label).nonzero(as_tuple=True)[0] for label in unique_labels} dataset = {label: norm_AX[idx] for label, idx in label_to_indices.items()} neg_samples_dict = {} neg_indices_dict = {} for label in unique_labels: label = label.item() mask = labels != label neg_samples_dict[label] = norm_AX[mask] neg_indices_dict[label] = mask.nonzero(as_tuple=True)[0] if node_filter is not None: node_filter = set(node_filter.tolist()) else: node_filter = set(range(labels.size(0))) for curr_label in tqdm(unique_labels, desc="Label groups", position=device.index): label_id = int(curr_label) curr_samples = dataset[label_id] curr_indices = label_to_indices[label_id] curr_num = len(curr_samples) chosen_curr_idx = torch.randperm(curr_num, device=device) pos_samples = curr_samples[chosen_curr_idx] pos_indices = curr_indices[chosen_curr_idx] neg_samples = neg_samples_dict[label_id] neg_indices = neg_indices_dict[label_id] neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples)) rand_idx = torch.randperm(len(neg_samples), device=device)[:neg_num] neg_samples = neg_samples[rand_idx] neg_indices = neg_indices[rand_idx] sample_idx = pos_indices.tolist() + neg_indices.tolist() sample_tensor = norm_AX[sample_idx] # [M, F] y = torch.cat([ torch.ones(len(pos_samples)), -torch.ones(len(neg_samples)) ], dim=0).to(device) with autocast(): H_inner = torch.matmul(sample_tensor, sample_tensor.T).clamp(-1.0, 1.0) H_base = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi) H_base.fill_diagonal_(0.5) H_base += 1e-6 * torch.eye(H_base.size(0), device=device) ref_error = torch.dot(y, torch.linalg.solve(H_base, y)) edge_batch = [(i.item(), j) for i in pos_indices if i.item() in node_filter for j in range(i.item() + 1, N) if A[i, j] != 0] for k in tqdm(range(0, len(edge_batch), batch_size), desc="Edge batches", leave=False, position=device.index): batch = edge_batch[k:k + batch_size] if not batch: continue i_idx, j_idx = zip(*batch) i_idx = torch.tensor(i_idx, device=device) j_idx = torch.tensor(j_idx, device=device) AX1_i = AX[i_idx] - A[i_idx, j_idx].unsqueeze(1) * X[j_idx] AX1_j = AX[j_idx] - A[j_idx, i_idx].unsqueeze(1) * X[i_idx] norm_AX1_i = normalize(AX1_i) norm_AX1_j = normalize(AX1_j) for b, (i, j) in enumerate(batch): i_int, j_int = i, j sample_tensor_copy = sample_tensor.clone() try: i_pos = sample_idx.index(i_int) sample_tensor_copy[i_pos] = norm_AX1_i[b] except ValueError: pass try: j_pos = sample_idx.index(j_int) sample_tensor_copy[j_pos] = norm_AX1_j[b] except ValueError: pass with autocast(): H_inner = torch.matmul(sample_tensor_copy, sample_tensor_copy.T).clamp(-1.0, 1.0) H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi) H.fill_diagonal_(0.5) H += 1e-6 * torch.eye(H.size(0), device=device) sol = torch.linalg.solve(H, y) err_new = torch.dot(y, sol) score = (ref_error - err_new).item() cg_scores["vi"][i, j] += score cg_scores["vi"][j, i] = score cg_scores["times"][i, j] += 1 cg_scores["times"][j, i] += 1 for key in ["vi"]: cg_scores[key] = cg_scores[key] / np.where(cg_scores["times"] > 0, cg_scores["times"], 1) return cg_scores def run_worker(gpu_id, world_size, A, X, labels, rep_num, unbalance_ratio, batch_size, return_dict): device = torch.device(f"cuda:{gpu_id}") # 用 node ids 划分代替 label 分片 node_ids = torch.arange(labels.size(0)) node_chunks = np.array_split(node_ids.numpy(), world_size) node_filter = torch.tensor(node_chunks[gpu_id], device=device) result = calc_cg_score_gnn_with_sampling( A, X, labels, device, rep_num=rep_num, unbalance_ratio=unbalance_ratio, batch_size=batch_size, node_filter=node_filter # 👈 改名字 ) return_dict[gpu_id] = result def multi_gpu_wrapper(A, X, labels, rep_num=1, unbalance_ratio=1, batch_size=1024, gpu_ids=None): if gpu_ids is None: gpu_ids = list(range(torch.cuda.device_count())) world_size = len(gpu_ids) mp.set_start_method("spawn", force=True) manager = mp.Manager() return_dict = manager.dict() processes = [] for i, gpu_id in enumerate(gpu_ids): p = mp.Process(target=run_worker, args=(gpu_id, world_size, A, X, labels, rep_num, unbalance_ratio, batch_size, return_dict)) p.start() processes.append(p) for p in processes: p.join() final_score = None for res in return_dict.values(): if final_score is None: final_score = {k: np.copy(v) for k, v in res.items()} else: for k in res: final_score[k] += res[k] return final_score if __name__ == "__main__": features, perturbed_adj, labels = utils.to_tensor(features, perturbed_adj, labels) features = features.to_dense() if utils.is_sparse_tensor(perturbed_adj): perturbed_adj = utils.normalize_adj_tensor(perturbed_adj, sparse=True) perturbed_adj = perturbed_adj.to_dense() selected_gpus = [0,1,2,3] cg_scores = multi_gpu_wrapper(perturbed_adj, features, labels, rep_num=1, unbalance_ratio=3, batch_size=40280, gpu_ids=selected_gpus) save_cg_scores(cg_scores["vi"], filename=f"{args.dataset}_{args.ptb_rate}.npy") print("🎉 CG-score computation completed.")