CGSCORE / deeprobust /graph /rl /nipa_nstep_replay_mem.py
Yaning1001's picture
Add files using upload-large-folder tool
d38bce3 verified
'''
This part of code is adopted from https://github.com/Hanjun-Dai/graph_adversarial_attack (Copyright (c) 2018 Dai, Hanjun and Li, Hui and Tian, Tian and Huang, Xin and Wang, Lin and Zhu, Jun and Song, Le)
but modified to be integrated into the repository.
'''
import random
import numpy as np
from deeprobust.graph.rl.nstep_replay_mem import *
def nipa_hash_state_action(s_t, a_t):
key = s_t[0]
base = 179424673
for e in s_t[1].directed_edges:
key = (key * base + e[0]) % base
key = (key * base + e[1]) % base
if s_t[2] is not None:
key = (key * base + s_t[2]) % base
else:
key = (key * base) % base
key = (key * base + a_t) % base
return key
class NstepReplayMem(object):
def __init__(self, memory_size, n_steps, balance_sample = False):
self.mem_cells = []
for i in range(n_steps - 1):
self.mem_cells.append(NstepReplayMemCell(memory_size, False))
self.mem_cells.append(NstepReplayMemCell(memory_size, balance_sample))
self.n_steps = n_steps
self.memory_size = memory_size
def add(self, s_t, a_t, r_t, s_prime, terminal, t):
assert t >= 0 and t < self.n_steps
if t == self.n_steps - 1:
assert terminal
else:
assert not terminal
self.mem_cells[t].add(s_t, a_t, r_t, s_prime, terminal)
def add_list(self, list_st, list_at, list_rt, list_sp, list_term, t):
for i in range(len(list_st)):
if list_sp is None:
sp = (None, None, None)
else:
sp = list_sp[i]
self.add(list_st[i], list_at[i], list_rt[i], sp, list_term[i], t)
def sample(self, batch_size, t = None):
if t is None:
t = np.random.randint(self.n_steps)
list_st, list_at, list_rt, list_s_primes, list_term = self.mem_cells[t].sample(batch_size)
return t, list_st, list_at, list_rt, list_s_primes, list_term