File size: 14,048 Bytes
92b9080 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 | import torch
import torch.nn.functional as F
import numpy as np
import scipy.sparse as sp
from collections import namedtuple
from functools import lru_cache
from torch_scatter import scatter_add
from torch_geometric.utils import k_hop_subgraph
from deeprobust.graph.targeted_attack import BaseAttack
from deeprobust.graph import utils
SubGraph = namedtuple('SubGraph', ['edge_index', 'non_edge_index',
'self_loop', 'self_loop_weight',
'edge_weight', 'non_edge_weight',
'edges_all'])
class SGAttack(BaseAttack):
"""SGAttack proposed in `Adversarial Attack on Large Scale Graph` TKDE 2021
<https://arxiv.org/abs/2009.03488>
SGAttack follows these steps::
+ training a surrogate SGC model with hop K
+ extrack a K-hop subgraph centered at target node
+ choose top-N attacker nodes that belong to the best wrong classes of the target node
+ compute gradients w.r.t to the subgraph to add or remove edges iteratively
Parameters
----------
model :
model to attack
nnodes : int
number of nodes in the input graph
attack_structure : bool
whether to attack graph structure
attack_features : bool
whether to attack node features
device: str
'cpu' or 'cuda'
Examples
--------
>>> from deeprobust.graph.data import Dataset
>>> from deeprobust.graph.defense import SGC
>>> data = Dataset(root='/tmp/', name='cora')
>>> adj, features, labels = data.adj, data.features, data.labels
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
>>> surrogate = SGC(nfeat=features.shape[1], K=3, lr=0.1,
nclass=labels.max().item() + 1, device='cuda')
>>> surrogate = surrogate.to('cuda')
>>> pyg_data = Dpr2Pyg(data) # convert deeprobust dataset to pyg dataset
>>> surrogate.fit(pyg_data, train_iters=200, patience=200, verbose=True) # train with earlystopping
>>> from deeprobust.graph.targeted_attack import SGAttack
>>> # Setup Attack Model
>>> target_node = 0
>>> model = SGAttack(surrogate, attack_structure=True, device=device)
>>> # Attack
>>> model.attack(features, adj, labels, target_node, n_perturbations=5)
>>> modified_adj = model.modified_adj
>>> modified_features = model.modified_features
"""
def __init__(self, model, nnodes=None, attack_structure=True, attack_features=False, device='cpu'):
super(SGAttack, self).__init__(model=None, nnodes=nnodes,
attack_structure=attack_structure, attack_features=attack_features, device=device)
self.target_node = None
self.logits = model.predict()
self.K = model.conv1.K
W = model.conv1.lin.weight.to(device)
b = model.conv1.lin.bias
if b is not None:
b = b.to(device)
self.weight, self.bias = W, b
@lru_cache(maxsize=1)
def compute_XW(self):
return F.linear(self.modified_features, self.weight)
def attack(self, features, adj, labels, target_node, n_perturbations, direct=True, n_influencers=3, **kwargs):
"""Generate perturbations on the input graph.
Parameters
----------
features :
Original (unperturbed) node feature matrix
adj :
Original (unperturbed) adjacency matrix
labels :
node labels
target_node : int
target_node node index to be attacked
n_perturbations : int
Number of perturbations on the input graph. Perturbations could
be edge removals/additions or feature removals/additions.
direct: bool
whether to conduct direct attack
n_influencers : int
number of the top influencers to choose. For direct attack, it will set as `n_perturbations`.
"""
if sp.issparse(features):
# to dense numpy matrix
features = features.A
if not torch.is_tensor(features):
features = torch.tensor(features, device=self.device)
if torch.is_tensor(adj):
adj = utils.to_scipy(adj).csr()
self.modified_features = features.requires_grad_(bool(self.attack_features))
target_label = torch.LongTensor([labels[target_node]])
best_wrong_label = torch.LongTensor([(self.logits[target_node].cpu() - 1000 * torch.eye(self.logits.size(1))[target_label]).argmax()])
self.selfloop_degree = torch.tensor(adj.sum(1).A1 + 1, device=self.device)
self.target_label = target_label.to(self.device)
self.best_wrong_label = best_wrong_label.to(self.device)
self.n_perturbations = n_perturbations
self.ori_adj = adj
self.target_node = target_node
self.direct = direct
attacker_nodes = torch.where(torch.as_tensor(labels) == best_wrong_label)[0]
subgraph = self.get_subgraph(attacker_nodes, n_influencers)
if not direct:
# for indirect attack, the edges adjacent to targeted node should not be considered
mask = torch.logical_or(subgraph.edge_index[0] == target_node, subgraph.edge_index[1] == target_node).to(self.device)
structure_perturbations = []
feature_perturbations = []
num_features = features.shape[-1]
for _ in range(n_perturbations):
edge_grad, non_edge_grad, features_grad = self.compute_gradient(subgraph)
max_structure_score = max_feature_score = 0.
if self.attack_structure:
edge_grad *= (-2 * subgraph.edge_weight + 1)
non_edge_grad *= -2 * subgraph.non_edge_weight + 1
min_grad = min(edge_grad.min().item(), non_edge_grad.min().item())
edge_grad -= min_grad
non_edge_grad -= min_grad
if not direct:
edge_grad[mask] = 0.
max_edge_grad, max_edge_idx = torch.max(edge_grad, dim=0)
max_non_edge_grad, max_non_edge_idx = torch.max(non_edge_grad, dim=0)
max_structure_score = max(max_edge_grad.item(), max_non_edge_grad.item())
if self.attack_features:
features_grad *= -2 * self.modified_features + 1
features_grad -= features_grad.min()
if not direct:
features_grad[target_node] = 0.
max_feature_grad, max_feature_idx = torch.max(features_grad.view(-1), dim=0)
max_feature_score = max_feature_grad.item()
if max_structure_score >= max_feature_score:
if max_edge_grad > max_non_edge_grad:
# remove one edge
best_edge = subgraph.edge_index[:, max_edge_idx]
subgraph.edge_weight.data[max_edge_idx] = 0.0
self.selfloop_degree[best_edge] -= 1.0
else:
# add one edge
best_edge = subgraph.non_edge_index[:, max_non_edge_idx]
subgraph.non_edge_weight.data[max_non_edge_idx] = 1.0
self.selfloop_degree[best_edge] += 1.0
u, v = best_edge.tolist()
structure_perturbations.append((u, v))
else:
u, v = divmod(max_feature_idx.item(), num_features)
feature_perturbations.append((u, v))
self.modified_features[u, v].data.fill_(1. - self.modified_features[u, v].data)
if structure_perturbations:
modified_adj = adj.tolil(copy=True)
row, col = list(zip(*structure_perturbations))
modified_adj[row, col] = modified_adj[col, row] = 1 - modified_adj[row, col].A
modified_adj = modified_adj.tocsr(copy=False)
modified_adj.eliminate_zeros()
else:
modified_adj = adj.copy()
self.modified_adj = modified_adj
self.modified_features = self.modified_features.detach().cpu().numpy()
self.structure_perturbations = structure_perturbations
self.feature_perturbations = feature_perturbations
def get_subgraph(self, attacker_nodes, n_influencers=None):
target_node = self.target_node
neighbors = self.ori_adj[target_node].indices
sub_nodes, sub_edges = self.ego_subgraph()
if self.direct or n_influencers is not None:
influencers = [target_node]
attacker_nodes = np.setdiff1d(attacker_nodes, neighbors)
else:
influencers = neighbors
subgraph = self.subgraph_processing(influencers, attacker_nodes, sub_nodes, sub_edges)
if n_influencers is not None and self.attack_structure:
if self.direct:
influencers = [target_node]
attacker_nodes = self.get_topk_influencers(subgraph, k=self.n_perturbations + 1)
else:
influencers = neighbors
attacker_nodes = self.get_topk_influencers(subgraph, k=n_influencers)
subgraph = self.subgraph_processing(influencers, attacker_nodes, sub_nodes, sub_edges)
return subgraph
def get_topk_influencers(self, subgraph, k):
_, non_edge_grad, _ = self.compute_gradient(subgraph)
_, topk_nodes = torch.topk(non_edge_grad, k=k, sorted=False)
influencers = subgraph.non_edge_index[1][topk_nodes.cpu()]
return influencers.cpu().numpy()
def subgraph_processing(self, influencers, attacker_nodes, sub_nodes, sub_edges):
if not self.attack_structure:
self_loop = sub_nodes.repeat((2, 1))
edges_all = torch.cat([sub_edges, sub_edges[[1, 0]], self_loop], dim=1)
edge_weight = torch.ones(edges_all.size(1), device=self.device)
return SubGraph(edge_index=sub_edges, non_edge_index=None,
self_loop=None, edges_all=edges_all,
edge_weight=edge_weight, non_edge_weight=None,
self_loop_weight=None)
row = np.repeat(influencers, len(attacker_nodes))
col = np.tile(attacker_nodes, len(influencers))
non_edges = np.row_stack([row, col])
if len(influencers) > 1:
mask = self.ori_adj[non_edges[0],
non_edges[1]].A1 == 0
non_edges = non_edges[:, mask]
non_edges = torch.as_tensor(non_edges, device=self.device)
unique_nodes = np.union1d(sub_nodes.tolist(), attacker_nodes)
unique_nodes = torch.as_tensor(unique_nodes, device=self.device)
self_loop = unique_nodes.repeat((2, 1))
edges_all = torch.cat([sub_edges, sub_edges[[1, 0]],
non_edges, non_edges[[1, 0]], self_loop], dim=1)
edge_weight = torch.ones(sub_edges.size(1), device=self.device).requires_grad_(bool(self.attack_structure))
non_edge_weight = torch.zeros(non_edges.size(1), device=self.device).requires_grad_(bool(self.attack_structure))
self_loop_weight = torch.ones(self_loop.size(1), device=self.device)
edge_index = sub_edges
non_edge_index = non_edges
self_loop = self_loop
subgraph = SubGraph(edge_index=edge_index, non_edge_index=non_edge_index,
self_loop=self_loop, edges_all=edges_all,
edge_weight=edge_weight, non_edge_weight=non_edge_weight,
self_loop_weight=self_loop_weight)
return subgraph
def SGCCov(self, x, edge_index, edge_weight):
row, col = edge_index
for _ in range(self.K):
src = x[row] * edge_weight.view(-1, 1)
x = scatter_add(src, col, dim=-2, dim_size=x.size(0))
return x
def compute_gradient(self, subgraph, eps=5.0):
if self.attack_structure:
edge_weight = subgraph.edge_weight
non_edge_weight = subgraph.non_edge_weight
self_loop_weight = subgraph.self_loop_weight
weights = torch.cat([edge_weight, edge_weight,
non_edge_weight, non_edge_weight,
self_loop_weight], dim=0)
else:
weights = subgraph.edge_weight
weights = self.gcn_norm(subgraph.edges_all, weights, self.selfloop_degree)
logit = self.SGCCov(self.compute_XW(), subgraph.edges_all, weights)
logit = logit[self.target_node]
if self.bias is not None:
logit += self.bias
# model calibration
logit = F.log_softmax(logit.view(1, -1) / eps, dim=1)
loss = F.nll_loss(logit, self.target_label) - F.nll_loss(logit, self.best_wrong_label)
edge_grad = non_edge_grad = features_grad = None
if self.attack_structure and self.attack_features:
edge_grad, non_edge_grad, features_grad = torch.autograd.grad(loss, [edge_weight, non_edge_weight, self.modified_features], create_graph=False)
elif self.attack_structure:
edge_grad, non_edge_grad = torch.autograd.grad(loss, [edge_weight, non_edge_weight], create_graph=False)
else:
features_grad = torch.autograd.grad(loss, self.modified_features, create_graph=False)[0]
if self.attack_features:
self.compute_XW.cache_clear()
return edge_grad, non_edge_grad, features_grad
def ego_subgraph(self):
edge_index = np.asarray(self.ori_adj.nonzero())
edge_index = torch.as_tensor(edge_index, dtype=torch.long, device=self.device)
sub_nodes, sub_edges, *_ = k_hop_subgraph(int(self.target_node), self.K, edge_index)
sub_edges = sub_edges[:, sub_edges[0] < sub_edges[1]]
return sub_nodes, sub_edges
@ staticmethod
def gcn_norm(edge_index, weights, degree):
row, col = edge_index
inv_degree = torch.pow(degree, -0.5)
normed_weights = weights * inv_degree[row] * inv_degree[col]
return normed_weights
|