File size: 10,018 Bytes
c91d7b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
# compute cgscore for gcn
# author: Yaning
import torch
import numpy as np
import torch.nn.functional as Fd
from deeprobust.graph.defense import GCNJaccard, GCN
from deeprobust.graph.defense import GCNScore
from deeprobust.graph.utils import *
from deeprobust.graph.data import Dataset, PrePtbDataset
from scipy.sparse import csr_matrix
import argparse
import pickle
from deeprobust.graph import utils
from collections import defaultdict

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
parser.add_argument('--dataset', type=str, default='polblogs', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
parser.add_argument('--ptb_rate', type=float, default=0.10,  help='pertubation rate')

args = parser.parse_args()
args.cuda = torch.cuda.is_available()
print('cuda: %s' % args.cuda)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# make sure you use the same data splits as you generated attacks
np.random.seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

# Here the random seed is to split the train/val/test data,
# we need to set the random seed to be the same as that when you generate the perturbed graph
# data = Dataset(root='/tmp/', name=args.dataset, setting='nettack', seed=15)
# Or we can just use setting='prognn' to get the splits
data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
adj, features, labels = data.adj, data.features, data.labels
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test


perturbed_data = PrePtbDataset(root='/tmp/',
        name=args.dataset,
        attack_method='meta',
        ptb_rate=args.ptb_rate)

perturbed_adj = perturbed_data.adj
# perturbed_adj = adj

def save_cg_scores(cg_scores, filename="cg_scores.npy"):
    np.save(filename, cg_scores)
    print(f"CG-scores saved to {filename}")

def load_cg_scores_numpy(filename="cg_scores.npy"):
    cg_scores = np.load(filename, allow_pickle=True)
    print(f"CG-scores loaded from {filename}")
    return cg_scores

def calc_cg_score_gnn_with_sampling(
    A, X, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False
):
    """
    Calculate CG-score for each edge in a graph with node labels and random sampling.

    Args:
        A: torch.Tensor
            Adjacency matrix of the graph (size: N x N).
        X: torch.Tensor
            Node features matrix (size: N x F).
        labels: torch.Tensor
            Node labels (size: N).
        device: torch.device
            Device to perform calculations.
        rep_num: int
            Number of repetitions for Monte Carlo sampling.
        unbalance_ratio: float
            Ratio of unbalanced data (1:unbalance_ratio).
        sub_term: bool
            If True, calculate and return sub-terms.

    Returns:
        cg_scores: dict
            Dictionary containing CG-scores for edges and optionally sub-terms.
    """
    N = A.shape[0]
    cg_scores = {
        "vi": np.zeros((N, N)),
        "ab": np.zeros((N, N)),
        "a2": np.zeros((N, N)),
        "b2": np.zeros((N, N)),
        "times": np.zeros((N, N)),
    }

    with torch.no_grad():
        for _ in range(rep_num):
            # Compute AX (node representations)
            AX = torch.matmul(A, X).to(device)
            norm_AX = AX / torch.norm(AX, dim=1, keepdim=True)

            # Group nodes by their labels
            dataset = defaultdict(list)
            data_idx = defaultdict(list)
            for i, label in enumerate(labels):
                dataset[label.item()].append(norm_AX[i].unsqueeze(0))  # Store normalized data
                data_idx[label.item()].append(i)  # Store indices

            # Convert to tensors
            for label, data_list in dataset.items():
                dataset[label] = torch.cat(data_list, dim=0)
                data_idx[label] = torch.tensor(data_idx[label], dtype=torch.long, device=device)

            # Calculate CG-scores for each label group
            for curr_label, curr_samples in dataset.items():
                curr_indices = data_idx[curr_label]
                curr_num = len(curr_samples)

                # Randomly sample a subset of current label examples
                chosen_curr_idx = np.random.choice(range(curr_num), curr_num, replace=False)
                chosen_curr_samples = curr_samples[chosen_curr_idx]
                chosen_curr_indices = curr_indices[chosen_curr_idx]

                # Sample negative examples from other classes
                neg_samples = torch.cat(
                    [dataset[l] for l in dataset if l != curr_label], dim=0
                )
                neg_indices = torch.cat(
                    [data_idx[l] for l in data_idx if l != curr_label], dim=0
                )
                neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples))
                chosen_neg_samples = neg_samples[
                    torch.randperm(len(neg_samples))[:neg_num]
                ]

                # Combine positive and negative samples
                combined_samples = torch.cat([chosen_curr_samples, chosen_neg_samples], dim=0)
                y = torch.cat(
                    [torch.ones(len(chosen_curr_samples)), -torch.ones(neg_num)], dim=0
                ).to(device)

                # Compute the Gram matrix H^\infty
                H_inner = torch.matmul(combined_samples, combined_samples.T)
                del combined_samples
                ###
                H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
                ###
                H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
                del H_inner

                H.fill_diagonal_(0.5)
                ##
                epsilon = 1e-6
                H = H + epsilon * torch.eye(H.size(0), device=H.device)
                ##
                invH = torch.inverse(H)
                del H
                original_error = y @ (invH @ y)

                # Compute CG-scores for each edge
                for i in chosen_curr_indices:
                    print("the node index:", i)
                    for j in range(i + 1, N):  # Upper triangular traversal
                        # print(j)
                        if A[i, j] == 0:  # Skip if no edge exists
                            continue

                        # Remove edge (i, j) to create A1
                        A1 = A.clone()
                        A1[i, j] = A1[j, i] = 0

                        # Recompute AX with A1
                        AX1 = torch.matmul(A1, X).to(device)
                        norm_AX1 = AX1 / torch.norm(AX1, dim=1, keepdim=True)

                        # Repeat error calculation with A1
                        curr_samples_A1 = norm_AX1[chosen_curr_indices]
                        neg_samples_A1 = norm_AX1[neg_indices]
                        chosen_neg_samples_A1 = neg_samples_A1[
                            torch.randperm(len(neg_samples_A1))[:neg_num]
                        ]
                        combined_samples_A1 = torch.cat(
                            [curr_samples_A1, chosen_neg_samples_A1], dim=0
                        )
                        H_inner_A1 = torch.matmul(combined_samples_A1, combined_samples_A1.T)

                        del combined_samples_A1
                        
                        ### trick1
                        H_inner_A1 = torch.clamp(H_inner_A1, min=-1.0, max=1.0)
                        ###

                        H_A1 = H_inner_A1 * (np.pi - torch.acos(H_inner_A1)) / (2 * np.pi)
                        del H_inner_A1
                        H_A1.fill_diagonal_(0.5)

                        ### trick2
                        epsilon = 1e-6
                        H_A1= H_A1 + epsilon * torch.eye(H_A1.size(0), device=H_A1.device)
                        ###
                        invH_A1 = torch.inverse(H_A1)
                        del H_A1

                        error_A1 = y @ (invH_A1 @ y)
                         
                        print("i:", i)
                        print("j:", j)
                        print("current score:", (original_error - error_A1).item())
                        # Compute the difference in error (CG-score)
                        cg_scores["vi"][i, j] += (original_error - error_A1).item()
                        cg_scores["vi"][j, i] = cg_scores["vi"][i, j]  # Symmetric
                        cg_scores["times"][i, j] += 1
                        cg_scores["times"][j, i] += 1

    # Normalize CG-scores by repetition count
    for key, values in cg_scores.items():
        if key == "times":
            continue
        cg_scores[key] = values / np.where(cg_scores["times"] > 0, cg_scores["times"], 1)

    return cg_scores if sub_term else cg_scores["vi"] 

