import numpy as np import torch import utils class ReplayBuffer(object): """Buffer to store environment transitions.""" def __init__(self, obs_shape, action_shape, capacity, device, window=1, store_image=False, image_size=300): self.capacity = capacity self.device = device # the proprioceptive obs is stored as float32, pixels obs as uint8 obs_dtype = np.float32 if len(obs_shape) == 1 else np.uint8 self.obses = np.empty((capacity, *obs_shape), dtype=obs_dtype) self.next_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype) self.actions = np.empty((capacity, *action_shape), dtype=np.float32) self.rewards = np.empty((capacity, 1), dtype=np.float32) self.not_dones = np.empty((capacity, 1), dtype=np.float32) self.not_dones_no_max = np.empty((capacity, 1), dtype=np.float32) self.window = window self.store_image = store_image if self.store_image: self.images = np.empty((capacity, image_size, image_size, 3), dtype=np.uint8) self.idx = 0 self.last_save = 0 self.full = False def __len__(self): return self.capacity if self.full else self.idx def add(self, obs, action, reward, next_obs, done, done_no_max, image=None): np.copyto(self.obses[self.idx], obs) np.copyto(self.actions[self.idx], action) np.copyto(self.rewards[self.idx], reward) np.copyto(self.next_obses[self.idx], next_obs) np.copyto(self.not_dones[self.idx], not done) np.copyto(self.not_dones_no_max[self.idx], not done_no_max) if image is not None and self.store_image: np.copyto(self.images[self.idx], image) self.idx = (self.idx + 1) % self.capacity self.full = self.full or self.idx == 0 def add_batch(self, obs, action, reward, next_obs, done, done_no_max): next_index = self.idx + self.window if next_index >= self.capacity: self.full = True maximum_index = self.capacity - self.idx np.copyto(self.obses[self.idx:self.capacity], obs[:maximum_index]) np.copyto(self.actions[self.idx:self.capacity], action[:maximum_index]) np.copyto(self.rewards[self.idx:self.capacity], reward[:maximum_index]) np.copyto(self.next_obses[self.idx:self.capacity], next_obs[:maximum_index]) np.copyto(self.not_dones[self.idx:self.capacity], done[:maximum_index] <= 0) np.copyto(self.not_dones_no_max[self.idx:self.capacity], done_no_max[:maximum_index] <= 0) remain = self.window - (maximum_index) if remain > 0: np.copyto(self.obses[0:remain], obs[maximum_index:]) np.copyto(self.actions[0:remain], action[maximum_index:]) np.copyto(self.rewards[0:remain], reward[maximum_index:]) np.copyto(self.next_obses[0:remain], next_obs[maximum_index:]) np.copyto(self.not_dones[0:remain], done[maximum_index:] <= 0) np.copyto(self.not_dones_no_max[0:remain], done_no_max[maximum_index:] <= 0) self.idx = remain else: np.copyto(self.obses[self.idx:next_index], obs) np.copyto(self.actions[self.idx:next_index], action) np.copyto(self.rewards[self.idx:next_index], reward) np.copyto(self.next_obses[self.idx:next_index], next_obs) np.copyto(self.not_dones[self.idx:next_index], done <= 0) np.copyto(self.not_dones_no_max[self.idx:next_index], done_no_max <= 0) self.idx = next_index def relabel_with_predictor(self, predictor): if not self.store_image: batch_size = 200 else: batch_size = 32 total_iter = int(self.idx/batch_size) if self.idx > batch_size*total_iter: total_iter += 1 for index in range(total_iter): last_index = (index+1)*batch_size if (index+1)*batch_size > self.idx: last_index = self.idx if not self.store_image: obses = self.obses[index*batch_size:last_index] actions = self.actions[index*batch_size:last_index] inputs = np.concatenate([obses, actions], axis=-1) else: inputs = self.images[index*batch_size:last_index] inputs = np.transpose(inputs, (0, 3, 1, 2)) inputs = inputs.astype(np.float32) / 255.0 pred_reward = predictor.r_hat_batch(inputs) self.rewards[index*batch_size:last_index] = pred_reward torch.cuda.empty_cache() def sample(self, batch_size): idxs = np.random.randint(0, self.capacity if self.full else self.idx, size=batch_size) obses = torch.as_tensor(self.obses[idxs], device=self.device).float() actions = torch.as_tensor(self.actions[idxs], device=self.device) rewards = torch.as_tensor(self.rewards[idxs], device=self.device) next_obses = torch.as_tensor(self.next_obses[idxs], device=self.device).float() not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device) not_dones_no_max = torch.as_tensor(self.not_dones_no_max[idxs], device=self.device) return obses, actions, rewards, next_obses, not_dones, not_dones_no_max def sample_state_ent(self, batch_size): idxs = np.random.randint(0, self.capacity if self.full else self.idx, size=batch_size) obses = torch.as_tensor(self.obses[idxs], device=self.device).float() actions = torch.as_tensor(self.actions[idxs], device=self.device) rewards = torch.as_tensor(self.rewards[idxs], device=self.device) next_obses = torch.as_tensor(self.next_obses[idxs], device=self.device).float() not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device) not_dones_no_max = torch.as_tensor(self.not_dones_no_max[idxs], device=self.device) if self.full: full_obs = self.obses else: full_obs = self.obses[: self.idx] full_obs = torch.as_tensor(full_obs, device=self.device) return obses, full_obs, actions, rewards, next_obses, not_dones, not_dones_no_max