|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
import torch.multiprocessing as mp |
|
|
import argparse |
|
|
from tqdm import tqdm |
|
|
from deeprobust.graph.utils import * |
|
|
from deeprobust.graph.data import Dataset, PrePtbDataset |
|
|
from torch.cuda.amp import autocast |
|
|
from deeprobust.graph import utils |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--seed', type=int, default=15) |
|
|
parser.add_argument('--dataset', type=str, default='pubmed', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed']) |
|
|
parser.add_argument('--ptb_rate', type=float, default=0.05) |
|
|
args = parser.parse_args() |
|
|
|
|
|
args.cuda = torch.cuda.is_available() |
|
|
print('cuda: %s' % args.cuda) |
|
|
device = torch.device("cuda:0" if args.cuda else "cpu") |
|
|
|
|
|
np.random.seed(args.seed) |
|
|
if args.cuda: |
|
|
torch.cuda.manual_seed(args.seed) |
|
|
|
|
|
data = Dataset(root='/tmp/', name=args.dataset, setting='prognn') |
|
|
adj, features, labels = data.adj, data.features, data.labels |
|
|
|
|
|
perturbed_data = PrePtbDataset(root='/tmp/', name=args.dataset, attack_method='meta', ptb_rate=args.ptb_rate) |
|
|
perturbed_adj = (perturbed_data.adj + perturbed_data.adj.T) / 2 |
|
|
|
|
|
def save_cg_scores(cg_scores, filename="cg_scores.npy"): |
|
|
np.save(filename, cg_scores) |
|
|
print(f"CG-scores saved to {filename}") |
|
|
|
|
|
def calc_cg_score_gnn_with_sampling(A, X, labels, device, rep_num=1, unbalance_ratio=1, batch_size=1024, node_filter=None): |
|
|
N = A.shape[0] |
|
|
cg_scores = {"vi": np.zeros((N, N)), "times": np.zeros((N, N))} |
|
|
A, X, labels = A.to(device), X.to(device), labels.to(device) |
|
|
|
|
|
@torch.no_grad() |
|
|
def normalize(tensor): |
|
|
return tensor / (torch.norm(tensor, dim=1, keepdim=True) + 1e-8) |
|
|
|
|
|
for _ in range(rep_num): |
|
|
AX = torch.matmul(A, X) |
|
|
norm_AX = normalize(AX) |
|
|
|
|
|
unique_labels = torch.unique(labels) |
|
|
label_to_indices = {label.item(): (labels == label).nonzero(as_tuple=True)[0] for label in unique_labels} |
|
|
dataset = {label: norm_AX[idx] for label, idx in label_to_indices.items()} |
|
|
|
|
|
neg_samples_dict = {} |
|
|
neg_indices_dict = {} |
|
|
for label in unique_labels: |
|
|
label = label.item() |
|
|
mask = labels != label |
|
|
neg_samples_dict[label] = norm_AX[mask] |
|
|
neg_indices_dict[label] = mask.nonzero(as_tuple=True)[0] |
|
|
|
|
|
if node_filter is not None: |
|
|
node_filter = set(node_filter.tolist()) |
|
|
else: |
|
|
node_filter = set(range(labels.size(0))) |
|
|
|
|
|
for curr_label in tqdm(unique_labels, desc="Label groups", position=device.index): |
|
|
label_id = int(curr_label) |
|
|
curr_samples = dataset[label_id] |
|
|
curr_indices = label_to_indices[label_id] |
|
|
curr_num = len(curr_samples) |
|
|
|
|
|
chosen_curr_idx = torch.randperm(curr_num, device=device) |
|
|
pos_samples = curr_samples[chosen_curr_idx] |
|
|
pos_indices = curr_indices[chosen_curr_idx] |
|
|
|
|
|
neg_samples = neg_samples_dict[label_id] |
|
|
neg_indices = neg_indices_dict[label_id] |
|
|
neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples)) |
|
|
rand_idx = torch.randperm(len(neg_samples), device=device)[:neg_num] |
|
|
neg_samples = neg_samples[rand_idx] |
|
|
neg_indices = neg_indices[rand_idx] |
|
|
|
|
|
sample_idx = pos_indices.tolist() + neg_indices.tolist() |
|
|
sample_tensor = norm_AX[sample_idx] |
|
|
y = torch.cat([ |
|
|
torch.ones(len(pos_samples)), |
|
|
-torch.ones(len(neg_samples)) |
|
|
], dim=0).to(device) |
|
|
|
|
|
with autocast(): |
|
|
H_inner = torch.matmul(sample_tensor, sample_tensor.T).clamp(-1.0, 1.0) |
|
|
H_base = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi) |
|
|
H_base.fill_diagonal_(0.5) |
|
|
H_base += 1e-6 * torch.eye(H_base.size(0), device=device) |
|
|
ref_error = torch.dot(y, torch.linalg.solve(H_base, y)) |
|
|
|
|
|
edge_batch = [(i.item(), j) |
|
|
for i in pos_indices if i.item() in node_filter |
|
|
for j in range(i.item() + 1, N) if A[i, j] != 0] |
|
|
|
|
|
for k in tqdm(range(0, len(edge_batch), batch_size), desc="Edge batches", leave=False, position=device.index): |
|
|
batch = edge_batch[k:k + batch_size] |
|
|
if not batch: continue |
|
|
i_idx, j_idx = zip(*batch) |
|
|
i_idx = torch.tensor(i_idx, device=device) |
|
|
j_idx = torch.tensor(j_idx, device=device) |
|
|
|
|
|
AX1_i = AX[i_idx] - A[i_idx, j_idx].unsqueeze(1) * X[j_idx] |
|
|
AX1_j = AX[j_idx] - A[j_idx, i_idx].unsqueeze(1) * X[i_idx] |
|
|
norm_AX1_i = normalize(AX1_i) |
|
|
norm_AX1_j = normalize(AX1_j) |
|
|
|
|
|
for b, (i, j) in enumerate(batch): |
|
|
i_int, j_int = i, j |
|
|
sample_tensor_copy = sample_tensor.clone() |
|
|
try: |
|
|
i_pos = sample_idx.index(i_int) |
|
|
sample_tensor_copy[i_pos] = norm_AX1_i[b] |
|
|
except ValueError: |
|
|
pass |
|
|
try: |
|
|
j_pos = sample_idx.index(j_int) |
|
|
sample_tensor_copy[j_pos] = norm_AX1_j[b] |
|
|
except ValueError: |
|
|
pass |
|
|
|
|
|
with autocast(): |
|
|
H_inner = torch.matmul(sample_tensor_copy, sample_tensor_copy.T).clamp(-1.0, 1.0) |
|
|
H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi) |
|
|
H.fill_diagonal_(0.5) |
|
|
H += 1e-6 * torch.eye(H.size(0), device=device) |
|
|
sol = torch.linalg.solve(H, y) |
|
|
err_new = torch.dot(y, sol) |
|
|
|
|
|
score = (ref_error - err_new).item() |
|
|
cg_scores["vi"][i, j] += score |
|
|
cg_scores["vi"][j, i] = score |
|
|
cg_scores["times"][i, j] += 1 |
|
|
cg_scores["times"][j, i] += 1 |
|
|
|
|
|
for key in ["vi"]: |
|
|
cg_scores[key] = cg_scores[key] / np.where(cg_scores["times"] > 0, cg_scores["times"], 1) |
|
|
|
|
|
return cg_scores |
|
|
|
|
|
def run_worker(gpu_id, world_size, A, X, labels, rep_num, unbalance_ratio, batch_size, return_dict): |
|
|
device = torch.device(f"cuda:{gpu_id}") |
|
|
|
|
|
|
|
|
node_ids = torch.arange(labels.size(0)) |
|
|
node_chunks = np.array_split(node_ids.numpy(), world_size) |
|
|
node_filter = torch.tensor(node_chunks[gpu_id], device=device) |
|
|
|
|
|
result = calc_cg_score_gnn_with_sampling( |
|
|
A, X, labels, device, |
|
|
rep_num=rep_num, |
|
|
unbalance_ratio=unbalance_ratio, |
|
|
batch_size=batch_size, |
|
|
node_filter=node_filter |
|
|
) |
|
|
return_dict[gpu_id] = result |
|
|
|
|
|
def multi_gpu_wrapper(A, X, labels, rep_num=1, unbalance_ratio=1, batch_size=1024, gpu_ids=None): |
|
|
if gpu_ids is None: |
|
|
gpu_ids = list(range(torch.cuda.device_count())) |
|
|
world_size = len(gpu_ids) |
|
|
|
|
|
mp.set_start_method("spawn", force=True) |
|
|
manager = mp.Manager() |
|
|
return_dict = manager.dict() |
|
|
processes = [] |
|
|
|
|
|
for i, gpu_id in enumerate(gpu_ids): |
|
|
p = mp.Process(target=run_worker, args=(gpu_id, world_size, A, X, labels, rep_num, unbalance_ratio, batch_size, return_dict)) |
|
|
p.start() |
|
|
processes.append(p) |
|
|
|
|
|
for p in processes: |
|
|
p.join() |
|
|
|
|
|
final_score = None |
|
|
for res in return_dict.values(): |
|
|
if final_score is None: |
|
|
final_score = {k: np.copy(v) for k, v in res.items()} |
|
|
else: |
|
|
for k in res: |
|
|
final_score[k] += res[k] |
|
|
return final_score |
|
|
|
|
|
if __name__ == "__main__": |
|
|
features, perturbed_adj, labels = utils.to_tensor(features, perturbed_adj, labels) |
|
|
features = features.to_dense() |
|
|
if utils.is_sparse_tensor(perturbed_adj): |
|
|
perturbed_adj = utils.normalize_adj_tensor(perturbed_adj, sparse=True) |
|
|
perturbed_adj = perturbed_adj.to_dense() |
|
|
|
|
|
selected_gpus = [0,1,2,3] |
|
|
cg_scores = multi_gpu_wrapper(perturbed_adj, features, labels, |
|
|
rep_num=1, |
|
|
unbalance_ratio=3, |
|
|
batch_size=40280, |
|
|
gpu_ids=selected_gpus) |
|
|
|
|
|
save_cg_scores(cg_scores["vi"], filename=f"{args.dataset}_{args.ptb_rate}.npy") |
|
|
print("🎉 CG-score computation completed.") |
|
|
|