CGSCORE / examples /graph /cgscore_save.py
Yaning1001's picture
Add files using upload-large-folder tool
c91d7b1 verified
def calc_cg_score_gnn_with_sampling( # stable training and defense effect
A, X, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False
):
"""
Optimized CG-score calculation with edge sampling.
"""
N = A.shape[0]
cg_scores = {
"vi": np.zeros((N, N)),
"ab": np.zeros((N, N)),
"a2": np.zeros((N, N)),
"b2": np.zeros((N, N)),
"times": np.zeros((N, N)),
}
A = A.to(device)
X = X.to(device)
labels = labels.to(device)
@torch.no_grad()
def normalize(tensor):
return tensor / (torch.norm(tensor, dim=1, keepdim=True) + 1e-8)
for _ in range(rep_num):
AX = torch.matmul(A, X)
norm_AX = normalize(AX)
# Organize data by labels
dataset = defaultdict(list)
data_idx = defaultdict(list)
for i, label in enumerate(labels):
dataset[label.item()].append(norm_AX[i].unsqueeze(0))
data_idx[label.item()].append(i)
for label in dataset:
dataset[label] = torch.cat(dataset[label], dim=0)
data_idx[label] = torch.tensor(data_idx[label], dtype=torch.long, device=device)
# Cache negative samples
neg_samples_dict = {}
neg_indices_dict = {}
for label in dataset:
neg_samples = torch.cat([dataset[l] for l in dataset if l != label])
neg_indices = torch.cat([data_idx[l] for l in data_idx if l != label])
neg_samples_dict[label] = neg_samples
neg_indices_dict[label] = neg_indices
# for curr_label, curr_samples in dataset.items():
for curr_label, curr_samples in tqdm(dataset.items(), desc="Label groups"):
curr_indices = data_idx[curr_label]
curr_num = len(curr_samples)
chosen_curr_idx = np.random.choice(range(curr_num), curr_num, replace=False)
chosen_curr_samples = curr_samples[chosen_curr_idx]
chosen_curr_indices = curr_indices[chosen_curr_idx]
# Get negative samples
neg_samples = neg_samples_dict[curr_label]
neg_indices = neg_indices_dict[curr_label]
neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples))
rand_idx = torch.randperm(len(neg_samples))[:neg_num]
chosen_neg_samples = neg_samples[rand_idx]
chosen_neg_indices = neg_indices[rand_idx]
combined_samples = torch.cat([chosen_curr_samples, chosen_neg_samples], dim=0)
y = torch.cat([torch.ones(len(chosen_curr_samples)), -torch.ones(neg_num)], dim=0).to(device)
# Gram matrix H
H_inner = torch.matmul(combined_samples, combined_samples.T)
H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
H.fill_diagonal_(0.5)
H += 1e-6 * torch.eye(H.size(0), device=device)
invH = torch.inverse(H)
original_error = y @ (invH @ y)
# for idx_i in chosen_curr_indices:
for idx_i in tqdm(chosen_curr_indices.tolist(), desc=f"Nodes in label {curr_label}"):
for j in range(idx_i + 1, N):
if A[idx_i, j] == 0:
continue
# Sparse AX1 update
AX1_i = AX[idx_i] - A[idx_i, j] * X[j]
AX1_j = AX[j] - A[j, idx_i] * X[idx_i]
norm_AX1 = norm_AX.clone()
norm_AX1[idx_i] = AX1_i / (torch.norm(AX1_i) + 1e-8)
norm_AX1[j] = AX1_j / (torch.norm(AX1_j) + 1e-8)
# Updated samples
curr_samples_A1 = norm_AX1[chosen_curr_indices]
neg_samples_A1 = norm_AX1[chosen_neg_indices]
combined_samples_A1 = torch.cat([curr_samples_A1, neg_samples_A1], dim=0)
# Recompute H_A1
H_inner_A1 = torch.matmul(combined_samples_A1, combined_samples_A1.T)
H_inner_A1 = torch.clamp(H_inner_A1, min=-1.0, max=1.0)
H_A1 = H_inner_A1 * (np.pi - torch.acos(H_inner_A1)) / (2 * np.pi)
H_A1.fill_diagonal_(0.5)
H_A1 += 1e-6 * torch.eye(H_A1.size(0), device=device)
invH_A1 = torch.inverse(H_A1)
error_A1 = y @ (invH_A1 @ y)
score = (original_error - error_A1).item()
cg_scores["vi"][idx_i, j] += score
cg_scores["vi"][j, idx_i] = cg_scores["vi"][idx_i, j]
cg_scores["times"][idx_i, j] += 1
cg_scores["times"][j, idx_i] += 1
# Normalize
for key in cg_scores:
if key != "times":
cg_scores[key] = cg_scores[key] / np.where(cg_scores["times"] > 0, cg_scores["times"], 1)
return cg_scores if sub_term else cg_scores["vi"]
def calc_cg_score_gnn_with_sampling(
A, X, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False, batch_size=64
):
"""
Optimized CG-score calculation with edge batching and GPU acceleration.
"""
# if hasattr(torch, "compile"):
# calc_cg_score_gnn_with_sampling = torch.compile(calc_cg_score_gnn_with_sampling)
N = A.shape[0]
cg_scores = {
"vi": np.zeros((N, N)),
"ab": np.zeros((N, N)),
"a2": np.zeros((N, N)),
"b2": np.zeros((N, N)),
"times": np.zeros((N, N)),
}
A = A.to(device)
X = X.to(device)
labels = labels.to(device)
@torch.no_grad()
def normalize(tensor):
return tensor / (torch.norm(tensor, dim=1, keepdim=True) + 1e-8)
for _ in range(rep_num):
AX = torch.matmul(A, X)
norm_AX = normalize(AX)
# Group nodes by label
dataset = defaultdict(list)
data_idx = defaultdict(list)
for i, label in enumerate(labels):
dataset[label.item()].append(norm_AX[i].unsqueeze(0))
data_idx[label.item()].append(i)
for label in dataset:
dataset[label] = torch.cat(dataset[label], dim=0)
data_idx[label] = torch.tensor(data_idx[label], dtype=torch.long, device=device)
# Prepare negative samples
neg_samples_dict = {}
neg_indices_dict = {}
for label in dataset:
neg_samples = torch.cat([dataset[l] for l in dataset if l != label])
neg_indices = torch.cat([data_idx[l] for l in data_idx if l != label])
neg_samples_dict[label] = neg_samples
neg_indices_dict[label] = neg_indices
for curr_label, curr_samples in tqdm(dataset.items(), desc="Label groups"):
curr_indices = data_idx[curr_label]
curr_num = len(curr_samples)
chosen_curr_idx = np.random.choice(range(curr_num), curr_num, replace=False)
chosen_curr_samples = curr_samples[chosen_curr_idx]
chosen_curr_indices = curr_indices[chosen_curr_idx]
neg_samples = neg_samples_dict[curr_label]
neg_indices = neg_indices_dict[curr_label]
neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples))
rand_idx = torch.randperm(len(neg_samples))[:neg_num]
chosen_neg_samples = neg_samples[rand_idx]
chosen_neg_indices = neg_indices[rand_idx]
combined_samples = torch.cat([chosen_curr_samples, chosen_neg_samples], dim=0)
y = torch.cat([torch.ones(len(chosen_curr_samples)), -torch.ones(neg_num)], dim=0).to(device)
# Compute reference error
H_inner = torch.matmul(combined_samples, combined_samples.T)
H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
H.fill_diagonal_(0.5)
H += 1e-6 * torch.eye(H.size(0), device=device)
invH = torch.inverse(H)
original_error = y @ (invH @ y)
# Gather candidate edges
edge_batch = []
for idx_i in chosen_curr_indices.tolist():
for j in range(idx_i + 1, N):
if A[idx_i, j] != 0:
edge_batch.append((idx_i, j))
# Process in batches
for k in tqdm(range(0, len(edge_batch), batch_size), desc="Edge batches", leave=False):
batch = edge_batch[k : k + batch_size]
B = len(batch)
norm_AX1_batch = norm_AX.repeat(B, 1, 1)
updates = []
for b, (i, j) in enumerate(batch):
AX1_i = AX[i] - A[i, j] * X[j]
AX1_j = AX[j] - A[j, i] * X[i]
norm_AX1_batch[b, i] = AX1_i / (torch.norm(AX1_i) + 1e-8)
norm_AX1_batch[b, j] = AX1_j / (torch.norm(AX1_j) + 1e-8)
sample_idx = chosen_curr_indices.tolist() + chosen_neg_indices.tolist()
sample_batch = norm_AX1_batch[:, sample_idx, :] # [B, M, D]
H_inner = torch.matmul(sample_batch, sample_batch.transpose(1, 2))
H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
eye = torch.eye(H.size(-1), device=device).unsqueeze(0).expand_as(H)
H = H + 1e-6 * eye
H.diagonal(dim1=-2, dim2=-1).copy_(0.5)
invH = torch.inverse(H)
y_expanded = y.unsqueeze(0).expand(B, -1)
error_A1 = torch.einsum('bi,bij,bj->b', y_expanded, invH, y_expanded)
for b, (i, j) in enumerate(batch):
score = (original_error - error_A1[b]).item()
cg_scores["vi"][i, j] += score
cg_scores["vi"][j, i] = cg_scores["vi"][i, j]
cg_scores["times"][i, j] += 1
cg_scores["times"][j, i] += 1
for key in cg_scores:
if key != "times":
cg_scores[key] = cg_scores[key] / np.where(cg_scores["times"] > 0, cg_scores["times"], 1)
return cg_scores if sub_term else cg_scores["vi"]
def calc_cg_score_gnn_with_sampling( # based on the front code, remove more data to GPU, effect is approxiamate
A, X, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False
):
"""
Calculate CG-score for each edge in a graph with node labels and random sampling.
Args:
A: torch.Tensor
Adjacency matrix of the graph (size: N x N).
X: torch.Tensor
Node features matrix (size: N x F).
labels: torch.Tensor
Node labels (size: N).
device: torch.device
Device to perform calculations.
rep_num: int
Number of repetitions for Monte Carlo sampling.
unbalance_ratio: float
Ratio of unbalanced data (1:unbalance_ratio).
sub_term: bool
If True, calculate and return sub-terms.
Returns:
cg_scores: dict
Dictionary containing CG-scores for edges and optionally sub-terms.
"""
N = A.shape[0]
cg_scores = {
"vi": np.zeros((N, N)),
"ab": np.zeros((N, N)),
"a2": np.zeros((N, N)),
"b2": np.zeros((N, N)),
"times": np.zeros((N, N)),
}
with torch.no_grad():
for _ in range(rep_num):
# Compute AX (node representations)
AX = torch.matmul(A, X).to(device)
norm_AX = AX / torch.norm(AX, dim=1, keepdim=True)
# Group nodes by their labels
dataset = defaultdict(list)
data_idx = defaultdict(list)
for i, label in enumerate(labels):
dataset[label.item()].append(norm_AX[i].unsqueeze(0)) # Store normalized data
data_idx[label.item()].append(i) # Store indices
# Convert to tensors
for label, data_list in dataset.items():
dataset[label] = torch.cat(data_list, dim=0)
data_idx[label] = torch.tensor(data_idx[label], dtype=torch.long, device=device)
# Calculate CG-scores for each label group
for curr_label, curr_samples in dataset.items():
curr_indices = data_idx[curr_label]
curr_num = len(curr_samples)
# Randomly sample a subset of current label examples
chosen_curr_idx = np.random.choice(range(curr_num), curr_num, replace=False)
chosen_curr_samples = curr_samples[chosen_curr_idx]
chosen_curr_indices = curr_indices[chosen_curr_idx]
# Sample negative examples from other classes
neg_samples = torch.cat(
[dataset[l] for l in dataset if l != curr_label], dim=0
)
neg_indices = torch.cat(
[data_idx[l] for l in data_idx if l != curr_label], dim=0
)
neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples))
chosen_neg_samples = neg_samples[
torch.randperm(len(neg_samples))[:neg_num]
]
# Combine positive and negative samples
combined_samples = torch.cat([chosen_curr_samples, chosen_neg_samples], dim=0)
y = torch.cat(
[torch.ones(len(chosen_curr_samples)), -torch.ones(neg_num)], dim=0
).to(device)
# Compute the Gram matrix H^\infty
H_inner = torch.matmul(combined_samples, combined_samples.T)
del combined_samples
###
H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
###
H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
del H_inner
H.fill_diagonal_(0.5)
##
epsilon = 1e-6
H = H + epsilon * torch.eye(H.size(0), device=H.device)
##
invH = torch.inverse(H)
del H
original_error = y @ (invH @ y)
# Compute CG-scores for each edge
for i in chosen_curr_indices:
print("the node index:", i)
for j in range(i + 1, N): # Upper triangular traversal
# print(j)
if A[i, j] == 0: # Skip if no edge exists
continue
# Remove edge (i, j) to create A1
A1 = A.clone()
A1[i, j] = A1[j, i] = 0
# Recompute AX with A1
AX1 = torch.matmul(A1, X).to(device)
norm_AX1 = AX1 / torch.norm(AX1, dim=1, keepdim=True)
# Repeat error calculation with A1
curr_samples_A1 = norm_AX1[chosen_curr_indices]
neg_samples_A1 = norm_AX1[neg_indices]
chosen_neg_samples_A1 = neg_samples_A1[
torch.randperm(len(neg_samples_A1))[:neg_num]
]
combined_samples_A1 = torch.cat(
[curr_samples_A1, chosen_neg_samples_A1], dim=0
)
H_inner_A1 = torch.matmul(combined_samples_A1, combined_samples_A1.T)
del combined_samples_A1
### trick1
H_inner_A1 = torch.clamp(H_inner_A1, min=-1.0, max=1.0)
###
H_A1 = H_inner_A1 * (np.pi - torch.acos(H_inner_A1)) / (2 * np.pi)
del H_inner_A1
H_A1.fill_diagonal_(0.5)
### trick2
epsilon = 1e-6
H_A1= H_A1 + epsilon * torch.eye(H_A1.size(0), device=H_A1.device)
###
invH_A1 = torch.inverse(H_A1)
del H_A1
error_A1 = y @ (invH_A1 @ y)
print("i:", i)
print("j:", j)
print("current score:", (original_error - error_A1).item())
# Compute the difference in error (CG-score)
cg_scores["vi"][i, j] += (original_error - error_A1).item()
cg_scores["vi"][j, i] = cg_scores["vi"][i, j] # Symmetric
cg_scores["times"][i, j] += 1
cg_scores["times"][j, i] += 1
# Normalize CG-scores by repetition count
for key, values in cg_scores.items():
if key == "times":
continue
cg_scores[key] = values / np.where(cg_scores["times"] > 0, cg_scores["times"], 1)