File size: 16,936 Bytes
c91d7b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
def calc_cg_score_gnn_with_sampling(   # stable training and defense effect
    A, X, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False
):
    """
    Optimized CG-score calculation with edge sampling.
    """

    N = A.shape[0]
    cg_scores = {
        "vi": np.zeros((N, N)),
        "ab": np.zeros((N, N)),
        "a2": np.zeros((N, N)),
        "b2": np.zeros((N, N)),
        "times": np.zeros((N, N)),
    }

    A = A.to(device)
    X = X.to(device)
    labels = labels.to(device)

    @torch.no_grad()
    def normalize(tensor):
        return tensor / (torch.norm(tensor, dim=1, keepdim=True) + 1e-8)

    for _ in range(rep_num):
        AX = torch.matmul(A, X)
        norm_AX = normalize(AX)

        # Organize data by labels
        dataset = defaultdict(list)
        data_idx = defaultdict(list)
        for i, label in enumerate(labels):
            dataset[label.item()].append(norm_AX[i].unsqueeze(0))
            data_idx[label.item()].append(i)

        for label in dataset:
            dataset[label] = torch.cat(dataset[label], dim=0)
            data_idx[label] = torch.tensor(data_idx[label], dtype=torch.long, device=device)

        # Cache negative samples
        neg_samples_dict = {}
        neg_indices_dict = {}
        for label in dataset:
            neg_samples = torch.cat([dataset[l] for l in dataset if l != label])
            neg_indices = torch.cat([data_idx[l] for l in data_idx if l != label])
            neg_samples_dict[label] = neg_samples
            neg_indices_dict[label] = neg_indices

        # for curr_label, curr_samples in dataset.items():
        for curr_label, curr_samples in tqdm(dataset.items(), desc="Label groups"):
            curr_indices = data_idx[curr_label]
            curr_num = len(curr_samples)

            chosen_curr_idx = np.random.choice(range(curr_num), curr_num, replace=False)
            chosen_curr_samples = curr_samples[chosen_curr_idx]
            chosen_curr_indices = curr_indices[chosen_curr_idx]

            # Get negative samples
            neg_samples = neg_samples_dict[curr_label]
            neg_indices = neg_indices_dict[curr_label]
            neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples))
            rand_idx = torch.randperm(len(neg_samples))[:neg_num]
            chosen_neg_samples = neg_samples[rand_idx]
            chosen_neg_indices = neg_indices[rand_idx]

            combined_samples = torch.cat([chosen_curr_samples, chosen_neg_samples], dim=0)
            y = torch.cat([torch.ones(len(chosen_curr_samples)), -torch.ones(neg_num)], dim=0).to(device)

            # Gram matrix H
            H_inner = torch.matmul(combined_samples, combined_samples.T)
            H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
            H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
            H.fill_diagonal_(0.5)
            H += 1e-6 * torch.eye(H.size(0), device=device)
            invH = torch.inverse(H)
            original_error = y @ (invH @ y)

            # for idx_i in chosen_curr_indices:
            for idx_i in tqdm(chosen_curr_indices.tolist(), desc=f"Nodes in label {curr_label}"):
                for j in range(idx_i + 1, N):
                    if A[idx_i, j] == 0:
                        continue

                    # Sparse AX1 update
                    AX1_i = AX[idx_i] - A[idx_i, j] * X[j]
                    AX1_j = AX[j] - A[j, idx_i] * X[idx_i]

                    norm_AX1 = norm_AX.clone()
                    norm_AX1[idx_i] = AX1_i / (torch.norm(AX1_i) + 1e-8)
                    norm_AX1[j] = AX1_j / (torch.norm(AX1_j) + 1e-8)

                    # Updated samples
                    curr_samples_A1 = norm_AX1[chosen_curr_indices]
                    neg_samples_A1 = norm_AX1[chosen_neg_indices]
                    combined_samples_A1 = torch.cat([curr_samples_A1, neg_samples_A1], dim=0)

                    # Recompute H_A1
                    H_inner_A1 = torch.matmul(combined_samples_A1, combined_samples_A1.T)
                    H_inner_A1 = torch.clamp(H_inner_A1, min=-1.0, max=1.0)
                    H_A1 = H_inner_A1 * (np.pi - torch.acos(H_inner_A1)) / (2 * np.pi)
                    H_A1.fill_diagonal_(0.5)
                    H_A1 += 1e-6 * torch.eye(H_A1.size(0), device=device)
                    invH_A1 = torch.inverse(H_A1)
                    error_A1 = y @ (invH_A1 @ y)

                    score = (original_error - error_A1).item()
                    cg_scores["vi"][idx_i, j] += score
                    cg_scores["vi"][j, idx_i] = cg_scores["vi"][idx_i, j]
                    cg_scores["times"][idx_i, j] += 1
                    cg_scores["times"][j, idx_i] += 1

    # Normalize
    for key in cg_scores:
        if key != "times":
            cg_scores[key] = cg_scores[key] / np.where(cg_scores["times"] > 0, cg_scores["times"], 1)

    return cg_scores if sub_term else cg_scores["vi"]


