File size: 8,419 Bytes
4113c4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# compute cgscore for gcn
# author: Yaning
import torch
import numpy as np
import torch.nn.functional as Fd
from deeprobust.graph.defense import GCNJaccard, GCN
from deeprobust.graph.defense import GCNScore
from deeprobust.graph.utils import *
from deeprobust.graph.data import Dataset, PrePtbDataset
from scipy.sparse import csr_matrix
import argparse
import pickle
from deeprobust.graph import utils
from collections import defaultdict
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
parser.add_argument('--dataset', type=str, default='pubmed', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
parser.add_argument('--ptb_rate', type=float, default=0.05,  help='pertubation rate')

args = parser.parse_args()
args.cuda = torch.cuda.is_available()
print('cuda: %s' % args.cuda)
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

# make sure you use the same data splits as you generated attacks
np.random.seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

# Here the random seed is to split the train/val/test data,
# we need to set the random seed to be the same as that when you generate the perturbed graph
# data = Dataset(root='/tmp/', name=args.dataset, setting='nettack', seed=15)
# Or we can just use setting='prognn' to get the splits
data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
adj, features, labels = data.adj, data.features, data.labels
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test


perturbed_data = PrePtbDataset(root='/tmp/',
        name=args.dataset,
        attack_method='meta',
        ptb_rate=args.ptb_rate)

perturbed_adj = perturbed_data.adj
# perturbed_adj = adj

def save_cg_scores(cg_scores, filename="cg_scores.npy"):
    np.save(filename, cg_scores)
    print(f"CG-scores saved to {filename}")

def load_cg_scores_numpy(filename="cg_scores.npy"):
    cg_scores = np.load(filename, allow_pickle=True)
    print(f"CG-scores loaded from {filename}")
    return cg_scores


import torch
import numpy as np
from collections import defaultdict
from tqdm import tqdm


def calc_cg_score_gnn_with_sampling(
    A, X, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False, batch_size=64
):
    """
    Optimized CG-score calculation with edge batching and GPU acceleration.
    """

    N = A.shape[0]
    cg_scores = {
        "vi": np.zeros((N, N)),
        "ab": np.zeros((N, N)),
        "a2": np.zeros((N, N)),
        "b2": np.zeros((N, N)),
        "times": np.zeros((N, N)),
    }

    A = A.to(device)
    X = X.to(device)
    labels = labels.to(device)

    @torch.no_grad()
    def normalize(tensor):
        return tensor / (torch.norm(tensor, dim=1, keepdim=True) + 1e-8)

    for _ in range(rep_num):
        AX = torch.matmul(A, X)
        norm_AX = normalize(AX)

        # ✨ Step 1: 标签分组(矢量化 + GPU)
        unique_labels = torch.unique(labels)
        label_to_indices = {
            label.item(): (labels == label).nonzero(as_tuple=True)[0] for label in unique_labels
        }
        dataset = {label: norm_AX[indices] for label, indices in label_to_indices.items()}

        # ✨ Step 2: 负样本构建(GPU 上)
        neg_samples_dict = {}
        neg_indices_dict = {}
        for label in unique_labels:
            label = label.item()
            mask = labels != label
            neg_samples = norm_AX[mask]
            neg_indices = mask.nonzero(as_tuple=True)[0]
            neg_samples_dict[label] = neg_samples
            neg_indices_dict[label] = neg_indices

        for curr_label in tqdm(unique_labels.tolist(), desc="Label groups"):
            curr_samples = dataset[curr_label]
            curr_indices = label_to_indices[curr_label]
            curr_num = len(curr_samples)

            chosen_curr_idx = torch.randperm(curr_num, device=device)
            chosen_curr_samples = curr_samples[chosen_curr_idx]
            chosen_curr_indices = curr_indices[chosen_curr_idx]

            neg_samples = neg_samples_dict[curr_label]
            neg_indices = neg_indices_dict[curr_label]
            neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples))
            rand_idx = torch.randperm(len(neg_samples), device=device)[:neg_num]
            chosen_neg_samples = neg_samples[rand_idx]
            chosen_neg_indices = neg_indices[rand_idx]

            combined_samples = torch.cat([chosen_curr_samples, chosen_neg_samples], dim=0)
            y = torch.cat([torch.ones(len(chosen_curr_samples)), -torch.ones(neg_num)], dim=0).to(device)

            # 参考误差
            H_inner = torch.matmul(combined_samples, combined_samples.T)
            H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
            H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
            H.fill_diagonal_(0.5)
            H += 1e-6 * torch.eye(H.size(0), device=device)
            invH = torch.inverse(H)
            original_error = y @ (invH @ y)

            # ✨ Step 3: 收集候选边(仍在 CPU 逻辑)
            edge_batch = []
            for idx_i in chosen_curr_indices.tolist():
                for j in range(idx_i + 1, N):
                    if A[idx_i, j] != 0:
                        edge_batch.append((idx_i, j))

            # ✨ Step 4: 批处理更新
            for k in tqdm(range(0, len(edge_batch), batch_size), desc="Edge batches", leave=False):
                batch = edge_batch[k : k + batch_size]
                B = len(batch)

                norm_AX1_batch = norm_AX.repeat(B, 1, 1).clone()
                for b, (i, j) in enumerate(batch):
                    AX1_i = AX[i] - A[i, j] * X[j]
                    AX1_j = AX[j] - A[j, i] * X[i]
                    norm_AX1_batch[b, i] = AX1_i / (torch.norm(AX1_i) + 1e-8)
                    norm_AX1_batch[b, j] = AX1_j / (torch.norm(AX1_j) + 1e-8)

                sample_idx = chosen_curr_indices.tolist() + chosen_neg_indices.tolist()
                sample_batch = norm_AX1_batch[:, sample_idx, :]  # [B, M, D]

                H_inner = torch.matmul(sample_batch, sample_batch.transpose(1, 2))
                H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
                H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
                eye = torch.eye(H.size(-1), device=device).unsqueeze(0).expand_as(H)
                H = H + 1e-6 * eye
                H.diagonal(dim1=-2, dim2=-1).copy_(0.5)

                invH = torch.inverse(H)
                y_expanded = y.unsqueeze(0).expand(B, -1)
                error_A1 = torch.einsum("bi,bij,bj->b", y_expanded, invH, y_expanded)

                for b, (i, j) in enumerate(batch):
                    score = (original_error - error_A1[b]).item()
                    cg_scores["vi"][i, j] += score
                    cg_scores["vi"][j, i] = score
                    cg_scores["times"][i, j] += 1
                    cg_scores["times"][j, i] += 1

    for key in cg_scores:
        if key != "times":
            cg_scores[key] = cg_scores[key] / np.where(cg_scores["times"] > 0, cg_scores["times"], 1)

    return cg_scores if sub_term else cg_scores["vi"]



