| from typing import Optional |
| import numpy as np |
| import numba |
| from diffusion_policy.common.replay_buffer import ReplayBuffer |
|
|
|
|
| @numba.jit(nopython=True) |
| def create_indices( |
| episode_ends:np.ndarray, sequence_length:int, |
| episode_mask: np.ndarray, |
| pad_before: int=0, pad_after: int=0, |
| debug:bool=True) -> np.ndarray: |
| episode_mask.shape == episode_ends.shape |
| pad_before = min(max(pad_before, 0), sequence_length-1) |
| pad_after = min(max(pad_after, 0), sequence_length-1) |
|
|
| indices = list() |
| for i in range(len(episode_ends)): |
| if not episode_mask[i]: |
| |
| continue |
| start_idx = 0 |
| if i > 0: |
| start_idx = episode_ends[i-1] |
| end_idx = episode_ends[i] |
| episode_length = end_idx - start_idx |
| |
| min_start = -pad_before |
| max_start = episode_length - sequence_length + pad_after |
| |
| |
| for idx in range(min_start, max_start+1): |
| buffer_start_idx = max(idx, 0) + start_idx |
| buffer_end_idx = min(idx+sequence_length, episode_length) + start_idx |
| start_offset = buffer_start_idx - (idx+start_idx) |
| end_offset = (idx+sequence_length+start_idx) - buffer_end_idx |
| sample_start_idx = 0 + start_offset |
| sample_end_idx = sequence_length - end_offset |
| if debug: |
| assert(start_offset >= 0) |
| assert(end_offset >= 0) |
| assert (sample_end_idx - sample_start_idx) == (buffer_end_idx - buffer_start_idx) |
| indices.append([ |
| buffer_start_idx, buffer_end_idx, |
| sample_start_idx, sample_end_idx]) |
| indices = np.array(indices) |
| return indices |
|
|
|
|
| def get_val_mask(n_episodes, val_ratio, seed=0): |
| val_mask = np.zeros(n_episodes, dtype=bool) |
| if val_ratio <= 0: |
| return val_mask |
|
|
| |
| n_val = min(max(1, round(n_episodes * val_ratio)), n_episodes-1) |
| rng = np.random.default_rng(seed=seed) |
| val_idxs = rng.choice(n_episodes, size=n_val, replace=False) |
| val_mask[val_idxs] = True |
| return val_mask |
|
|
|
|
| def downsample_mask(mask, max_n, seed=0): |
| |
| train_mask = mask |
| if (max_n is not None) and (np.sum(train_mask) > max_n): |
| n_train = int(max_n) |
| curr_train_idxs = np.nonzero(train_mask)[0] |
| rng = np.random.default_rng(seed=seed) |
| train_idxs_idx = rng.choice(len(curr_train_idxs), size=n_train, replace=False) |
| train_idxs = curr_train_idxs[train_idxs_idx] |
| train_mask = np.zeros_like(train_mask) |
| train_mask[train_idxs] = True |
| assert np.sum(train_mask) == n_train |
| return train_mask |
|
|
| class SequenceSampler: |
| def __init__(self, |
| replay_buffer: ReplayBuffer, |
| sequence_length:int, |
| pad_before:int=0, |
| pad_after:int=0, |
| keys=None, |
| key_first_k=dict(), |
| episode_mask: Optional[np.ndarray]=None, |
| ): |
| """ |
| key_first_k: dict str: int |
| Only take first k data from these keys (to improve perf) |
| """ |
|
|
| super().__init__() |
| assert(sequence_length >= 1) |
| if keys is None: |
| keys = list(replay_buffer.keys()) |
| |
| episode_ends = replay_buffer.episode_ends[:] |
| if episode_mask is None: |
| episode_mask = np.ones(episode_ends.shape, dtype=bool) |
|
|
| if np.any(episode_mask): |
| indices = create_indices(episode_ends, |
| sequence_length=sequence_length, |
| pad_before=pad_before, |
| pad_after=pad_after, |
| episode_mask=episode_mask |
| ) |
| else: |
| indices = np.zeros((0,4), dtype=np.int64) |
|
|
| |
| self.indices = indices |
| self.keys = list(keys) |
| self.sequence_length = sequence_length |
| self.replay_buffer = replay_buffer |
| self.key_first_k = key_first_k |
| |
| def __len__(self): |
| return len(self.indices) |
| |
| def sample_sequence(self, idx): |
| buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx \ |
| = self.indices[idx] |
| result = dict() |
| for key in self.keys: |
| input_arr = self.replay_buffer[key] |
| |
| if key not in self.key_first_k: |
| sample = input_arr[buffer_start_idx:buffer_end_idx] |
| else: |
| |
| n_data = buffer_end_idx - buffer_start_idx |
| k_data = min(self.key_first_k[key], n_data) |
| |
| |
| sample = np.full((n_data,) + input_arr.shape[1:], |
| fill_value=np.nan, dtype=input_arr.dtype) |
| try: |
| sample[:k_data] = input_arr[buffer_start_idx:buffer_start_idx+k_data] |
| except Exception as e: |
| import pdb; pdb.set_trace() |
| data = sample |
| if (sample_start_idx > 0) or (sample_end_idx < self.sequence_length): |
| data = np.zeros( |
| shape=(self.sequence_length,) + input_arr.shape[1:], |
| dtype=input_arr.dtype) |
| if sample_start_idx > 0: |
| data[:sample_start_idx] = sample[0] |
| if sample_end_idx < self.sequence_length: |
| data[sample_end_idx:] = sample[-1] |
| data[sample_start_idx:sample_end_idx] = sample |
| result[key] = data |
| return result |
|
|