Yaning1001's picture
Add files using upload-large-folder tool
92b9080 verified
import torch
import torch.nn.functional as F
import numpy as np
import scipy.sparse as sp
from collections import namedtuple
from functools import lru_cache
from torch_scatter import scatter_add
from torch_geometric.utils import k_hop_subgraph
from deeprobust.graph.targeted_attack import BaseAttack
from deeprobust.graph import utils
SubGraph = namedtuple('SubGraph', ['edge_index', 'non_edge_index',
'self_loop', 'self_loop_weight',
'edge_weight', 'non_edge_weight',
'edges_all'])
class SGAttack(BaseAttack):
"""SGAttack proposed in `Adversarial Attack on Large Scale Graph` TKDE 2021
<https://arxiv.org/abs/2009.03488>
SGAttack follows these steps::
+ training a surrogate SGC model with hop K
+ extrack a K-hop subgraph centered at target node
+ choose top-N attacker nodes that belong to the best wrong classes of the target node
+ compute gradients w.r.t to the subgraph to add or remove edges iteratively
Parameters
----------
model :
model to attack
nnodes : int
number of nodes in the input graph
attack_structure : bool
whether to attack graph structure
attack_features : bool
whether to attack node features
device: str
'cpu' or 'cuda'
Examples
--------
>>> from deeprobust.graph.data import Dataset
>>> from deeprobust.graph.defense import SGC
>>> data = Dataset(root='/tmp/', name='cora')
>>> adj, features, labels = data.adj, data.features, data.labels
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
>>> surrogate = SGC(nfeat=features.shape[1], K=3, lr=0.1,
nclass=labels.max().item() + 1, device='cuda')
>>> surrogate = surrogate.to('cuda')
>>> pyg_data = Dpr2Pyg(data) # convert deeprobust dataset to pyg dataset
>>> surrogate.fit(pyg_data, train_iters=200, patience=200, verbose=True) # train with earlystopping
>>> from deeprobust.graph.targeted_attack import SGAttack
>>> # Setup Attack Model
>>> target_node = 0
>>> model = SGAttack(surrogate, attack_structure=True, device=device)
>>> # Attack
>>> model.attack(features, adj, labels, target_node, n_perturbations=5)
>>> modified_adj = model.modified_adj
>>> modified_features = model.modified_features
"""
def __init__(self, model, nnodes=None, attack_structure=True, attack_features=False, device='cpu'):
super(SGAttack, self).__init__(model=None, nnodes=nnodes,
attack_structure=attack_structure, attack_features=attack_features, device=device)
self.target_node = None
self.logits = model.predict()
self.K = model.conv1.K
W = model.conv1.lin.weight.to(device)
b = model.conv1.lin.bias
if b is not None:
b = b.to(device)
self.weight, self.bias = W, b
@lru_cache(maxsize=1)
def compute_XW(self):
return F.linear(self.modified_features, self.weight)
def attack(self, features, adj, labels, target_node, n_perturbations, direct=True, n_influencers=3, **kwargs):
"""Generate perturbations on the input graph.
Parameters
----------
features :
Original (unperturbed) node feature matrix
adj :
Original (unperturbed) adjacency matrix
labels :
node labels
target_node : int
target_node node index to be attacked
n_perturbations : int
Number of perturbations on the input graph. Perturbations could
be edge removals/additions or feature removals/additions.
direct: bool
whether to conduct direct attack
n_influencers : int
number of the top influencers to choose. For direct attack, it will set as `n_perturbations`.
"""
if sp.issparse(features):
# to dense numpy matrix
features = features.A
if not torch.is_tensor(features):
features = torch.tensor(features, device=self.device)
if torch.is_tensor(adj):
adj = utils.to_scipy(adj).csr()
self.modified_features = features.requires_grad_(bool(self.attack_features))
target_label = torch.LongTensor([labels[target_node]])
best_wrong_label = torch.LongTensor([(self.logits[target_node].cpu() - 1000 * torch.eye(self.logits.size(1))[target_label]).argmax()])
self.selfloop_degree = torch.tensor(adj.sum(1).A1 + 1, device=self.device)
self.target_label = target_label.to(self.device)
self.best_wrong_label = best_wrong_label.to(self.device)
self.n_perturbations = n_perturbations
self.ori_adj = adj
self.target_node = target_node
self.direct = direct
attacker_nodes = torch.where(torch.as_tensor(labels) == best_wrong_label)[0]
subgraph = self.get_subgraph(attacker_nodes, n_influencers)
if not direct:
# for indirect attack, the edges adjacent to targeted node should not be considered
mask = torch.logical_or(subgraph.edge_index[0] == target_node, subgraph.edge_index[1] == target_node).to(self.device)
structure_perturbations = []
feature_perturbations = []
num_features = features.shape[-1]
for _ in range(n_perturbations):
edge_grad, non_edge_grad, features_grad = self.compute_gradient(subgraph)
max_structure_score = max_feature_score = 0.
if self.attack_structure:
edge_grad *= (-2 * subgraph.edge_weight + 1)
non_edge_grad *= -2 * subgraph.non_edge_weight + 1
min_grad = min(edge_grad.min().item(), non_edge_grad.min().item())
edge_grad -= min_grad
non_edge_grad -= min_grad
if not direct:
edge_grad[mask] = 0.
max_edge_grad, max_edge_idx = torch.max(edge_grad, dim=0)
max_non_edge_grad, max_non_edge_idx = torch.max(non_edge_grad, dim=0)
max_structure_score = max(max_edge_grad.item(), max_non_edge_grad.item())
if self.attack_features:
features_grad *= -2 * self.modified_features + 1
features_grad -= features_grad.min()
if not direct:
features_grad[target_node] = 0.
max_feature_grad, max_feature_idx = torch.max(features_grad.view(-1), dim=0)
max_feature_score = max_feature_grad.item()
if max_structure_score >= max_feature_score:
if max_edge_grad > max_non_edge_grad:
# remove one edge
best_edge = subgraph.edge_index[:, max_edge_idx]
subgraph.edge_weight.data[max_edge_idx] = 0.0
self.selfloop_degree[best_edge] -= 1.0
else:
# add one edge
best_edge = subgraph.non_edge_index[:, max_non_edge_idx]
subgraph.non_edge_weight.data[max_non_edge_idx] = 1.0
self.selfloop_degree[best_edge] += 1.0
u, v = best_edge.tolist()
structure_perturbations.append((u, v))
else:
u, v = divmod(max_feature_idx.item(), num_features)
feature_perturbations.append((u, v))
self.modified_features[u, v].data.fill_(1. - self.modified_features[u, v].data)
if structure_perturbations:
modified_adj = adj.tolil(copy=True)
row, col = list(zip(*structure_perturbations))
modified_adj[row, col] = modified_adj[col, row] = 1 - modified_adj[row, col].A
modified_adj = modified_adj.tocsr(copy=False)
modified_adj.eliminate_zeros()
else:
modified_adj = adj.copy()
self.modified_adj = modified_adj
self.modified_features = self.modified_features.detach().cpu().numpy()
self.structure_perturbations = structure_perturbations
self.feature_perturbations = feature_perturbations
def get_subgraph(self, attacker_nodes, n_influencers=None):
target_node = self.target_node
neighbors = self.ori_adj[target_node].indices
sub_nodes, sub_edges = self.ego_subgraph()
if self.direct or n_influencers is not None:
influencers = [target_node]
attacker_nodes = np.setdiff1d(attacker_nodes, neighbors)
else:
influencers = neighbors
subgraph = self.subgraph_processing(influencers, attacker_nodes, sub_nodes, sub_edges)
if n_influencers is not None and self.attack_structure:
if self.direct:
influencers = [target_node]
attacker_nodes = self.get_topk_influencers(subgraph, k=self.n_perturbations + 1)
else:
influencers = neighbors
attacker_nodes = self.get_topk_influencers(subgraph, k=n_influencers)
subgraph = self.subgraph_processing(influencers, attacker_nodes, sub_nodes, sub_edges)
return subgraph
def get_topk_influencers(self, subgraph, k):
_, non_edge_grad, _ = self.compute_gradient(subgraph)
_, topk_nodes = torch.topk(non_edge_grad, k=k, sorted=False)
influencers = subgraph.non_edge_index[1][topk_nodes.cpu()]
return influencers.cpu().numpy()
def subgraph_processing(self, influencers, attacker_nodes, sub_nodes, sub_edges):
if not self.attack_structure:
self_loop = sub_nodes.repeat((2, 1))
edges_all = torch.cat([sub_edges, sub_edges[[1, 0]], self_loop], dim=1)
edge_weight = torch.ones(edges_all.size(1), device=self.device)
return SubGraph(edge_index=sub_edges, non_edge_index=None,
self_loop=None, edges_all=edges_all,
edge_weight=edge_weight, non_edge_weight=None,
self_loop_weight=None)
row = np.repeat(influencers, len(attacker_nodes))
col = np.tile(attacker_nodes, len(influencers))
non_edges = np.row_stack([row, col])
if len(influencers) > 1:
mask = self.ori_adj[non_edges[0],
non_edges[1]].A1 == 0
non_edges = non_edges[:, mask]
non_edges = torch.as_tensor(non_edges, device=self.device)
unique_nodes = np.union1d(sub_nodes.tolist(), attacker_nodes)
unique_nodes = torch.as_tensor(unique_nodes, device=self.device)
self_loop = unique_nodes.repeat((2, 1))
edges_all = torch.cat([sub_edges, sub_edges[[1, 0]],
non_edges, non_edges[[1, 0]], self_loop], dim=1)
edge_weight = torch.ones(sub_edges.size(1), device=self.device).requires_grad_(bool(self.attack_structure))
non_edge_weight = torch.zeros(non_edges.size(1), device=self.device).requires_grad_(bool(self.attack_structure))
self_loop_weight = torch.ones(self_loop.size(1), device=self.device)
edge_index = sub_edges
non_edge_index = non_edges
self_loop = self_loop
subgraph = SubGraph(edge_index=edge_index, non_edge_index=non_edge_index,
self_loop=self_loop, edges_all=edges_all,
edge_weight=edge_weight, non_edge_weight=non_edge_weight,
self_loop_weight=self_loop_weight)
return subgraph
def SGCCov(self, x, edge_index, edge_weight):
row, col = edge_index
for _ in range(self.K):
src = x[row] * edge_weight.view(-1, 1)
x = scatter_add(src, col, dim=-2, dim_size=x.size(0))
return x
def compute_gradient(self, subgraph, eps=5.0):
if self.attack_structure:
edge_weight = subgraph.edge_weight
non_edge_weight = subgraph.non_edge_weight
self_loop_weight = subgraph.self_loop_weight
weights = torch.cat([edge_weight, edge_weight,
non_edge_weight, non_edge_weight,
self_loop_weight], dim=0)
else:
weights = subgraph.edge_weight
weights = self.gcn_norm(subgraph.edges_all, weights, self.selfloop_degree)
logit = self.SGCCov(self.compute_XW(), subgraph.edges_all, weights)
logit = logit[self.target_node]
if self.bias is not None:
logit += self.bias
# model calibration
logit = F.log_softmax(logit.view(1, -1) / eps, dim=1)
loss = F.nll_loss(logit, self.target_label) - F.nll_loss(logit, self.best_wrong_label)
edge_grad = non_edge_grad = features_grad = None
if self.attack_structure and self.attack_features:
edge_grad, non_edge_grad, features_grad = torch.autograd.grad(loss, [edge_weight, non_edge_weight, self.modified_features], create_graph=False)
elif self.attack_structure:
edge_grad, non_edge_grad = torch.autograd.grad(loss, [edge_weight, non_edge_weight], create_graph=False)
else:
features_grad = torch.autograd.grad(loss, self.modified_features, create_graph=False)[0]
if self.attack_features:
self.compute_XW.cache_clear()
return edge_grad, non_edge_grad, features_grad
def ego_subgraph(self):
edge_index = np.asarray(self.ori_adj.nonzero())
edge_index = torch.as_tensor(edge_index, dtype=torch.long, device=self.device)
sub_nodes, sub_edges, *_ = k_hop_subgraph(int(self.target_node), self.K, edge_index)
sub_edges = sub_edges[:, sub_edges[0] < sub_edges[1]]
return sub_nodes, sub_edges
@ staticmethod
def gcn_norm(edge_index, weights, degree):
row, col = edge_index
inv_degree = torch.pow(degree, -0.5)
normed_weights = weights * inv_degree[row] * inv_degree[col]
return normed_weights