|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
def remove_least_important_edges(adj, cgscore, remove_ratio=0.8): |
|
|
""" |
|
|
Remove the least important edges based on CGScore. |
|
|
|
|
|
Args: |
|
|
adj (torch.Tensor): Original adjacency matrix (N x N). |
|
|
cgscore (np.ndarray): CGScore matrix (N x N). |
|
|
keep_ratio (float): Ratio of edges to keep (default: 0.8). |
|
|
|
|
|
Returns: |
|
|
adj (torch.Tensor): Adjusted adjacency matrix after removing edges. |
|
|
""" |
|
|
|
|
|
cgscore = torch.tensor(cgscore, dtype=torch.float32) |
|
|
|
|
|
assert adj.shape == cgscore.shape, "adj and cgscore must have the same shape" |
|
|
N = adj.shape[0] |
|
|
|
|
|
|
|
|
triu_indices = torch.triu_indices(N, N, offset=1) |
|
|
triu_scores = cgscore[triu_indices[0], triu_indices[1]] |
|
|
triu_adj = adj[triu_indices[0], triu_indices[1]] |
|
|
|
|
|
|
|
|
mask = triu_adj > 0 |
|
|
triu_scores = triu_scores[mask] |
|
|
triu_indices = triu_indices[:, mask] |
|
|
|
|
|
|
|
|
sorted_indices = torch.argsort(triu_scores) |
|
|
|
|
|
|
|
|
|
|
|
num_edges_to_remove = int(len(sorted_indices) * (remove_ratio)) |
|
|
print("len(sorted_indices)", len(sorted_indices)) |
|
|
print("remove_radio:", remove_ratio) |
|
|
print("num_edges_to_remove", num_edges_to_remove) |
|
|
edges_to_remove = sorted_indices[:num_edges_to_remove] |
|
|
|
|
|
|
|
|
adj_new = adj.clone() |
|
|
|
|
|
|
|
|
for idx in edges_to_remove: |
|
|
i, j = triu_indices[:, idx] |
|
|
adj_new[i, j] = 0 |
|
|
adj_new[j, i] = 0 |
|
|
|
|
|
return adj_new |
|
|
|
|
|
|
|
|
|
|
|
adj = torch.tensor([ |
|
|
[0, 1, 1, 0], |
|
|
[1, 0, 1, 1], |
|
|
[1, 1, 0, 1], |
|
|
[0, 1, 1, 0] |
|
|
], dtype=torch.float32) |
|
|
|
|
|
cgscore = np.array([ |
|
|
[0.0, 0.8, 0.6, 0.0], |
|
|
[0.8, 0.0, 0.1, 1.2], |
|
|
[0.6, 0.7, 0.0, 1.9], |
|
|
[0.0, 1.2, 1.1, 0.0] |
|
|
], dtype=np.float32) |
|
|
|
|
|
|
|
|
adj_new = remove_least_important_edges(adj, cgscore, remove_ratio=0.2) |
|
|
|
|
|
|
|
|
print("原始邻接矩阵:") |
|
|
print(adj) |
|
|
print("调整后的邻接矩阵:") |
|
|
print(adj_new) |
|
|
|
|
|
|