|
|
|
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
import torch.nn.functional as Fd |
|
|
from deeprobust.graph.defense import GCNJaccard, GCN |
|
|
from deeprobust.graph.defense import GCNScore |
|
|
from deeprobust.graph.utils import * |
|
|
from deeprobust.graph.data import Dataset, PrePtbDataset |
|
|
from scipy.sparse import csr_matrix |
|
|
import argparse |
|
|
import pickle |
|
|
from deeprobust.graph import utils |
|
|
from collections import defaultdict |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--seed', type=int, default=15, help='Random seed.') |
|
|
parser.add_argument('--dataset', type=str, default='polblogs', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset') |
|
|
parser.add_argument('--ptb_rate', type=float, default=0.10, help='pertubation rate') |
|
|
|
|
|
args = parser.parse_args() |
|
|
args.cuda = torch.cuda.is_available() |
|
|
print('cuda: %s' % args.cuda) |
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
np.random.seed(args.seed) |
|
|
if args.cuda: |
|
|
torch.cuda.manual_seed(args.seed) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data = Dataset(root='/tmp/', name=args.dataset, setting='prognn') |
|
|
adj, features, labels = data.adj, data.features, data.labels |
|
|
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test |
|
|
|
|
|
|
|
|
perturbed_data = PrePtbDataset(root='/tmp/', |
|
|
name=args.dataset, |
|
|
attack_method='meta', |
|
|
ptb_rate=args.ptb_rate) |
|
|
|
|
|
perturbed_adj = perturbed_data.adj |
|
|
|
|
|
|
|
|
def save_cg_scores(cg_scores, filename="cg_scores.npy"): |
|
|
np.save(filename, cg_scores) |
|
|
print(f"CG-scores saved to {filename}") |
|
|
|
|
|
def load_cg_scores_numpy(filename="cg_scores.npy"): |
|
|
cg_scores = np.load(filename, allow_pickle=True) |
|
|
print(f"CG-scores loaded from {filename}") |
|
|
return cg_scores |
|
|
|
|
|
def calc_cg_score_gnn_with_sampling( |
|
|
A, X, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False |
|
|
): |
|
|
""" |
|
|
Calculate CG-score for each edge in a graph with node labels and random sampling. |
|
|
|
|
|
Args: |
|
|
A: torch.Tensor |
|
|
Adjacency matrix of the graph (size: N x N). |
|
|
X: torch.Tensor |
|
|
Node features matrix (size: N x F). |
|
|
labels: torch.Tensor |
|
|
Node labels (size: N). |
|
|
device: torch.device |
|
|
Device to perform calculations. |
|
|
rep_num: int |
|
|
Number of repetitions for Monte Carlo sampling. |
|
|
unbalance_ratio: float |
|
|
Ratio of unbalanced data (1:unbalance_ratio). |
|
|
sub_term: bool |
|
|
If True, calculate and return sub-terms. |
|
|
|
|
|
Returns: |
|
|
cg_scores: dict |
|
|
Dictionary containing CG-scores for edges and optionally sub-terms. |
|
|
""" |
|
|
N = A.shape[0] |
|
|
cg_scores = { |
|
|
"vi": np.zeros((N, N)), |
|
|
"ab": np.zeros((N, N)), |
|
|
"a2": np.zeros((N, N)), |
|
|
"b2": np.zeros((N, N)), |
|
|
"times": np.zeros((N, N)), |
|
|
} |
|
|
|
|
|
with torch.no_grad(): |
|
|
for _ in range(rep_num): |
|
|
|
|
|
AX = torch.matmul(A, X).to(device) |
|
|
norm_AX = AX / torch.norm(AX, dim=1, keepdim=True) |
|
|
|
|
|
|
|
|
dataset = defaultdict(list) |
|
|
data_idx = defaultdict(list) |
|
|
for i, label in enumerate(labels): |
|
|
dataset[label.item()].append(norm_AX[i].unsqueeze(0)) |
|
|
data_idx[label.item()].append(i) |
|
|
|
|
|
|
|
|
for label, data_list in dataset.items(): |
|
|
dataset[label] = torch.cat(data_list, dim=0) |
|
|
data_idx[label] = torch.tensor(data_idx[label], dtype=torch.long, device=device) |
|
|
|
|
|
|
|
|
for curr_label, curr_samples in dataset.items(): |
|
|
curr_indices = data_idx[curr_label] |
|
|
curr_num = len(curr_samples) |
|
|
|
|
|
|
|
|
chosen_curr_idx = np.random.choice(range(curr_num), curr_num, replace=False) |
|
|
chosen_curr_samples = curr_samples[chosen_curr_idx] |
|
|
chosen_curr_indices = curr_indices[chosen_curr_idx] |
|
|
|
|
|
|
|
|
neg_samples = torch.cat( |
|
|
[dataset[l] for l in dataset if l != curr_label], dim=0 |
|
|
) |
|
|
neg_indices = torch.cat( |
|
|
[data_idx[l] for l in data_idx if l != curr_label], dim=0 |
|
|
) |
|
|
neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples)) |
|
|
chosen_neg_samples = neg_samples[ |
|
|
torch.randperm(len(neg_samples))[:neg_num] |
|
|
] |
|
|
|
|
|
|
|
|
combined_samples = torch.cat([chosen_curr_samples, chosen_neg_samples], dim=0) |
|
|
y = torch.cat( |
|
|
[torch.ones(len(chosen_curr_samples)), -torch.ones(neg_num)], dim=0 |
|
|
).to(device) |
|
|
|
|
|
|
|
|
H_inner = torch.matmul(combined_samples, combined_samples.T) |
|
|
del combined_samples |
|
|
|
|
|
H_inner = torch.clamp(H_inner, min=-1.0, max=1.0) |
|
|
|
|
|
H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi) |
|
|
del H_inner |
|
|
|
|
|
H.fill_diagonal_(0.5) |
|
|
|
|
|
epsilon = 1e-6 |
|
|
H = H + epsilon * torch.eye(H.size(0), device=H.device) |
|
|
|
|
|
invH = torch.inverse(H) |
|
|
del H |
|
|
original_error = y @ (invH @ y) |
|
|
|
|
|
|
|
|
for i in chosen_curr_indices: |
|
|
print("the node index:", i) |
|
|
for j in range(i + 1, N): |
|
|
|
|
|
if A[i, j] == 0: |
|
|
continue |
|
|
|
|
|
|
|
|
A1 = A.clone() |
|
|
A1[i, j] = A1[j, i] = 0 |
|
|
|
|
|
|
|
|
AX1 = torch.matmul(A1, X).to(device) |
|
|
norm_AX1 = AX1 / torch.norm(AX1, dim=1, keepdim=True) |
|
|
|
|
|
|
|
|
curr_samples_A1 = norm_AX1[chosen_curr_indices] |
|
|
neg_samples_A1 = norm_AX1[neg_indices] |
|
|
chosen_neg_samples_A1 = neg_samples_A1[ |
|
|
torch.randperm(len(neg_samples_A1))[:neg_num] |
|
|
] |
|
|
combined_samples_A1 = torch.cat( |
|
|
[curr_samples_A1, chosen_neg_samples_A1], dim=0 |
|
|
) |
|
|
H_inner_A1 = torch.matmul(combined_samples_A1, combined_samples_A1.T) |
|
|
|
|
|
del combined_samples_A1 |
|
|
|
|
|
|
|
|
H_inner_A1 = torch.clamp(H_inner_A1, min=-1.0, max=1.0) |
|
|
|
|
|
|
|
|
H_A1 = H_inner_A1 * (np.pi - torch.acos(H_inner_A1)) / (2 * np.pi) |
|
|
del H_inner_A1 |
|
|
H_A1.fill_diagonal_(0.5) |
|
|
|
|
|
|
|
|
epsilon = 1e-6 |
|
|
H_A1= H_A1 + epsilon * torch.eye(H_A1.size(0), device=H_A1.device) |
|
|
|
|
|
invH_A1 = torch.inverse(H_A1) |
|
|
del H_A1 |
|
|
|
|
|
error_A1 = y @ (invH_A1 @ y) |
|
|
|
|
|
print("i:", i) |
|
|
print("j:", j) |
|
|
print("current score:", (original_error - error_A1).item()) |
|
|
|
|
|
cg_scores["vi"][i, j] += (original_error - error_A1).item() |
|
|
cg_scores["vi"][j, i] = cg_scores["vi"][i, j] |
|
|
cg_scores["times"][i, j] += 1 |
|
|
cg_scores["times"][j, i] += 1 |
|
|
|
|
|
|
|
|
for key, values in cg_scores.items(): |
|
|
if key == "times": |
|
|
continue |
|
|
cg_scores[key] = values / np.where(cg_scores["times"] > 0, cg_scores["times"], 1) |
|
|
|
|
|
return cg_scores if sub_term else cg_scores["vi"] |
|
|
|
|
|
def is_symmetric_sparse(adj): |
|
|
""" |
|
|
Check if a sparse matrix is symmetric. |
|
|
""" |
|
|
|
|
|
return (adj != adj.transpose()).nnz == 0 |
|
|
|
|
|
def make_symmetric_sparse(adj): |
|
|
""" |
|
|
Ensure the sparse adjacency matrix is symmetrical. |
|
|
""" |
|
|
|
|
|
sym_adj = (adj + adj.transpose()) / 2 |
|
|
return sym_adj |
|
|
|
|
|
perturbed_adj = make_symmetric_sparse(perturbed_adj) |
|
|
|
|
|
if type(perturbed_adj) is not torch.Tensor: |
|
|
features, perturbed_adj, labels = utils.to_tensor(features, perturbed_adj, labels) |
|
|
else: |
|
|
features = features.to(device) |
|
|
perturbed_adj = perturbed_adj.to(device) |
|
|
labels = labels.to(device) |
|
|
|
|
|
if utils.is_sparse_tensor(perturbed_adj): |
|
|
|
|
|
adj_norm = utils.normalize_adj_tensor(perturbed_adj, sparse=True) |
|
|
else: |
|
|
adj_norm = utils.normalize_adj_tensor(perturbed_adj) |
|
|
|
|
|
features = features.to_dense() |
|
|
perturbed_adj = adj_norm.to_dense() |
|
|
|
|
|
|
|
|
calc_cg_score = calc_cg_score_gnn_with_sampling(perturbed_adj, features, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False) |
|
|
save_cg_scores(calc_cg_score, filename="cg_scores_polblogs_0.10.npy") |
|
|
|
|
|
|