def is_symmetric_sparse(adj): 
    """
    Check if a sparse matrix is symmetric.
    """
    # Check symmetry
    return (adj != adj.transpose()).nnz == 0  # .nnz is the number of non-zero elements

def make_symmetric_sparse(adj):
    """
    Ensure the sparse adjacency matrix is symmetrical.
    """
    # Make the matrix symmetric
    sym_adj = (adj + adj.transpose()) / 2
    return sym_adj

perturbed_adj = make_symmetric_sparse(perturbed_adj)

if type(perturbed_adj) is not torch.Tensor:
    features, perturbed_adj, labels = utils.to_tensor(features, perturbed_adj, labels)
else:
    features = features.to(device)
    perturbed_adj = perturbed_adj.to(device)
    labels = labels.to(device)

if utils.is_sparse_tensor(perturbed_adj):
    
    adj_norm = utils.normalize_adj_tensor(perturbed_adj, sparse=True)
else:
    adj_norm = utils.normalize_adj_tensor(perturbed_adj)

features = features.to_dense()
perturbed_adj = adj_norm.to_dense()


calc_cg_score =  calc_cg_score_gnn_with_sampling(perturbed_adj, features, labels, device, rep_num=1, unbalance_ratio=3, sub_term=False, batch_size=512)
save_cg_scores(calc_cg_score, filename="pubmed_0.05.npy")
# print("completed")