def calc_cg_score_gnn_with_sampling(
    A, X, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False, batch_size=64
):
    """
    Optimized CG-score calculation with edge batching and GPU acceleration.
    """
    # if hasattr(torch, "compile"):
    #     calc_cg_score_gnn_with_sampling = torch.compile(calc_cg_score_gnn_with_sampling)

    N = A.shape[0]
    cg_scores = {
        "vi": np.zeros((N, N)),
        "ab": np.zeros((N, N)),
        "a2": np.zeros((N, N)),
        "b2": np.zeros((N, N)),
        "times": np.zeros((N, N)),
    }

    A = A.to(device)
    X = X.to(device)
    labels = labels.to(device)

    @torch.no_grad()
    def normalize(tensor):
        return tensor / (torch.norm(tensor, dim=1, keepdim=True) + 1e-8)

    for _ in range(rep_num):
        AX = torch.matmul(A, X)
        norm_AX = normalize(AX)

        # Group nodes by label
        dataset = defaultdict(list)
        data_idx = defaultdict(list)
        for i, label in enumerate(labels):
            dataset[label.item()].append(norm_AX[i].unsqueeze(0))
            data_idx[label.item()].append(i)

        for label in dataset:
            dataset[label] = torch.cat(dataset[label], dim=0)
            data_idx[label] = torch.tensor(data_idx[label], dtype=torch.long, device=device)

        # Prepare negative samples
        neg_samples_dict = {}
        neg_indices_dict = {}
        for label in dataset:
            neg_samples = torch.cat([dataset[l] for l in dataset if l != label])
            neg_indices = torch.cat([data_idx[l] for l in data_idx if l != label])
            neg_samples_dict[label] = neg_samples
            neg_indices_dict[label] = neg_indices

        for curr_label, curr_samples in tqdm(dataset.items(), desc="Label groups"):
            curr_indices = data_idx[curr_label]
            curr_num = len(curr_samples)

            chosen_curr_idx = np.random.choice(range(curr_num), curr_num, replace=False)
            chosen_curr_samples = curr_samples[chosen_curr_idx]
            chosen_curr_indices = curr_indices[chosen_curr_idx]

            neg_samples = neg_samples_dict[curr_label]
            neg_indices = neg_indices_dict[curr_label]
            neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples))
            rand_idx = torch.randperm(len(neg_samples))[:neg_num]
            chosen_neg_samples = neg_samples[rand_idx]
            chosen_neg_indices = neg_indices[rand_idx]

            combined_samples = torch.cat([chosen_curr_samples, chosen_neg_samples], dim=0)
            y = torch.cat([torch.ones(len(chosen_curr_samples)), -torch.ones(neg_num)], dim=0).to(device)

            # Compute reference error
            H_inner = torch.matmul(combined_samples, combined_samples.T)
            H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
            H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
            H.fill_diagonal_(0.5)
            H += 1e-6 * torch.eye(H.size(0), device=device)
            invH = torch.inverse(H)
            original_error = y @ (invH @ y)

            # Gather candidate edges
            edge_batch = []
            for idx_i in chosen_curr_indices.tolist():
                for j in range(idx_i + 1, N):
                    if A[idx_i, j] != 0:
                        edge_batch.append((idx_i, j))

            # Process in batches
            for k in tqdm(range(0, len(edge_batch), batch_size), desc="Edge batches", leave=False):
                batch = edge_batch[k : k + batch_size]
                B = len(batch)

                norm_AX1_batch = norm_AX.repeat(B, 1, 1)
                updates = []
                for b, (i, j) in enumerate(batch):
                    AX1_i = AX[i] - A[i, j] * X[j]
                    AX1_j = AX[j] - A[j, i] * X[i]
                    norm_AX1_batch[b, i] = AX1_i / (torch.norm(AX1_i) + 1e-8)
                    norm_AX1_batch[b, j] = AX1_j / (torch.norm(AX1_j) + 1e-8)

                sample_idx = chosen_curr_indices.tolist() + chosen_neg_indices.tolist()
                sample_batch = norm_AX1_batch[:, sample_idx, :]  # [B, M, D]

                H_inner = torch.matmul(sample_batch, sample_batch.transpose(1, 2))
                H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
                H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
                eye = torch.eye(H.size(-1), device=device).unsqueeze(0).expand_as(H)
                H = H + 1e-6 * eye
                H.diagonal(dim1=-2, dim2=-1).copy_(0.5)

                invH = torch.inverse(H)
                y_expanded = y.unsqueeze(0).expand(B, -1)
                error_A1 = torch.einsum('bi,bij,bj->b', y_expanded, invH, y_expanded)

                for b, (i, j) in enumerate(batch):
                    score = (original_error - error_A1[b]).item()
                    cg_scores["vi"][i, j] += score
                    cg_scores["vi"][j, i] = cg_scores["vi"][i, j]
                    cg_scores["times"][i, j] += 1
                    cg_scores["times"][j, i] += 1

    for key in cg_scores:
        if key != "times":
            cg_scores[key] = cg_scores[key] / np.where(cg_scores["times"] > 0, cg_scores["times"], 1)

    return cg_scores if sub_term else cg_scores["vi"]

