Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,913 Bytes
142a1ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
from collections import defaultdict
import numpy as np
import pandas as pd
import pickle
from tqdm import tqdm
from .video_base import VideoDataset
from utils.video_utils import write_numpy_to_mp4
class OpenXVideoDataset(VideoDataset):
def preprocess_record(self, record):
record["fps"] = self.cfg.download.openx_fps
# if "bbox" in record:
# bbox = eval(record["bbox"])
# if len(bbox) == 5:
# record["has_bbox"] = True
# record["bbox_left"] = bbox[0]
# record["bbox_top"] = bbox[1]
# record["bbox_right"] = bbox[2]
# record["bbox_bottom"] = bbox[3]
# else:
# record["has_bbox"] = False
# record["bbox_left"] = 0
# record["bbox_top"] = 0
# record["bbox_right"] = 0
# record["bbox_bottom"] = 0
return record
def download(self):
import tensorflow_datasets as tfds
import tensorflow as tf
from utils.tf_utils import recursive_cast_to_numpy
all_episode_dir = self.data_root / "episodes"
all_episode_dir.mkdir(parents=True, exist_ok=True)
builder = tfds.builder_from_directory(
builder_dir=f"gs://gresearch/robotics/{self.cfg.download.openx_name}/{self.cfg.download.openx_version}"
)
info = builder.info
n_episodes = info.splits["train"].num_examples
# Count number of episodes to skip based on existing state files
for episode_id in range(n_episodes):
episode_dir = all_episode_dir / f"episode_{episode_id}"
state_path = episode_dir / "states.pkl"
if not state_path.exists():
break
if episode_id > 0:
print(f"Skipping {episode_id} already downloaded episodes")
dataset = builder.as_dataset(split=f"train[{episode_id}:]")
dataset = dataset.prefetch(tf.data.AUTOTUNE)
for episode_data in tqdm(dataset, total=n_episodes - episode_id):
episode_dir = all_episode_dir / f"episode_{episode_id}"
episode_dir.mkdir(parents=True, exist_ok=True)
episode_records = defaultdict(list)
state_path = episode_dir / "states.pkl"
if state_path.exists():
continue
episode = defaultdict(list)
videos = defaultdict(list)
fields_to_stack = []
for k, v in episode_data.items():
if k != "steps":
episode[k] = recursive_cast_to_numpy(v)
# sometimes we can split a video into multiple segments based on caption
segments = {
"natural_language_instruction": [],
"instruction": [],
"language_instruction": [],
"language_instruction_2": [],
"language_instruction_3": [],
}
for idx, step in enumerate(episode_data["steps"]):
step = recursive_cast_to_numpy(step)
obs_dict = step["observation"]
action_dict = step["action"]
if hasattr(obs_dict, "shape"):
obs_dict = dict(observation=obs_dict)
if hasattr(action_dict, "shape"):
action_dict = dict(action=action_dict)
# some times caption field is here but mostly in observation
for k, v in step.items():
if k in segments:
obs_dict[k] = v
for k, v in obs_dict.items():
if hasattr(v, "shape") and len(v.shape) == 3 and v.shape[-1] == 3:
videos[k].append(v)
elif k in segments:
if (
k == "instruction"
and self.cfg.download.openx_name == "language_table"
):
# special case for language table dataset
v = tf.convert_to_tensor(v)
v = tf.strings.unicode_encode(v, output_encoding="UTF-8")
v = v.numpy().decode("utf-8").split("\x00")[0]
if not segments[k] or segments[k][-1][1] != v:
segments[k].append((idx, v))
elif k != "natural_language_embedding":
if hasattr(v, "shape"):
fields_to_stack.append("observation/" + k)
episode["observation/" + k].append(v)
for k, v in action_dict.items():
fields_to_stack.append("action/" + k)
episode["action/" + k].append(v)
for k in list(segments.keys()):
if not segments[k]:
del segments[k]
continue
segments[k].append((idx + 1, ""))
if not segments:
segments["not_captioned"] = [(0, ""), (idx + 1, "")]
for view, frames in videos.items():
frames = np.stack(frames)
n, h, w, _ = frames.shape
video_path = episode_dir / f"{view}.mp4"
if h % 2 != 0:
h = h - 1
frames = frames[:, :h, :, :]
if w % 2 != 0:
w = w - 1
frames = frames[:, :, :w, :]
write_numpy_to_mp4(frames, str(video_path))
for k, v in segments.items():
for s in range(len(v) - 1):
start_idx, caption = v[s]
end_idx = v[s + 1][0]
record = dict(
video_path=str(video_path.relative_to(self.data_root)),
state_path=str(state_path.relative_to(self.data_root)),
height=h,
width=w,
n_frames=end_idx - start_idx,
trim_start=start_idx,
trim_end=end_idx,
fps=self.cfg.download.openx_fps,
original_caption=caption,
has_caption=v[0][1] != "",
)
episode_records[view].append(record)
for view, records in episode_records.items():
df = pd.DataFrame.from_records(records)
df.to_csv(episode_dir / f"{view}.csv", index=False)
for k in fields_to_stack:
episode[k] = np.stack(episode[k])
with open(state_path, "wb") as f:
pickle.dump(episode, f)
episode_id += 1
# Save metadata
metadata_path = self.data_root / self.metadata_path
metadata_dir = metadata_path.parent
metadata_dir.mkdir(parents=True, exist_ok=True)
record_dict = defaultdict(list)
for episode_dir in all_episode_dir.glob("episode_*"):
for view_csv in episode_dir.glob("*.csv"):
view_csv = view_csv.name
view_df = pd.read_csv(episode_dir / view_csv)
record_dict[view_csv].extend(view_df.to_dict("records"))
all_df = []
for view_csv, records in record_dict.items():
df = pd.DataFrame.from_records(records)
df.to_csv(metadata_dir / view_csv, index=False)
print(
f"Created metadata csv for view {view_csv.split('.')[0]} with {len(df)} records"
)
if view_csv.replace(".csv", "") in self.cfg.download.views:
all_df.append(df)
all_df = pd.concat(all_df)
all_df.to_csv(metadata_path, index=False)
print(f"Created metadata CSV with {len(all_df)} records")
|