from collections import defaultdict import numpy as np import pandas as pd import pickle from tqdm import tqdm from .video_base import VideoDataset from utils.video_utils import write_numpy_to_mp4 class OpenXVideoDataset(VideoDataset): def preprocess_record(self, record): record["fps"] = self.cfg.download.openx_fps # if "bbox" in record: # bbox = eval(record["bbox"]) # if len(bbox) == 5: # record["has_bbox"] = True # record["bbox_left"] = bbox[0] # record["bbox_top"] = bbox[1] # record["bbox_right"] = bbox[2] # record["bbox_bottom"] = bbox[3] # else: # record["has_bbox"] = False # record["bbox_left"] = 0 # record["bbox_top"] = 0 # record["bbox_right"] = 0 # record["bbox_bottom"] = 0 return record def download(self): import tensorflow_datasets as tfds import tensorflow as tf from utils.tf_utils import recursive_cast_to_numpy all_episode_dir = self.data_root / "episodes" all_episode_dir.mkdir(parents=True, exist_ok=True) builder = tfds.builder_from_directory( builder_dir=f"gs://gresearch/robotics/{self.cfg.download.openx_name}/{self.cfg.download.openx_version}" ) info = builder.info n_episodes = info.splits["train"].num_examples # Count number of episodes to skip based on existing state files for episode_id in range(n_episodes): episode_dir = all_episode_dir / f"episode_{episode_id}" state_path = episode_dir / "states.pkl" if not state_path.exists(): break if episode_id > 0: print(f"Skipping {episode_id} already downloaded episodes") dataset = builder.as_dataset(split=f"train[{episode_id}:]") dataset = dataset.prefetch(tf.data.AUTOTUNE) for episode_data in tqdm(dataset, total=n_episodes - episode_id): episode_dir = all_episode_dir / f"episode_{episode_id}" episode_dir.mkdir(parents=True, exist_ok=True) episode_records = defaultdict(list) state_path = episode_dir / "states.pkl" if state_path.exists(): continue episode = defaultdict(list) videos = defaultdict(list) fields_to_stack = [] for k, v in episode_data.items(): if k != "steps": episode[k] = recursive_cast_to_numpy(v) # sometimes we can split a video into multiple segments based on caption segments = { "natural_language_instruction": [], "instruction": [], "language_instruction": [], "language_instruction_2": [], "language_instruction_3": [], } for idx, step in enumerate(episode_data["steps"]): step = recursive_cast_to_numpy(step) obs_dict = step["observation"] action_dict = step["action"] if hasattr(obs_dict, "shape"): obs_dict = dict(observation=obs_dict) if hasattr(action_dict, "shape"): action_dict = dict(action=action_dict) # some times caption field is here but mostly in observation for k, v in step.items(): if k in segments: obs_dict[k] = v for k, v in obs_dict.items(): if hasattr(v, "shape") and len(v.shape) == 3 and v.shape[-1] == 3: videos[k].append(v) elif k in segments: if ( k == "instruction" and self.cfg.download.openx_name == "language_table" ): # special case for language table dataset v = tf.convert_to_tensor(v) v = tf.strings.unicode_encode(v, output_encoding="UTF-8") v = v.numpy().decode("utf-8").split("\x00")[0] if not segments[k] or segments[k][-1][1] != v: segments[k].append((idx, v)) elif k != "natural_language_embedding": if hasattr(v, "shape"): fields_to_stack.append("observation/" + k) episode["observation/" + k].append(v) for k, v in action_dict.items(): fields_to_stack.append("action/" + k) episode["action/" + k].append(v) for k in list(segments.keys()): if not segments[k]: del segments[k] continue segments[k].append((idx + 1, "")) if not segments: segments["not_captioned"] = [(0, ""), (idx + 1, "")] for view, frames in videos.items(): frames = np.stack(frames) n, h, w, _ = frames.shape video_path = episode_dir / f"{view}.mp4" if h % 2 != 0: h = h - 1 frames = frames[:, :h, :, :] if w % 2 != 0: w = w - 1 frames = frames[:, :, :w, :] write_numpy_to_mp4(frames, str(video_path)) for k, v in segments.items(): for s in range(len(v) - 1): start_idx, caption = v[s] end_idx = v[s + 1][0] record = dict( video_path=str(video_path.relative_to(self.data_root)), state_path=str(state_path.relative_to(self.data_root)), height=h, width=w, n_frames=end_idx - start_idx, trim_start=start_idx, trim_end=end_idx, fps=self.cfg.download.openx_fps, original_caption=caption, has_caption=v[0][1] != "", ) episode_records[view].append(record) for view, records in episode_records.items(): df = pd.DataFrame.from_records(records) df.to_csv(episode_dir / f"{view}.csv", index=False) for k in fields_to_stack: episode[k] = np.stack(episode[k]) with open(state_path, "wb") as f: pickle.dump(episode, f) episode_id += 1 # Save metadata metadata_path = self.data_root / self.metadata_path metadata_dir = metadata_path.parent metadata_dir.mkdir(parents=True, exist_ok=True) record_dict = defaultdict(list) for episode_dir in all_episode_dir.glob("episode_*"): for view_csv in episode_dir.glob("*.csv"): view_csv = view_csv.name view_df = pd.read_csv(episode_dir / view_csv) record_dict[view_csv].extend(view_df.to_dict("records")) all_df = [] for view_csv, records in record_dict.items(): df = pd.DataFrame.from_records(records) df.to_csv(metadata_dir / view_csv, index=False) print( f"Created metadata csv for view {view_csv.split('.')[0]} with {len(df)} records" ) if view_csv.replace(".csv", "") in self.cfg.download.views: all_df.append(df) all_df = pd.concat(all_df) all_df.to_csv(metadata_path, index=False) print(f"Created metadata CSV with {len(all_df)} records")