File size: 1,602 Bytes
912c7e2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | from __future__ import annotations
import zarr
import numpy as np
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.codecs.imagecodecs_numcodecs import register_codec, Jpeg2k
register_codec(Jpeg2k)
class RobotReplayBuffer(ReplayBuffer):
def __init__(self, root: zarr.Group):
super().__init__(root)
self.jpeg_compressor = Jpeg2k()
return
def add_episode_from_list(self, data_list: list[dict[str, np.ndarray]], **kwargs):
"""
data_list is a list of dictionaries, where each dictionary contains the data for one step.
"""
data_dict = dict()
for key in data_list[0].keys():
data_dict[key] = np.stack([x[key] for x in data_list])
self.add_episode(data_dict, **kwargs)
return
def add_episode_from_list_compressed(self, data_list: list[dict[str, np.ndarray]], **kwargs):
"""
data_list is a list of dictionaries, where each dictionary contains the data for one step.
WARNING: decoding (i.e. reading) is broken.
"""
data_dict = {key: np.stack([x[key] for x in data_list]) for key in data_list[0].keys()}
# get the keys starting with 'rgb*'
rgb_keys = [key for key in data_dict.keys() if key.startswith("rgb")]
rgb_shapes = [data_list[0][key].shape for key in rgb_keys]
chunks = {rgb_keys[i]: (1, *rgb_shapes[i]) for i in range(len(rgb_keys))}
compressors = {key: self.jpeg_compressor for key in rgb_keys}
self.add_episode(data_dict, chunks, compressors, **kwargs)
return
|