def is_symmetric_sparse(adj): 
    """
    Check if a sparse matrix is symmetric.
    """
    # Check symmetry
    return (adj != adj.transpose()).nnz == 0  # .nnz is the number of non-zero elements

def make_symmetric_sparse(adj):
    """
    Ensure the sparse adjacency matrix is symmetrical.
    """
    # Make the matrix symmetric
    sym_adj = (adj + adj.transpose()) / 2
    return sym_adj

perturbed_adj = make_symmetric_sparse(perturbed_adj)

if type(perturbed_adj) is not torch.Tensor:
    features, perturbed_adj, labels = utils.to_tensor(features, perturbed_adj, labels)
else:
    features = features.to(device)
    perturbed_adj = perturbed_adj.to(device)
    labels = labels.to(device)

if utils.is_sparse_tensor(perturbed_adj):
    
    adj_norm = utils.normalize_adj_tensor(perturbed_adj, sparse=True)
else:
    adj_norm = utils.normalize_adj_tensor(perturbed_adj)

features = features.to_dense()
perturbed_adj = adj_norm.to_dense()


calc_cg_score =  calc_cg_score_gnn_with_sampling(perturbed_adj, features, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False)
save_cg_scores(calc_cg_score, filename="cg_scores_polblogs_0.10.npy")
# print("completed")