File size: 11,134 Bytes
c91d7b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
# compute cgscore for gcn
# author: Yaning
import torch
import numpy as np
import torch.nn.functional as Fd
from deeprobust.graph.defense import GCNJaccard, GCN
from deeprobust.graph.defense import GCNScore
from deeprobust.graph.utils import *
from deeprobust.graph.data import Dataset, PrePtbDataset
from scipy.sparse import csr_matrix
import argparse
import pickle
from deeprobust.graph import utils
import torch.multiprocessing as mp
from collections import defaultdict
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
parser.add_argument('--dataset', type=str, default='pubmed', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
parser.add_argument('--ptb_rate', type=float, default=0.05,  help='pertubation rate')

args = parser.parse_args()
args.cuda = torch.cuda.is_available()
print('cuda: %s' % args.cuda)
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

# make sure you use the same data splits as you generated attacks
np.random.seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

# Here the random seed is to split the train/val/test data,
# we need to set the random seed to be the same as that when you generate the perturbed graph
# data = Dataset(root='/tmp/', name=args.dataset, setting='nettack', seed=15)
# Or we can just use setting='prognn' to get the splits
data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
adj, features, labels = data.adj, data.features, data.labels
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test


perturbed_data = PrePtbDataset(root='/tmp/',
        name=args.dataset,
        attack_method='meta',
        ptb_rate=args.ptb_rate)

perturbed_adj = perturbed_data.adj
# perturbed_adj = adj

def save_cg_scores(cg_scores, filename="cg_scores.npy"):
    np.save(filename, cg_scores)
    print(f"CG-scores saved to {filename}")

def load_cg_scores_numpy(filename="cg_scores.npy"):
    cg_scores = np.load(filename, allow_pickle=True)
    print(f"CG-scores loaded from {filename}")
    return cg_scores

def calc_cg_score_gnn_with_sampling(
    A, X, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False, batch_size=64, label_filter=None
):
    N = A.shape[0]
    cg_scores = {
        "vi": np.zeros((N, N)),
        "ab": np.zeros((N, N)),
        "a2": np.zeros((N, N)),
        "b2": np.zeros((N, N)),
        "times": np.zeros((N, N)),
    }

    A = A.to(device)
    X = X.to(device)
    labels = labels.to(device)

    @torch.no_grad()
    def normalize(tensor):
        return tensor / (torch.norm(tensor, dim=1, keepdim=True) + 1e-8)

    for _ in range(rep_num):
        AX = torch.matmul(A, X)
        norm_AX = normalize(AX)

        unique_labels = torch.unique(labels)
        if label_filter is not None:
            unique_labels = [label for label in unique_labels if label.item() in label_filter]

        label_to_indices = {
            label.item(): (labels == label).nonzero(as_tuple=True)[0] for label in unique_labels
        }
        dataset = {label: norm_AX[indices] for label, indices in label_to_indices.items()}

        neg_samples_dict = {}
        neg_indices_dict = {}
        for label in unique_labels:
            print("label:", label)
            label = label.item()
            mask = labels != label
            neg_samples = norm_AX[mask]
            neg_indices = mask.nonzero(as_tuple=True)[0]
            neg_samples_dict[label] = neg_samples
            neg_indices_dict[label] = neg_indices

        for curr_label in tqdm(unique_labels, desc="Label groups", position=device.index):
            print("curr_label:", curr_label)
            label_id = int(curr_label)
            curr_samples = dataset[label_id]
            curr_indices = label_to_indices[label_id]
            curr_num = len(curr_samples)

            chosen_curr_idx = torch.randperm(curr_num, device=device)
            chosen_curr_samples = curr_samples[chosen_curr_idx]
            chosen_curr_indices = curr_indices[chosen_curr_idx]

            neg_samples = neg_samples_dict[label_id]
            neg_indices = neg_indices_dict[label_id]
            neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples))
            rand_idx = torch.randperm(len(neg_samples), device=device)[:neg_num]
            chosen_neg_samples = neg_samples[rand_idx]
            chosen_neg_indices = neg_indices[rand_idx]

            combined_samples = torch.cat([chosen_curr_samples, chosen_neg_samples], dim=0)
            y = torch.cat([torch.ones(len(chosen_curr_samples)), -torch.ones(neg_num)], dim=0).to(device)

            H_inner = torch.matmul(combined_samples, combined_samples.T)
            H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
            H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
            H.fill_diagonal_(0.5)
            H += 1e-6 * torch.eye(H.size(0), device=device)
            invH = torch.inverse(H)
            original_error = y @ (invH @ y)

            edge_batch = []
            for idx_i in chosen_curr_indices.tolist():
                for j in range(idx_i + 1, N):
                    if A[idx_i, j] != 0:
                        edge_batch.append((idx_i, j))

            for k in tqdm(range(0, len(edge_batch), batch_size), desc="Edge batches", leave=False, position=device.index):
                batch = edge_batch[k: k + batch_size]
                B = len(batch)

                norm_AX1_batch = norm_AX.repeat(B, 1, 1).clone()
                for b, (i, j) in enumerate(batch):
                    AX1_i = AX[i] - A[i, j] * X[j]
                    AX1_j = AX[j] - A[j, i] * X[i]
                    norm_AX1_batch[b, i] = AX1_i / (torch.norm(AX1_i) + 1e-8)
                    norm_AX1_batch[b, j] = AX1_j / (torch.norm(AX1_j) + 1e-8)

                sample_idx = chosen_curr_indices.tolist() + chosen_neg_indices.tolist()
                sample_batch = norm_AX1_batch[:, sample_idx, :]

                H_inner = torch.matmul(sample_batch, sample_batch.transpose(1, 2))
                H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
                H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
                eye = torch.eye(H.size(-1), device=device).unsqueeze(0).expand_as(H)
                H = H + 1e-6 * eye
                H.diagonal(dim1=-2, dim2=-1).copy_(0.5)

                invH = torch.inverse(H)
                y_expanded = y.unsqueeze(0).expand(B, -1)
                error_A1 = torch.einsum("bi,bij,bj->b", y_expanded, invH, y_expanded)

                for b, (i, j) in enumerate(batch):
                    score = (original_error - error_A1[b]).item()
                    cg_scores["vi"][i, j] += score
                    cg_scores["vi"][j, i] = score
                    cg_scores["times"][i, j] += 1
                    cg_scores["times"][j, i] += 1

    for key in cg_scores:
        if key != "times":
            cg_scores[key] = cg_scores[key] / np.where(cg_scores["times"] > 0, cg_scores["times"], 1)

    # return cg_scores if sub_term else cg_scores["vi"]
    return cg_scores


