import numpy as np from collections import deque class TemporaryBuffer: def __init__(self, delayed_steps): self.d = delayed_steps self.states = deque(maxlen=delayed_steps + 2) self.actions = deque(maxlen=2 * delayed_steps + 1) def clear(self): self.states.clear() self.actions.clear() def get_augmented_state(self, last_observed_state, first_action_idx): aug_state = np.concatenate([last_observed_state, self.actions[first_action_idx]]) for i in range(first_action_idx + 1, first_action_idx + self.d): aug_state = np.concatenate([aug_state, self.actions[i]]) return aug_state def get_tuple(self): assert len(self.states) == self.d + 2 and len(self.actions) == 2 * self.d + 1 aug_s = self.get_augmented_state(self.states[0], 0) s = self.states[-2] a = self.actions[self.d] next_aug_s = self.get_augmented_state(self.states[1], 1) next_s = self.states[-1] self.states.popleft() self.actions.popleft() return aug_s, s, a, next_aug_s, next_s