| import numpy as np | |
| import random | |
| from baselines.common.segment_tree import SumSegmentTree, MinSegmentTree | |
| class ReplayBuffer(object): | |
| def __init__(self, size): | |
| """Create Replay buffer. | |
| Parameters | |
| ---------- | |
| size: int | |
| Max number of transitions to store in the buffer. When the buffer | |
| overflows the old memories are dropped. | |
| """ | |
| self._storage = [] | |
| self._maxsize = size | |
| self._next_idx = 0 | |
| def __len__(self): | |
| return len(self._storage) | |
| def add(self, obs_t, action, reward, obs_tp1, done): | |
| data = (obs_t, action, reward, obs_tp1, done) | |
| if self._next_idx >= len(self._storage): | |
| self._storage.append(data) | |
| else: | |
| self._storage[self._next_idx] = data | |
| self._next_idx = (self._next_idx + 1) % self._maxsize | |
| def _encode_sample(self, idxes): | |
| obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], [] | |
| for i in idxes: | |
| data = self._storage[i] | |
| obs_t, action, reward, obs_tp1, done = data | |
| obses_t.append(np.array(obs_t, copy=False)) | |
| actions.append(np.array(action, copy=False)) | |
| rewards.append(reward) | |
| obses_tp1.append(np.array(obs_tp1, copy=False)) | |
| dones.append(done) | |
| return np.array(obses_t), np.array(actions), np.array(rewards), np.array(obses_tp1), np.array(dones) | |
| def sample(self, batch_size): | |
| """Sample a batch of experiences. | |
| Parameters | |
| ---------- | |
| batch_size: int | |
| How many transitions to sample. | |
| Returns | |
| ------- | |
| obs_batch: np.array | |
| batch of observations | |
| act_batch: np.array | |
| batch of actions executed given obs_batch | |
| rew_batch: np.array | |
| rewards received as results of executing act_batch | |
| next_obs_batch: np.array | |
| next set of observations seen after executing act_batch | |
| done_mask: np.array | |
| done_mask[i] = 1 if executing act_batch[i] resulted in | |
| the end of an episode and 0 otherwise. | |
| """ | |
| idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)] | |
| return self._encode_sample(idxes) | |
| class PrioritizedReplayBuffer(ReplayBuffer): | |
| def __init__(self, size, alpha): | |
| """Create Prioritized Replay buffer. | |
| Parameters | |
| ---------- | |
| size: int | |
| Max number of transitions to store in the buffer. When the buffer | |
| overflows the old memories are dropped. | |
| alpha: float | |
| how much prioritization is used | |
| (0 - no prioritization, 1 - full prioritization) | |
| See Also | |
| -------- | |
| ReplayBuffer.__init__ | |
| """ | |
| super(PrioritizedReplayBuffer, self).__init__(size) | |
| assert alpha >= 0 | |
| self._alpha = alpha | |
| it_capacity = 1 | |
| while it_capacity < size: | |
| it_capacity *= 2 | |
| self._it_sum = SumSegmentTree(it_capacity) | |
| self._it_min = MinSegmentTree(it_capacity) | |
| self._max_priority = 1.0 | |
| def add(self, *args, **kwargs): | |
| """See ReplayBuffer.store_effect""" | |
| idx = self._next_idx | |
| super().add(*args, **kwargs) | |
| self._it_sum[idx] = self._max_priority ** self._alpha | |
| self._it_min[idx] = self._max_priority ** self._alpha | |
| def _sample_proportional(self, batch_size): | |
| res = [] | |
| p_total = self._it_sum.sum(0, len(self._storage) - 1) | |
| every_range_len = p_total / batch_size | |
| for i in range(batch_size): | |
| mass = random.random() * every_range_len + i * every_range_len | |
| idx = self._it_sum.find_prefixsum_idx(mass) | |
| res.append(idx) | |
| return res | |
| def sample(self, batch_size, beta): | |
| """Sample a batch of experiences. | |
| compared to ReplayBuffer.sample | |
| it also returns importance weights and idxes | |
| of sampled experiences. | |
| Parameters | |
| ---------- | |
| batch_size: int | |
| How many transitions to sample. | |
| beta: float | |
| To what degree to use importance weights | |
| (0 - no corrections, 1 - full correction) | |
| Returns | |
| ------- | |
| obs_batch: np.array | |
| batch of observations | |
| act_batch: np.array | |
| batch of actions executed given obs_batch | |
| rew_batch: np.array | |
| rewards received as results of executing act_batch | |
| next_obs_batch: np.array | |
| next set of observations seen after executing act_batch | |
| done_mask: np.array | |
| done_mask[i] = 1 if executing act_batch[i] resulted in | |
| the end of an episode and 0 otherwise. | |
| weights: np.array | |
| Array of shape (batch_size,) and dtype np.float32 | |
| denoting importance weight of each sampled transition | |
| idxes: np.array | |
| Array of shape (batch_size,) and dtype np.int32 | |
| idexes in buffer of sampled experiences | |
| """ | |
| assert beta > 0 | |
| idxes = self._sample_proportional(batch_size) | |
| weights = [] | |
| p_min = self._it_min.min() / self._it_sum.sum() | |
| max_weight = (p_min * len(self._storage)) ** (-beta) | |
| for idx in idxes: | |
| p_sample = self._it_sum[idx] / self._it_sum.sum() | |
| weight = (p_sample * len(self._storage)) ** (-beta) | |
| weights.append(weight / max_weight) | |
| weights = np.array(weights) | |
| encoded_sample = self._encode_sample(idxes) | |
| return tuple(list(encoded_sample) + [weights, idxes]) | |
| def update_priorities(self, idxes, priorities): | |
| """Update priorities of sampled transitions. | |
| sets priority of transition at index idxes[i] in buffer | |
| to priorities[i]. | |
| Parameters | |
| ---------- | |
| idxes: [int] | |
| List of idxes of sampled transitions | |
| priorities: [float] | |
| List of updated priorities corresponding to | |
| transitions at the sampled idxes denoted by | |
| variable `idxes`. | |
| """ | |
| assert len(idxes) == len(priorities) | |
| for idx, priority in zip(idxes, priorities): | |
| assert priority > 0 | |
| assert 0 <= idx < len(self._storage) | |
| self._it_sum[idx] = priority ** self._alpha | |
| self._it_min[idx] = priority ** self._alpha | |
| self._max_priority = max(self._max_priority, priority) | |