def run_worker(gpu_id, world_size, A, X, labels, rep_num, unbalance_ratio, sub_term, batch_size, return_dict):
    device = torch.device(f"cuda:{gpu_id}")
    unique_labels = torch.unique(labels).tolist()
    label_chunks = np.array_split(unique_labels, world_size)
    rank = torch.cuda.current_device()
    label_filter = [int(l) for l in label_chunks[gpu_id % world_size]]

    result = calc_cg_score_gnn_with_sampling(
        A, X, labels, device,
        rep_num=rep_num,
        unbalance_ratio=unbalance_ratio,
        sub_term=sub_term,
        batch_size=batch_size,
        label_filter=label_filter
    )
    return_dict[gpu_id] = result



def multi_gpu_wrapper(A, X, labels, rep_num=1, unbalance_ratio=1, sub_term=False, batch_size=64, gpu_ids=None):
    if gpu_ids is None:
        gpu_ids = list(range(torch.cuda.device_count()))
    world_size = len(gpu_ids)

    manager = mp.Manager()
    return_dict = manager.dict()
    processes = []

    for local_rank, gpu_id in enumerate(gpu_ids):
        p = mp.Process(
            target=run_worker,
            args=(gpu_id, world_size, A, X, labels, rep_num, unbalance_ratio, sub_term, batch_size, return_dict)
        )
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

    # 初始化 final_score
    final_score = None

    for gpu_id, rank_result in return_dict.items():
        if not isinstance(rank_result, dict):
            print(f"[FATAL] GPU {gpu_id} result is not a dict: {type(rank_result)}")
            continue

        if final_score is None:
            # 深拷贝防止指针复用
            final_score = {k: np.copy(v) for k, v in rank_result.items()}
        else:
            for key in rank_result:
                if key not in final_score:
                    print(f"[WARN] key '{key}' not in final_score. Skipping.")
                    continue
                try:
                    if isinstance(final_score[key], np.ndarray) and isinstance(rank_result[key], np.ndarray):
                        final_score[key] += rank_result[key]
                    else:
                        print(f"[WARN] Skipped merging key '{key}' due to type mismatch.")
                except Exception as e:
                    print(f"[ERROR] Failed merging key '{key}': {e}")

    return final_score



def is_symmetric_sparse(adj): 
    """
    Check if a sparse matrix is symmetric.
    """
    # Check symmetry
    return (adj != adj.transpose()).nnz == 0  # .nnz is the number of non-zero elements

def make_symmetric_sparse(adj):
    """
    Ensure the sparse adjacency matrix is symmetrical.
    """
    # Make the matrix symmetric
    sym_adj = (adj + adj.transpose()) / 2
    return sym_adj

if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)

    print("cuda:", torch.cuda.is_available())

    # 选择使用的 GPU(根据你的实际情况)
    selected_gpus = [0, 1, 2, 3]

    # 稀疏矩阵对称处理
    perturbed_adj = make_symmetric_sparse(perturbed_adj)

    # 转 tensor
    if type(perturbed_adj) is not torch.Tensor:
        features, perturbed_adj, labels = utils.to_tensor(features, perturbed_adj, labels)
    else:
        features = features.to(device)
        perturbed_adj = perturbed_adj.to(device)
        labels = labels.to(device)

    # 标准化邻接
    if utils.is_sparse_tensor(perturbed_adj):
        adj_norm = utils.normalize_adj_tensor(perturbed_adj, sparse=True)
    else:
        adj_norm = utils.normalize_adj_tensor(perturbed_adj)

    features = features.to_dense()
    perturbed_adj = adj_norm.to_dense()

    #  多GPU并行计算 CG-score
    calc_cg_score = multi_gpu_wrapper(
        perturbed_adj, features, labels,
        rep_num=1,
        unbalance_ratio=1,
        sub_term=False,
        batch_size=1024,
        gpu_ids=selected_gpus
    )
    save_cg_scores(calc_cg_score["vi"], filename=f"{args.dataset}_{args.ptb_rate}.npy")

    print(" CG-score computation completed.")