CGSCORE / examples /graph /cgscore_datasets_multigpus2.py
Yaning1001's picture
Add files using upload-large-folder tool
c91d7b1 verified
# compute cgscore for gcn - Final Optimized Complete Version
## 精度有损失,但不多
import torch
import numpy as np
import torch.multiprocessing as mp
import argparse
from tqdm import tqdm
from deeprobust.graph.utils import *
from deeprobust.graph.data import Dataset, PrePtbDataset
from torch.cuda.amp import autocast
from deeprobust.graph import utils
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=15)
parser.add_argument('--dataset', type=str, default='pubmed', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'])
parser.add_argument('--ptb_rate', type=float, default=0.05)
args = parser.parse_args()
args.cuda = torch.cuda.is_available()
print('cuda: %s' % args.cuda)
device = torch.device("cuda:0" if args.cuda else "cpu")
np.random.seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)
data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
adj, features, labels = data.adj, data.features, data.labels
perturbed_data = PrePtbDataset(root='/tmp/', name=args.dataset, attack_method='meta', ptb_rate=args.ptb_rate)
perturbed_adj = (perturbed_data.adj + perturbed_data.adj.T) / 2
def save_cg_scores(cg_scores, filename="cg_scores.npy"):
np.save(filename, cg_scores)
print(f"CG-scores saved to {filename}")
def calc_cg_score_gnn_with_sampling(A, X, labels, device, rep_num=1, unbalance_ratio=1, batch_size=1024, node_filter=None):
N = A.shape[0]
cg_scores = {"vi": np.zeros((N, N)), "times": np.zeros((N, N))}
A, X, labels = A.to(device), X.to(device), labels.to(device)
@torch.no_grad()
def normalize(tensor):
return tensor / (torch.norm(tensor, dim=1, keepdim=True) + 1e-8)
for _ in range(rep_num):
AX = torch.matmul(A, X)
norm_AX = normalize(AX)
unique_labels = torch.unique(labels)
label_to_indices = {label.item(): (labels == label).nonzero(as_tuple=True)[0] for label in unique_labels}
dataset = {label: norm_AX[idx] for label, idx in label_to_indices.items()}
neg_samples_dict = {}
neg_indices_dict = {}
for label in unique_labels:
label = label.item()
mask = labels != label
neg_samples_dict[label] = norm_AX[mask]
neg_indices_dict[label] = mask.nonzero(as_tuple=True)[0]
if node_filter is not None:
node_filter = set(node_filter.tolist())
else:
node_filter = set(range(labels.size(0)))
for curr_label in tqdm(unique_labels, desc="Label groups", position=device.index):
label_id = int(curr_label)
curr_samples = dataset[label_id]
curr_indices = label_to_indices[label_id]
curr_num = len(curr_samples)
chosen_curr_idx = torch.randperm(curr_num, device=device)
pos_samples = curr_samples[chosen_curr_idx]
pos_indices = curr_indices[chosen_curr_idx]
neg_samples = neg_samples_dict[label_id]
neg_indices = neg_indices_dict[label_id]
neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples))
rand_idx = torch.randperm(len(neg_samples), device=device)[:neg_num]
neg_samples = neg_samples[rand_idx]
neg_indices = neg_indices[rand_idx]
sample_idx = pos_indices.tolist() + neg_indices.tolist()
sample_tensor = norm_AX[sample_idx] # [M, F]
y = torch.cat([
torch.ones(len(pos_samples)),
-torch.ones(len(neg_samples))
], dim=0).to(device)
with autocast():
H_inner = torch.matmul(sample_tensor, sample_tensor.T).clamp(-1.0, 1.0)
H_base = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
H_base.fill_diagonal_(0.5)
H_base += 1e-6 * torch.eye(H_base.size(0), device=device)
ref_error = torch.dot(y, torch.linalg.solve(H_base, y))
edge_batch = [(i.item(), j)
for i in pos_indices if i.item() in node_filter
for j in range(i.item() + 1, N) if A[i, j] != 0]
for k in tqdm(range(0, len(edge_batch), batch_size), desc="Edge batches", leave=False, position=device.index):
batch = edge_batch[k:k + batch_size]
if not batch: continue
i_idx, j_idx = zip(*batch)
i_idx = torch.tensor(i_idx, device=device)
j_idx = torch.tensor(j_idx, device=device)
AX1_i = AX[i_idx] - A[i_idx, j_idx].unsqueeze(1) * X[j_idx]
AX1_j = AX[j_idx] - A[j_idx, i_idx].unsqueeze(1) * X[i_idx]
norm_AX1_i = normalize(AX1_i)
norm_AX1_j = normalize(AX1_j)
for b, (i, j) in enumerate(batch):
i_int, j_int = i, j
sample_tensor_copy = sample_tensor.clone()
try:
i_pos = sample_idx.index(i_int)
sample_tensor_copy[i_pos] = norm_AX1_i[b]
except ValueError:
pass
try:
j_pos = sample_idx.index(j_int)
sample_tensor_copy[j_pos] = norm_AX1_j[b]
except ValueError:
pass
with autocast():
H_inner = torch.matmul(sample_tensor_copy, sample_tensor_copy.T).clamp(-1.0, 1.0)
H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
H.fill_diagonal_(0.5)
H += 1e-6 * torch.eye(H.size(0), device=device)
sol = torch.linalg.solve(H, y)
err_new = torch.dot(y, sol)
score = (ref_error - err_new).item()
cg_scores["vi"][i, j] += score
cg_scores["vi"][j, i] = score
cg_scores["times"][i, j] += 1
cg_scores["times"][j, i] += 1
for key in ["vi"]:
cg_scores[key] = cg_scores[key] / np.where(cg_scores["times"] > 0, cg_scores["times"], 1)
return cg_scores
def run_worker(gpu_id, world_size, A, X, labels, rep_num, unbalance_ratio, batch_size, return_dict):
device = torch.device(f"cuda:{gpu_id}")
# 用 node ids 划分代替 label 分片
node_ids = torch.arange(labels.size(0))
node_chunks = np.array_split(node_ids.numpy(), world_size)
node_filter = torch.tensor(node_chunks[gpu_id], device=device)
result = calc_cg_score_gnn_with_sampling(
A, X, labels, device,
rep_num=rep_num,
unbalance_ratio=unbalance_ratio,
batch_size=batch_size,
node_filter=node_filter # 👈 改名字
)
return_dict[gpu_id] = result
def multi_gpu_wrapper(A, X, labels, rep_num=1, unbalance_ratio=1, batch_size=1024, gpu_ids=None):
if gpu_ids is None:
gpu_ids = list(range(torch.cuda.device_count()))
world_size = len(gpu_ids)
mp.set_start_method("spawn", force=True)
manager = mp.Manager()
return_dict = manager.dict()
processes = []
for i, gpu_id in enumerate(gpu_ids):
p = mp.Process(target=run_worker, args=(gpu_id, world_size, A, X, labels, rep_num, unbalance_ratio, batch_size, return_dict))
p.start()
processes.append(p)
for p in processes:
p.join()
final_score = None
for res in return_dict.values():
if final_score is None:
final_score = {k: np.copy(v) for k, v in res.items()}
else:
for k in res:
final_score[k] += res[k]
return final_score
if __name__ == "__main__":
features, perturbed_adj, labels = utils.to_tensor(features, perturbed_adj, labels)
features = features.to_dense()
if utils.is_sparse_tensor(perturbed_adj):
perturbed_adj = utils.normalize_adj_tensor(perturbed_adj, sparse=True)
perturbed_adj = perturbed_adj.to_dense()
selected_gpus = [0,1,2,3]
cg_scores = multi_gpu_wrapper(perturbed_adj, features, labels,
rep_num=1,
unbalance_ratio=3,
batch_size=40280,
gpu_ids=selected_gpus)
save_cg_scores(cg_scores["vi"], filename=f"{args.dataset}_{args.ptb_rate}.npy")
print("🎉 CG-score computation completed.")