|
|
from collections import defaultdict |
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import pickle |
|
|
from tqdm import tqdm |
|
|
|
|
|
from .video_base import VideoDataset |
|
|
from utils.video_utils import write_numpy_to_mp4 |
|
|
|
|
|
|
|
|
class OpenXVideoDataset(VideoDataset): |
|
|
def preprocess_record(self, record): |
|
|
record["fps"] = self.cfg.download.openx_fps |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return record |
|
|
|
|
|
def download(self): |
|
|
import tensorflow_datasets as tfds |
|
|
import tensorflow as tf |
|
|
from utils.tf_utils import recursive_cast_to_numpy |
|
|
|
|
|
all_episode_dir = self.data_root / "episodes" |
|
|
all_episode_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
builder = tfds.builder_from_directory( |
|
|
builder_dir=f"gs://gresearch/robotics/{self.cfg.download.openx_name}/{self.cfg.download.openx_version}" |
|
|
) |
|
|
info = builder.info |
|
|
n_episodes = info.splits["train"].num_examples |
|
|
|
|
|
|
|
|
for episode_id in range(n_episodes): |
|
|
episode_dir = all_episode_dir / f"episode_{episode_id}" |
|
|
state_path = episode_dir / "states.pkl" |
|
|
if not state_path.exists(): |
|
|
break |
|
|
|
|
|
if episode_id > 0: |
|
|
print(f"Skipping {episode_id} already downloaded episodes") |
|
|
dataset = builder.as_dataset(split=f"train[{episode_id}:]") |
|
|
|
|
|
dataset = dataset.prefetch(tf.data.AUTOTUNE) |
|
|
for episode_data in tqdm(dataset, total=n_episodes - episode_id): |
|
|
episode_dir = all_episode_dir / f"episode_{episode_id}" |
|
|
episode_dir.mkdir(parents=True, exist_ok=True) |
|
|
episode_records = defaultdict(list) |
|
|
state_path = episode_dir / "states.pkl" |
|
|
if state_path.exists(): |
|
|
continue |
|
|
|
|
|
episode = defaultdict(list) |
|
|
videos = defaultdict(list) |
|
|
fields_to_stack = [] |
|
|
for k, v in episode_data.items(): |
|
|
if k != "steps": |
|
|
episode[k] = recursive_cast_to_numpy(v) |
|
|
|
|
|
|
|
|
segments = { |
|
|
"natural_language_instruction": [], |
|
|
"instruction": [], |
|
|
"language_instruction": [], |
|
|
"language_instruction_2": [], |
|
|
"language_instruction_3": [], |
|
|
} |
|
|
for idx, step in enumerate(episode_data["steps"]): |
|
|
step = recursive_cast_to_numpy(step) |
|
|
obs_dict = step["observation"] |
|
|
action_dict = step["action"] |
|
|
if hasattr(obs_dict, "shape"): |
|
|
obs_dict = dict(observation=obs_dict) |
|
|
if hasattr(action_dict, "shape"): |
|
|
action_dict = dict(action=action_dict) |
|
|
|
|
|
|
|
|
for k, v in step.items(): |
|
|
if k in segments: |
|
|
obs_dict[k] = v |
|
|
|
|
|
for k, v in obs_dict.items(): |
|
|
if hasattr(v, "shape") and len(v.shape) == 3 and v.shape[-1] == 3: |
|
|
videos[k].append(v) |
|
|
elif k in segments: |
|
|
if ( |
|
|
k == "instruction" |
|
|
and self.cfg.download.openx_name == "language_table" |
|
|
): |
|
|
|
|
|
v = tf.convert_to_tensor(v) |
|
|
v = tf.strings.unicode_encode(v, output_encoding="UTF-8") |
|
|
v = v.numpy().decode("utf-8").split("\x00")[0] |
|
|
if not segments[k] or segments[k][-1][1] != v: |
|
|
segments[k].append((idx, v)) |
|
|
elif k != "natural_language_embedding": |
|
|
if hasattr(v, "shape"): |
|
|
fields_to_stack.append("observation/" + k) |
|
|
episode["observation/" + k].append(v) |
|
|
|
|
|
for k, v in action_dict.items(): |
|
|
fields_to_stack.append("action/" + k) |
|
|
episode["action/" + k].append(v) |
|
|
|
|
|
for k in list(segments.keys()): |
|
|
if not segments[k]: |
|
|
del segments[k] |
|
|
continue |
|
|
segments[k].append((idx + 1, "")) |
|
|
if not segments: |
|
|
segments["not_captioned"] = [(0, ""), (idx + 1, "")] |
|
|
|
|
|
for view, frames in videos.items(): |
|
|
frames = np.stack(frames) |
|
|
n, h, w, _ = frames.shape |
|
|
video_path = episode_dir / f"{view}.mp4" |
|
|
|
|
|
if h % 2 != 0: |
|
|
h = h - 1 |
|
|
frames = frames[:, :h, :, :] |
|
|
if w % 2 != 0: |
|
|
w = w - 1 |
|
|
frames = frames[:, :, :w, :] |
|
|
write_numpy_to_mp4(frames, str(video_path)) |
|
|
|
|
|
for k, v in segments.items(): |
|
|
for s in range(len(v) - 1): |
|
|
start_idx, caption = v[s] |
|
|
end_idx = v[s + 1][0] |
|
|
record = dict( |
|
|
video_path=str(video_path.relative_to(self.data_root)), |
|
|
state_path=str(state_path.relative_to(self.data_root)), |
|
|
height=h, |
|
|
width=w, |
|
|
n_frames=end_idx - start_idx, |
|
|
trim_start=start_idx, |
|
|
trim_end=end_idx, |
|
|
fps=self.cfg.download.openx_fps, |
|
|
original_caption=caption, |
|
|
has_caption=v[0][1] != "", |
|
|
) |
|
|
episode_records[view].append(record) |
|
|
for view, records in episode_records.items(): |
|
|
df = pd.DataFrame.from_records(records) |
|
|
df.to_csv(episode_dir / f"{view}.csv", index=False) |
|
|
|
|
|
for k in fields_to_stack: |
|
|
episode[k] = np.stack(episode[k]) |
|
|
with open(state_path, "wb") as f: |
|
|
pickle.dump(episode, f) |
|
|
episode_id += 1 |
|
|
|
|
|
|
|
|
metadata_path = self.data_root / self.metadata_path |
|
|
metadata_dir = metadata_path.parent |
|
|
metadata_dir.mkdir(parents=True, exist_ok=True) |
|
|
record_dict = defaultdict(list) |
|
|
for episode_dir in all_episode_dir.glob("episode_*"): |
|
|
for view_csv in episode_dir.glob("*.csv"): |
|
|
view_csv = view_csv.name |
|
|
view_df = pd.read_csv(episode_dir / view_csv) |
|
|
record_dict[view_csv].extend(view_df.to_dict("records")) |
|
|
all_df = [] |
|
|
for view_csv, records in record_dict.items(): |
|
|
df = pd.DataFrame.from_records(records) |
|
|
df.to_csv(metadata_dir / view_csv, index=False) |
|
|
print( |
|
|
f"Created metadata csv for view {view_csv.split('.')[0]} with {len(df)} records" |
|
|
) |
|
|
if view_csv.replace(".csv", "") in self.cfg.download.views: |
|
|
all_df.append(df) |
|
|
all_df = pd.concat(all_df) |
|
|
all_df.to_csv(metadata_path, index=False) |
|
|
print(f"Created metadata CSV with {len(all_df)} records") |
|
|
|