|
|
""" |
|
|
Adversarial Attacks on Neural Networks for Graph Data. ICML 2018. |
|
|
https://arxiv.org/abs/1806.02371 |
|
|
Author's Implementation |
|
|
https://github.com/Hanjun-Dai/graph_adversarial_attack |
|
|
This part of code is adopted from the author's implementation (Copyright (c) 2018 Dai, Hanjun and Li, Hui and Tian, Tian and Huang, Xin and Wang, Lin and Zhu, Jun and Song, Le) |
|
|
but modified to be integrated into the repository. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import os.path as osp |
|
|
import numpy as np |
|
|
import torch |
|
|
import networkx as nx |
|
|
import random |
|
|
from torch.nn.parameter import Parameter |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.optim as optim |
|
|
from tqdm import tqdm |
|
|
from copy import deepcopy |
|
|
from deeprobust.graph.rl.q_net_node import QNetNode, NStepQNetNode, node_greedy_actions |
|
|
from deeprobust.graph.rl.env import NodeAttackEnv |
|
|
from deeprobust.graph.rl.nstep_replay_mem import NstepReplayMem |
|
|
|
|
|
class RLS2V(object): |
|
|
""" Reinforcement learning agent for RL-S2V attack. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
env : |
|
|
Node attack environment |
|
|
features : |
|
|
node features matrix |
|
|
labels : |
|
|
labels |
|
|
idx_meta : |
|
|
node meta indices |
|
|
idx_test : |
|
|
node test indices |
|
|
list_action_space : list |
|
|
list of action space |
|
|
num_mod : |
|
|
number of modification (perturbation) on the graph |
|
|
reward_type : str |
|
|
type of reward (e.g., 'binary') |
|
|
batch_size : |
|
|
batch size for training DQN |
|
|
save_dir : |
|
|
saving directory for model checkpoints |
|
|
device: str |
|
|
'cpu' or 'cuda' |
|
|
|
|
|
Examples |
|
|
-------- |
|
|
See details in https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_rl_s2v.py |
|
|
""" |
|
|
|
|
|
def __init__(self, env, features, labels, idx_meta, idx_test, |
|
|
list_action_space, num_mod, reward_type, batch_size=10, |
|
|
num_wrong=0, bilin_q=1, embed_dim=64, gm='mean_field', |
|
|
mlp_hidden=64, max_lv=1, save_dir='checkpoint_dqn', device=None): |
|
|
|
|
|
|
|
|
assert device is not None, "'device' cannot be None, please specify it" |
|
|
|
|
|
self.features = features |
|
|
self.labels = labels |
|
|
self.idx_meta = idx_meta |
|
|
self.idx_test = idx_test |
|
|
self.num_wrong = num_wrong |
|
|
self.list_action_space = list_action_space |
|
|
self.num_mod = num_mod |
|
|
self.reward_type = reward_type |
|
|
self.batch_size = batch_size |
|
|
self.save_dir = save_dir |
|
|
if not osp.exists(save_dir): |
|
|
os.system('mkdir -p {}'.format(save_dir)) |
|
|
|
|
|
self.gm = gm |
|
|
self.device = device |
|
|
|
|
|
self.mem_pool = NstepReplayMem(memory_size=500000, n_steps=2 * num_mod, balance_sample=reward_type == 'binary') |
|
|
self.env = env |
|
|
|
|
|
|
|
|
|
|
|
self.net = NStepQNetNode(2 * num_mod, features, labels, list_action_space, |
|
|
bilin_q=bilin_q, embed_dim=embed_dim, mlp_hidden=mlp_hidden, |
|
|
max_lv=max_lv, gm=gm, device=device) |
|
|
|
|
|
self.old_net = NStepQNetNode(2 * num_mod, features, labels, list_action_space, |
|
|
bilin_q=bilin_q, embed_dim=embed_dim, mlp_hidden=mlp_hidden, |
|
|
max_lv=max_lv, gm=gm, device=device) |
|
|
|
|
|
self.net = self.net.to(device) |
|
|
self.old_net = self.old_net.to(device) |
|
|
|
|
|
self.eps_start = 1.0 |
|
|
self.eps_end = 0.05 |
|
|
self.eps_step = 100000 |
|
|
self.burn_in = 10 |
|
|
self.step = 0 |
|
|
self.pos = 0 |
|
|
self.best_eval = None |
|
|
self.take_snapshot() |
|
|
|
|
|
def take_snapshot(self): |
|
|
self.old_net.load_state_dict(self.net.state_dict()) |
|
|
|
|
|
def make_actions(self, time_t, greedy=False): |
|
|
self.eps = self.eps_end + max(0., (self.eps_start - self.eps_end) |
|
|
* (self.eps_step - max(0., self.step)) / self.eps_step) |
|
|
|
|
|
if random.random() < self.eps and not greedy: |
|
|
actions = self.env.uniformRandActions() |
|
|
else: |
|
|
cur_state = self.env.getStateRef() |
|
|
actions, values = self.net(time_t, cur_state, None, greedy_acts=True, is_inference=True) |
|
|
actions = list(actions.cpu().numpy()) |
|
|
|
|
|
return actions |
|
|
|
|
|
def run_simulation(self): |
|
|
|
|
|
if (self.pos + 1) * self.batch_size > len(self.idx_test): |
|
|
self.pos = 0 |
|
|
random.shuffle(self.idx_test) |
|
|
|
|
|
selected_idx = self.idx_test[self.pos * self.batch_size : (self.pos + 1) * self.batch_size] |
|
|
self.pos += 1 |
|
|
self.env.setup(selected_idx) |
|
|
|
|
|
t = 0 |
|
|
list_of_list_st = [] |
|
|
list_of_list_at = [] |
|
|
|
|
|
while not self.env.isTerminal(): |
|
|
list_at = self.make_actions(t) |
|
|
list_st = self.env.cloneState() |
|
|
|
|
|
self.env.step(list_at) |
|
|
|
|
|
|
|
|
env = self.env |
|
|
assert (env.rewards is not None) == env.isTerminal() |
|
|
if env.isTerminal(): |
|
|
rewards = env.rewards |
|
|
s_prime = None |
|
|
else: |
|
|
rewards = np.zeros(len(list_at), dtype=np.float32) |
|
|
s_prime = self.env.cloneState() |
|
|
|
|
|
self.mem_pool.add_list(list_st, list_at, rewards, s_prime, [env.isTerminal()] * len(list_at), t) |
|
|
list_of_list_st.append( deepcopy(list_st) ) |
|
|
list_of_list_at.append( deepcopy(list_at) ) |
|
|
t += 1 |
|
|
|
|
|
|
|
|
if self.reward_type == 'nll': |
|
|
return |
|
|
|
|
|
T = t |
|
|
cands = self.env.sample_pos_rewards(len(selected_idx)) |
|
|
if len(cands): |
|
|
for c in cands: |
|
|
sample_idx, target = c |
|
|
doable = True |
|
|
for t in range(T): |
|
|
if self.list_action_space[target] is not None and (not list_of_list_at[t][sample_idx] in self.list_action_space[target]): |
|
|
doable = False |
|
|
break |
|
|
if not doable: |
|
|
continue |
|
|
|
|
|
for t in range(T): |
|
|
s_t = list_of_list_st[t][sample_idx] |
|
|
a_t = list_of_list_at[t][sample_idx] |
|
|
s_t = [target, deepcopy(s_t[1]), s_t[2]] |
|
|
if t + 1 == T: |
|
|
s_prime = (None, None, None) |
|
|
r = 1.0 |
|
|
term = True |
|
|
else: |
|
|
s_prime = list_of_list_st[t + 1][sample_idx] |
|
|
s_prime = [target, deepcopy(s_prime[1]), s_prime[2]] |
|
|
r = 0.0 |
|
|
term = False |
|
|
self.mem_pool.mem_cells[t].add(s_t, a_t, r, s_prime, term) |
|
|
|
|
|
def eval(self, training=True): |
|
|
"""Evaluate RL agent. |
|
|
""" |
|
|
|
|
|
self.env.setup(self.idx_meta) |
|
|
t = 0 |
|
|
|
|
|
while not self.env.isTerminal(): |
|
|
list_at = self.make_actions(t, greedy=True) |
|
|
self.env.step(list_at) |
|
|
t += 1 |
|
|
|
|
|
acc = 1 - (self.env.binary_rewards + 1.0) / 2.0 |
|
|
acc = np.sum(acc) / (len(self.idx_meta) + self.num_wrong) |
|
|
print('\033[93m average test: acc %.5f\033[0m' % (acc)) |
|
|
|
|
|
if training == True and self.best_eval is None or acc < self.best_eval: |
|
|
print('----saving to best attacker since this is the best attack rate so far.----') |
|
|
torch.save(self.net.state_dict(), osp.join(self.save_dir, 'epoch-best.model')) |
|
|
with open(osp.join(self.save_dir, 'epoch-best.txt'), 'w') as f: |
|
|
f.write('%.4f\n' % acc) |
|
|
with open(osp.join(self.save_dir, 'attack_solution.txt'), 'w') as f: |
|
|
for i in range(len(self.idx_meta)): |
|
|
f.write('%d: [' % self.idx_meta[i]) |
|
|
for e in self.env.modified_list[i].directed_edges: |
|
|
f.write('(%d %d)' % e) |
|
|
f.write('] succ: %d\n' % (self.env.binary_rewards[i])) |
|
|
self.best_eval = acc |
|
|
|
|
|
def train(self, num_steps=100000, lr=0.001): |
|
|
"""Train RL agent. |
|
|
""" |
|
|
|
|
|
pbar = tqdm(range(self.burn_in), unit='batch') |
|
|
|
|
|
for p in pbar: |
|
|
self.run_simulation() |
|
|
|
|
|
pbar = tqdm(range(num_steps), unit='steps') |
|
|
optimizer = optim.Adam(self.net.parameters(), lr=lr) |
|
|
|
|
|
for self.step in pbar: |
|
|
|
|
|
self.run_simulation() |
|
|
|
|
|
if self.step % 123 == 0: |
|
|
|
|
|
self.take_snapshot() |
|
|
if self.step % 500 == 0: |
|
|
self.eval() |
|
|
|
|
|
cur_time, list_st, list_at, list_rt, list_s_primes, list_term = self.mem_pool.sample(batch_size=self.batch_size) |
|
|
list_target = torch.Tensor(list_rt).to(self.device) |
|
|
|
|
|
if not list_term[0]: |
|
|
target_nodes, _, picked_nodes = zip(*list_s_primes) |
|
|
_, q_t_plus_1 = self.old_net(cur_time + 1, list_s_primes, None) |
|
|
_, q_rhs = node_greedy_actions(target_nodes, picked_nodes, q_t_plus_1, self.old_net) |
|
|
list_target += q_rhs |
|
|
|
|
|
|
|
|
list_target = list_target.view(-1, 1) |
|
|
_, q_sa = self.net(cur_time, list_st, list_at) |
|
|
q_sa = torch.cat(q_sa, dim=0) |
|
|
loss = F.mse_loss(q_sa, list_target) |
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
pbar.set_description('eps: %.5f, loss: %0.5f, q_val: %.5f' % (self.eps, loss, torch.mean(q_sa)) ) |
|
|
|
|
|
|
|
|
|