| from __future__ import annotations |
| import torch |
| import numpy as np |
| from diffusion_policy.common.sampler import SequenceSampler |
| from pfp.data.replay_buffer import RobotReplayBuffer |
| from pfp import DATA_DIRS |
|
|
|
|
| class RobotDatasetImages(torch.utils.data.Dataset): |
| def __init__( |
| self, |
| data_path: str, |
| n_obs_steps: int, |
| n_pred_steps: int, |
| subs_factor: int = 1, |
| **kwargs, |
| ) -> None: |
| """ |
| To me it makes sense that sequence_length == n_obs_steps + n_prediction_steps |
| """ |
| replay_buffer = RobotReplayBuffer.create_from_path(data_path, mode="r") |
| data_keys = ["robot_state", "images"] |
| data_key_first_k = {"images": n_obs_steps * subs_factor} |
| self.sampler = SequenceSampler( |
| replay_buffer=replay_buffer, |
| sequence_length=(n_obs_steps + n_pred_steps) * subs_factor - (subs_factor - 1), |
| pad_before=(n_obs_steps - 1) * subs_factor, |
| pad_after=(n_pred_steps - 1) * subs_factor + (subs_factor - 1), |
| keys=data_keys, |
| key_first_k=data_key_first_k, |
| ) |
| self.n_obs_steps = n_obs_steps |
| self.n_prediction_steps = n_pred_steps |
| self.subs_factor = subs_factor |
| self.rng = np.random.default_rng() |
| return |
|
|
| def __len__(self) -> int: |
| return len(self.sampler) |
|
|
| def __getitem__(self, idx: int) -> tuple[torch.Tensor, ...]: |
| sample: dict[str, np.ndarray] = self.sampler.sample_sequence(idx) |
| cur_step_i = self.n_obs_steps * self.subs_factor |
| images = sample["images"][: cur_step_i : self.subs_factor] |
| robot_state_obs = sample["robot_state"][: cur_step_i : self.subs_factor] |
| robot_state_pred = sample["robot_state"][cur_step_i :: self.subs_factor] |
| return images, robot_state_obs, robot_state_pred |
|
|
|
|
| if __name__ == "__main__": |
| dataset = RobotDatasetImages( |
| data_path=DATA_DIRS.PFP / "open_fridge" / "train", |
| n_obs_steps=2, |
| n_pred_steps=8, |
| subs_factor=5, |
| ) |
| i = 20 |
| obs, robot_state_obs, robot_state_pred = dataset[i] |
| print("robot_state_obs: ", robot_state_obs) |
| print("robot_state_pred: ", robot_state_pred) |
| print("done") |
|
|