def calc_cg_score_gnn_with_sampling(      # based on the front code, remove more data to GPU, effect is approxiamate
    A, X, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False
):
    """
    Calculate CG-score for each edge in a graph with node labels and random sampling.

    Args:
        A: torch.Tensor
            Adjacency matrix of the graph (size: N x N).
        X: torch.Tensor
            Node features matrix (size: N x F).
        labels: torch.Tensor
            Node labels (size: N).
        device: torch.device
            Device to perform calculations.
        rep_num: int
            Number of repetitions for Monte Carlo sampling.
        unbalance_ratio: float
            Ratio of unbalanced data (1:unbalance_ratio).
        sub_term: bool
            If True, calculate and return sub-terms.

    Returns:
        cg_scores: dict
            Dictionary containing CG-scores for edges and optionally sub-terms.
    """
    N = A.shape[0]
    cg_scores = {
        "vi": np.zeros((N, N)),
        "ab": np.zeros((N, N)),
        "a2": np.zeros((N, N)),
        "b2": np.zeros((N, N)),
        "times": np.zeros((N, N)),
    }

    with torch.no_grad():
        for _ in range(rep_num):
            # Compute AX (node representations)
            AX = torch.matmul(A, X).to(device)
            norm_AX = AX / torch.norm(AX, dim=1, keepdim=True)

            # Group nodes by their labels
            dataset = defaultdict(list)
            data_idx = defaultdict(list)
            for i, label in enumerate(labels):
                dataset[label.item()].append(norm_AX[i].unsqueeze(0))  # Store normalized data
                data_idx[label.item()].append(i)  # Store indices

            # Convert to tensors
            for label, data_list in dataset.items():
                dataset[label] = torch.cat(data_list, dim=0)
                data_idx[label] = torch.tensor(data_idx[label], dtype=torch.long, device=device)

            # Calculate CG-scores for each label group
            for curr_label, curr_samples in dataset.items():
                curr_indices = data_idx[curr_label]
                curr_num = len(curr_samples)

                # Randomly sample a subset of current label examples
                chosen_curr_idx = np.random.choice(range(curr_num), curr_num, replace=False)
                chosen_curr_samples = curr_samples[chosen_curr_idx]
                chosen_curr_indices = curr_indices[chosen_curr_idx]

                # Sample negative examples from other classes
                neg_samples = torch.cat(
                    [dataset[l] for l in dataset if l != curr_label], dim=0
                )
                neg_indices = torch.cat(
                    [data_idx[l] for l in data_idx if l != curr_label], dim=0
                )
                neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples))
                chosen_neg_samples = neg_samples[
                    torch.randperm(len(neg_samples))[:neg_num]
                ]

                # Combine positive and negative samples
                combined_samples = torch.cat([chosen_curr_samples, chosen_neg_samples], dim=0)
                y = torch.cat(
                    [torch.ones(len(chosen_curr_samples)), -torch.ones(neg_num)], dim=0
                ).to(device)

                # Compute the Gram matrix H^\infty
                H_inner = torch.matmul(combined_samples, combined_samples.T)
                del combined_samples
                ###
                H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
                ###
                H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
                del H_inner

                H.fill_diagonal_(0.5)
                ##
                epsilon = 1e-6
                H = H + epsilon * torch.eye(H.size(0), device=H.device)
                ##
                invH = torch.inverse(H)
                del H
                original_error = y @ (invH @ y)

                # Compute CG-scores for each edge
                for i in chosen_curr_indices:
                    print("the node index:", i)
                    for j in range(i + 1, N):  # Upper triangular traversal
                        # print(j)
                        if A[i, j] == 0:  # Skip if no edge exists
                            continue

                        # Remove edge (i, j) to create A1
                        A1 = A.clone()
                        A1[i, j] = A1[j, i] = 0

                        # Recompute AX with A1
                        AX1 = torch.matmul(A1, X).to(device)
                        norm_AX1 = AX1 / torch.norm(AX1, dim=1, keepdim=True)

                        # Repeat error calculation with A1
                        curr_samples_A1 = norm_AX1[chosen_curr_indices]
                        neg_samples_A1 = norm_AX1[neg_indices]
                        chosen_neg_samples_A1 = neg_samples_A1[
                            torch.randperm(len(neg_samples_A1))[:neg_num]
                        ]
                        combined_samples_A1 = torch.cat(
                            [curr_samples_A1, chosen_neg_samples_A1], dim=0
                        )
                        H_inner_A1 = torch.matmul(combined_samples_A1, combined_samples_A1.T)

                        del combined_samples_A1
                        
                        ### trick1
                        H_inner_A1 = torch.clamp(H_inner_A1, min=-1.0, max=1.0)
                        ###

                        H_A1 = H_inner_A1 * (np.pi - torch.acos(H_inner_A1)) / (2 * np.pi)
                        del H_inner_A1
                        H_A1.fill_diagonal_(0.5)

                        ### trick2
                        epsilon = 1e-6
                        H_A1= H_A1 + epsilon * torch.eye(H_A1.size(0), device=H_A1.device)
                        ###
                        invH_A1 = torch.inverse(H_A1)
                        del H_A1

                        error_A1 = y @ (invH_A1 @ y)
                         
                        print("i:", i)
                        print("j:", j)
                        print("current score:", (original_error - error_A1).item())
                        # Compute the difference in error (CG-score)
                        cg_scores["vi"][i, j] += (original_error - error_A1).item()
                        cg_scores["vi"][j, i] = cg_scores["vi"][i, j]  # Symmetric
                        cg_scores["times"][i, j] += 1
                        cg_scores["times"][j, i] += 1

    # Normalize CG-scores by repetition count
    for key, values in cg_scores.items():
        if key == "times":
            continue
        cg_scores[key] = values / np.where(cg_scores["times"] > 0, cg_scores["times"], 1)