File size: 2,361 Bytes
4113c4d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import numpy as np
import torch
def remove_least_important_edges(adj, cgscore, remove_ratio=0.8):
"""
Remove the least important edges based on CGScore.
Args:
adj (torch.Tensor): Original adjacency matrix (N x N).
cgscore (np.ndarray): CGScore matrix (N x N).
keep_ratio (float): Ratio of edges to keep (default: 0.8).
Returns:
adj (torch.Tensor): Adjusted adjacency matrix after removing edges.
"""
# Convert CGScore from numpy to PyTorch tensor
cgscore = torch.tensor(cgscore, dtype=torch.float32)
assert adj.shape == cgscore.shape, "adj and cgscore must have the same shape"
N = adj.shape[0]
# Extract upper triangular non-zero elements (excluding diagonal)
triu_indices = torch.triu_indices(N, N, offset=1) # Upper triangle indices
triu_scores = cgscore[triu_indices[0], triu_indices[1]]
triu_adj = adj[triu_indices[0], triu_indices[1]]
# Mask to ignore zero elements in adj
mask = triu_adj > 0
triu_scores = triu_scores[mask]
triu_indices = triu_indices[:, mask]
# Sort by CGScore in ascending order
sorted_indices = torch.argsort(triu_scores) # Indices of sorted CGScores
# Determine the cutoff for edges to remove
num_edges_to_remove = int(len(sorted_indices) * (remove_ratio)) # Edges to remove
print("len(sorted_indices)", len(sorted_indices))
print("remove_radio:", remove_ratio)
print("num_edges_to_remove", num_edges_to_remove)
edges_to_remove = sorted_indices[:num_edges_to_remove] # First 20% (lowest CGScores)
# Create a copy of the adjacency matrix
adj_new = adj.clone()
# Remove the least important edges
for idx in edges_to_remove:
i, j = triu_indices[:, idx]
adj_new[i, j] = 0
adj_new[j, i] = 0 # Ensure symmetry
return adj_new
# 示例邻接矩阵和 CGScore 矩阵
adj = torch.tensor([
[0, 1, 1, 0],
[1, 0, 1, 1],
[1, 1, 0, 1],
[0, 1, 1, 0]
], dtype=torch.float32)
cgscore = np.array([
[0.0, 0.8, 0.6, 0.0],
[0.8, 0.0, 0.1, 1.2],
[0.6, 0.7, 0.0, 1.9],
[0.0, 1.2, 1.1, 0.0]
], dtype=np.float32)
# 调用函数
adj_new = remove_least_important_edges(adj, cgscore, remove_ratio=0.2)
# 打印结果
print("原始邻接矩阵:")
print(adj)
print("调整后的邻接矩阵:")
print(adj_new)
|