CGSCORE / examples /graph /cgscore_datasets_multigpus.py
Yaning1001's picture
Add files using upload-large-folder tool
c91d7b1 verified
# compute cgscore for gcn
# author: Yaning
import torch
import numpy as np
import torch.nn.functional as Fd
from deeprobust.graph.defense import GCNJaccard, GCN
from deeprobust.graph.defense import GCNScore
from deeprobust.graph.utils import *
from deeprobust.graph.data import Dataset, PrePtbDataset
from scipy.sparse import csr_matrix
import argparse
import pickle
from deeprobust.graph import utils
import torch.multiprocessing as mp
from collections import defaultdict
from tqdm import tqdm
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
parser.add_argument('--dataset', type=str, default='pubmed', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
parser.add_argument('--ptb_rate', type=float, default=0.05, help='pertubation rate')
args = parser.parse_args()
args.cuda = torch.cuda.is_available()
print('cuda: %s' % args.cuda)
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
# make sure you use the same data splits as you generated attacks
np.random.seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)
# Here the random seed is to split the train/val/test data,
# we need to set the random seed to be the same as that when you generate the perturbed graph
# data = Dataset(root='/tmp/', name=args.dataset, setting='nettack', seed=15)
# Or we can just use setting='prognn' to get the splits
data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
adj, features, labels = data.adj, data.features, data.labels
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
perturbed_data = PrePtbDataset(root='/tmp/',
name=args.dataset,
attack_method='meta',
ptb_rate=args.ptb_rate)
perturbed_adj = perturbed_data.adj
# perturbed_adj = adj
def save_cg_scores(cg_scores, filename="cg_scores.npy"):
np.save(filename, cg_scores)
print(f"CG-scores saved to {filename}")
def load_cg_scores_numpy(filename="cg_scores.npy"):
cg_scores = np.load(filename, allow_pickle=True)
print(f"CG-scores loaded from {filename}")
return cg_scores
def calc_cg_score_gnn_with_sampling(
A, X, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False, batch_size=64, label_filter=None
):
N = A.shape[0]
cg_scores = {
"vi": np.zeros((N, N)),
"ab": np.zeros((N, N)),
"a2": np.zeros((N, N)),
"b2": np.zeros((N, N)),
"times": np.zeros((N, N)),
}
A = A.to(device)
X = X.to(device)
labels = labels.to(device)
@torch.no_grad()
def normalize(tensor):
return tensor / (torch.norm(tensor, dim=1, keepdim=True) + 1e-8)
for _ in range(rep_num):
AX = torch.matmul(A, X)
norm_AX = normalize(AX)
unique_labels = torch.unique(labels)
if label_filter is not None:
unique_labels = [label for label in unique_labels if label.item() in label_filter]
label_to_indices = {
label.item(): (labels == label).nonzero(as_tuple=True)[0] for label in unique_labels
}
dataset = {label: norm_AX[indices] for label, indices in label_to_indices.items()}
neg_samples_dict = {}
neg_indices_dict = {}
for label in unique_labels:
print("label:", label)
label = label.item()
mask = labels != label
neg_samples = norm_AX[mask]
neg_indices = mask.nonzero(as_tuple=True)[0]
neg_samples_dict[label] = neg_samples
neg_indices_dict[label] = neg_indices
for curr_label in tqdm(unique_labels, desc="Label groups", position=device.index):
print("curr_label:", curr_label)
label_id = int(curr_label)
curr_samples = dataset[label_id]
curr_indices = label_to_indices[label_id]
curr_num = len(curr_samples)
chosen_curr_idx = torch.randperm(curr_num, device=device)
chosen_curr_samples = curr_samples[chosen_curr_idx]
chosen_curr_indices = curr_indices[chosen_curr_idx]
neg_samples = neg_samples_dict[label_id]
neg_indices = neg_indices_dict[label_id]
neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples))
rand_idx = torch.randperm(len(neg_samples), device=device)[:neg_num]
chosen_neg_samples = neg_samples[rand_idx]
chosen_neg_indices = neg_indices[rand_idx]
combined_samples = torch.cat([chosen_curr_samples, chosen_neg_samples], dim=0)
y = torch.cat([torch.ones(len(chosen_curr_samples)), -torch.ones(neg_num)], dim=0).to(device)
H_inner = torch.matmul(combined_samples, combined_samples.T)
H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
H.fill_diagonal_(0.5)
H += 1e-6 * torch.eye(H.size(0), device=device)
invH = torch.inverse(H)
original_error = y @ (invH @ y)
edge_batch = []
for idx_i in chosen_curr_indices.tolist():
for j in range(idx_i + 1, N):
if A[idx_i, j] != 0:
edge_batch.append((idx_i, j))
for k in tqdm(range(0, len(edge_batch), batch_size), desc="Edge batches", leave=False, position=device.index):
batch = edge_batch[k: k + batch_size]
B = len(batch)
norm_AX1_batch = norm_AX.repeat(B, 1, 1).clone()
for b, (i, j) in enumerate(batch):
AX1_i = AX[i] - A[i, j] * X[j]
AX1_j = AX[j] - A[j, i] * X[i]
norm_AX1_batch[b, i] = AX1_i / (torch.norm(AX1_i) + 1e-8)
norm_AX1_batch[b, j] = AX1_j / (torch.norm(AX1_j) + 1e-8)
sample_idx = chosen_curr_indices.tolist() + chosen_neg_indices.tolist()
sample_batch = norm_AX1_batch[:, sample_idx, :]
H_inner = torch.matmul(sample_batch, sample_batch.transpose(1, 2))
H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
eye = torch.eye(H.size(-1), device=device).unsqueeze(0).expand_as(H)
H = H + 1e-6 * eye
H.diagonal(dim1=-2, dim2=-1).copy_(0.5)
invH = torch.inverse(H)
y_expanded = y.unsqueeze(0).expand(B, -1)
error_A1 = torch.einsum("bi,bij,bj->b", y_expanded, invH, y_expanded)
for b, (i, j) in enumerate(batch):
score = (original_error - error_A1[b]).item()
cg_scores["vi"][i, j] += score
cg_scores["vi"][j, i] = score
cg_scores["times"][i, j] += 1
cg_scores["times"][j, i] += 1
for key in cg_scores:
if key != "times":
cg_scores[key] = cg_scores[key] / np.where(cg_scores["times"] > 0, cg_scores["times"], 1)
# return cg_scores if sub_term else cg_scores["vi"]
return cg_scores
def run_worker(gpu_id, world_size, A, X, labels, rep_num, unbalance_ratio, sub_term, batch_size, return_dict):
device = torch.device(f"cuda:{gpu_id}")
unique_labels = torch.unique(labels).tolist()
label_chunks = np.array_split(unique_labels, world_size)
rank = torch.cuda.current_device()
label_filter = [int(l) for l in label_chunks[gpu_id % world_size]]
result = calc_cg_score_gnn_with_sampling(
A, X, labels, device,
rep_num=rep_num,
unbalance_ratio=unbalance_ratio,
sub_term=sub_term,
batch_size=batch_size,
label_filter=label_filter
)
return_dict[gpu_id] = result
def multi_gpu_wrapper(A, X, labels, rep_num=1, unbalance_ratio=1, sub_term=False, batch_size=64, gpu_ids=None):
if gpu_ids is None:
gpu_ids = list(range(torch.cuda.device_count()))
world_size = len(gpu_ids)
manager = mp.Manager()
return_dict = manager.dict()
processes = []
for local_rank, gpu_id in enumerate(gpu_ids):
p = mp.Process(
target=run_worker,
args=(gpu_id, world_size, A, X, labels, rep_num, unbalance_ratio, sub_term, batch_size, return_dict)
)
p.start()
processes.append(p)
for p in processes:
p.join()
# 初始化 final_score
final_score = None
for gpu_id, rank_result in return_dict.items():
if not isinstance(rank_result, dict):
print(f"[FATAL] GPU {gpu_id} result is not a dict: {type(rank_result)}")
continue
if final_score is None:
# 深拷贝防止指针复用
final_score = {k: np.copy(v) for k, v in rank_result.items()}
else:
for key in rank_result:
if key not in final_score:
print(f"[WARN] key '{key}' not in final_score. Skipping.")
continue
try:
if isinstance(final_score[key], np.ndarray) and isinstance(rank_result[key], np.ndarray):
final_score[key] += rank_result[key]
else:
print(f"[WARN] Skipped merging key '{key}' due to type mismatch.")
except Exception as e:
print(f"[ERROR] Failed merging key '{key}': {e}")
return final_score
def is_symmetric_sparse(adj):
"""
Check if a sparse matrix is symmetric.
"""
# Check symmetry
return (adj != adj.transpose()).nnz == 0 # .nnz is the number of non-zero elements
def make_symmetric_sparse(adj):
"""
Ensure the sparse adjacency matrix is symmetrical.
"""
# Make the matrix symmetric
sym_adj = (adj + adj.transpose()) / 2
return sym_adj
if __name__ == "__main__":
mp.set_start_method("spawn", force=True)
print("cuda:", torch.cuda.is_available())
# 选择使用的 GPU(根据你的实际情况)
selected_gpus = [0, 1, 2, 3]
# 稀疏矩阵对称处理
perturbed_adj = make_symmetric_sparse(perturbed_adj)
# 转 tensor
if type(perturbed_adj) is not torch.Tensor:
features, perturbed_adj, labels = utils.to_tensor(features, perturbed_adj, labels)
else:
features = features.to(device)
perturbed_adj = perturbed_adj.to(device)
labels = labels.to(device)
# 标准化邻接
if utils.is_sparse_tensor(perturbed_adj):
adj_norm = utils.normalize_adj_tensor(perturbed_adj, sparse=True)
else:
adj_norm = utils.normalize_adj_tensor(perturbed_adj)
features = features.to_dense()
perturbed_adj = adj_norm.to_dense()
# 多GPU并行计算 CG-score
calc_cg_score = multi_gpu_wrapper(
perturbed_adj, features, labels,
rep_num=1,
unbalance_ratio=1,
sub_term=False,
batch_size=1024,
gpu_ids=selected_gpus
)
save_cg_scores(calc_cg_score["vi"], filename=f"{args.dataset}_{args.ptb_rate}.npy")
print(" CG-score computation completed.")