File size: 7,291 Bytes
c91d7b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
import torch
import numpy as np
from deeprobust.graph.defense import GCN
from deeprobust.graph.targeted_attack import SGAttack
from deeprobust.graph.utils import *
from deeprobust.graph.data import Dataset, Dpr2Pyg
from deeprobust.graph.defense import SGC
import argparse
from tqdm import tqdm
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
parser.add_argument('--dataset', type=str, default='cora', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
parser.add_argument('--ptb_rate', type=float, default=0.05, help='pertubation rate')
args = parser.parse_args()
args.cuda = torch.cuda.is_available()
print('cuda: %s' % args.cuda)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)
data = Dataset(root='/tmp/', name=args.dataset)
adj, features, labels = data.adj, data.features, data.labels
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
idx_unlabeled = np.union1d(idx_val, idx_test)
# Setup Surrogate model
surrogate = SGC(nfeat=features.shape[1],
nclass=labels.max().item() + 1, K=2,
lr=0.01, device=device).to(device)
pyg_data = Dpr2Pyg(data)
surrogate.fit(pyg_data, verbose=False) # train with earlystopping
surrogate.test()
# Setup Attack Model
target_node = 0
assert target_node in idx_unlabeled
model = SGAttack(surrogate, attack_structure=True, attack_features=False, device=device)
# model = SGAttack(surrogate, attack_structure=True, attack_features=True, device=device)
model = model.to(device)
def main():
degrees = adj.sum(0).A1
# How many perturbations to perform. Default: Degree of the node
n_perturbations = int(degrees[target_node])
# direct attack
model.attack(features, adj, labels, target_node, n_perturbations, direct=True)
# # indirect attack/ influencer attack
# model.attack(features, adj, labels, target_node, n_perturbations, direct=False, n_influencers=5)
modified_adj = model.modified_adj
modified_features = model.modified_features
print('=== Structure perturbations ===')
print(model.structure_perturbations)
print('=== Feature perturbations ===')
print(model.feature_perturbations)
print('=== testing GCN on original(clean) graph ===')
test(adj, features, target_node)
print('=== testing GCN on perturbed graph ===')
test(modified_adj, modified_features, target_node)
def test(adj, features, target_node):
''' test on GCN '''
gcn = GCN(nfeat=features.shape[1],
nhid=16,
nclass=labels.max().item() + 1,
dropout=0.5, device=device)
gcn = gcn.to(device)
gcn.fit(features, adj, labels, idx_train, idx_val, patience=30)
gcn.eval()
output = gcn.predict()
probs = torch.exp(output[[target_node]])[0]
print('Target node probs: {}'.format(probs.detach().cpu().numpy()))
acc_test = accuracy(output[idx_test], labels[idx_test])
print("Overall test set results:",
"accuracy= {:.4f}".format(acc_test.item()))
return acc_test.item()
def select_nodes(target_gcn=None):
'''
selecting nodes as reported in Nettack paper:
(i) the 10 nodes with highest margin of classification, i.e. they are clearly correctly classified,
(ii) the 10 nodes with lowest margin (but still correctly classified) and
(iii) 20 more nodes randomly
'''
if target_gcn is None:
target_gcn = GCN(nfeat=features.shape[1],
nhid=16,
nclass=labels.max().item() + 1,
dropout=0.5, device=device)
target_gcn = target_gcn.to(device)
target_gcn.fit(features, adj, labels, idx_train, idx_val, patience=30)
target_gcn.eval()
output = target_gcn.predict()
margin_dict = {}
for idx in idx_test:
margin = classification_margin(output[idx], labels[idx])
if margin < 0: # only keep the nodes correctly classified
continue
margin_dict[idx] = margin
sorted_margins = sorted(margin_dict.items(), key=lambda x: x[1], reverse=True)
high = [x for x, y in sorted_margins[: 10]]
low = [x for x, y in sorted_margins[-10:]]
other = [x for x, y in sorted_margins[10: -10]]
other = np.random.choice(other, 20, replace=False).tolist()
return high + low + other
def multi_test_poison():
# test on 40 nodes on poisoining attack
cnt = 0
degrees = adj.sum(0).A1
node_list = select_nodes()
num = len(node_list)
print('=== [Poisoning] Attacking %s nodes respectively ===' % num)
for target_node in tqdm(node_list):
n_perturbations = int(degrees[target_node])
model = SGAttack(surrogate, attack_structure=True, attack_features=False, device=device)
model = model.to(device)
model.attack(features, adj, labels, target_node, n_perturbations, direct=True, verbose=False)
modified_adj = model.modified_adj
modified_features = model.modified_features
acc = single_test(modified_adj, modified_features, target_node)
if acc == 0:
cnt += 1
print('misclassification rate : %s' % (cnt / num))
def single_test(adj, features, target_node, gcn=None):
if gcn is None:
# test on GCN (poisoning attack)
gcn = GCN(nfeat=features.shape[1],
nhid=16,
nclass=labels.max().item() + 1,
dropout=0.5, device=device)
gcn = gcn.to(device)
gcn.fit(features, adj, labels, idx_train, idx_val, patience=30)
gcn.eval()
output = gcn.predict()
else:
# test on GCN (evasion attack)
output = gcn.predict(features, adj)
probs = torch.exp(output[[target_node]])
# acc_test = accuracy(output[[target_node]], labels[target_node])
acc_test = (output.argmax(1)[target_node] == labels[target_node])
return acc_test.item()
def multi_test_evasion():
# test on 40 nodes on evasion attack
target_gcn = GCN(nfeat=features.shape[1],
nhid=16,
nclass=labels.max().item() + 1,
dropout=0.5, device=device)
target_gcn = target_gcn.to(device)
target_gcn.fit(features, adj, labels, idx_train, idx_val, patience=30)
cnt = 0
degrees = adj.sum(0).A1
node_list = select_nodes(target_gcn)
num = len(node_list)
print('=== [Evasion] Attacking %s nodes respectively ===' % num)
for target_node in tqdm(node_list):
n_perturbations = int(degrees[target_node])
model = SGAttack(surrogate, attack_structure=True, attack_features=False, device=device)
model = model.to(device)
model.attack(features, adj, labels, target_node, n_perturbations, direct=True, verbose=False)
modified_adj = model.modified_adj
modified_features = model.modified_features
acc = single_test(modified_adj, modified_features, target_node, gcn=target_gcn)
if acc == 0:
cnt += 1
print('misclassification rate : %s' % (cnt / num))
if __name__ == '__main__':
main()
multi_test_poison()
multi_test_evasion()
|