|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Run this command to interactively debug: |
|
|
PYTHONPATH=. python cosmos_predict1/diffusion/training/datasets/dataset_multiview.py |
|
|
|
|
|
Adapted from: |
|
|
https://github.com/bytedance/IRASim/blob/main/dataset/dataset_3D.py |
|
|
""" |
|
|
|
|
|
import os |
|
|
import pickle |
|
|
import traceback |
|
|
import warnings |
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from decord import VideoReader, cpu |
|
|
from torch.utils.data import Dataset |
|
|
from torchvision import transforms as T |
|
|
from tqdm import tqdm |
|
|
|
|
|
from cosmos_predict1.diffusion.training.datasets.dataset_utils import Resize_Preprocess, ToTensorVideo |
|
|
|
|
|
|
|
|
class Dataset(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
dataset_dir, |
|
|
sequence_interval, |
|
|
num_frames, |
|
|
view_keys, |
|
|
video_size, |
|
|
start_frame_interval=1, |
|
|
): |
|
|
"""Dataset class for loading image-text-to-video generation data. |
|
|
|
|
|
Args: |
|
|
dataset_dir (str): Base path to the dataset directory |
|
|
sequence_interval (int): Interval between sampled frames in a sequence |
|
|
num_frames (int): Number of frames to load per sequence |
|
|
video_size (list): Target size [H,W] for video frames |
|
|
|
|
|
Returns dict with: |
|
|
- video: RGB frames tensor [T,C,H,W] |
|
|
- video_name: Dict with episode/frame metadata |
|
|
""" |
|
|
|
|
|
super().__init__() |
|
|
self.dataset_dir = dataset_dir |
|
|
self.start_frame_interval = start_frame_interval |
|
|
self.sequence_interval = sequence_interval |
|
|
self.sequence_length = num_frames |
|
|
self.view_keys = view_keys |
|
|
|
|
|
video_dir = os.path.join(self.dataset_dir, "videos") |
|
|
self.video_paths = [ |
|
|
os.path.join(video_dir, view_keys[0], f) for f in os.listdir(os.path.join(video_dir, view_keys[0])) |
|
|
] |
|
|
print(f"{len(self.video_paths)} videos in total") |
|
|
|
|
|
self.t5_dir = os.path.join(self.dataset_dir, "t5_xxl") |
|
|
self.samples = self._init_samples(self.video_paths) |
|
|
self.samples = sorted(self.samples, key=lambda x: (x["video_path"], x["frame_ids"][0])) |
|
|
print(f"{len(self.samples)} samples in total") |
|
|
self.wrong_number = 0 |
|
|
self.preprocess = T.Compose([ToTensorVideo(), Resize_Preprocess(tuple(video_size))]) |
|
|
|
|
|
cache_dir = os.path.join(self.dataset_dir, "cache") |
|
|
self.prefix_t5_embeddings = {} |
|
|
for view_key in view_keys: |
|
|
with open(os.path.join(cache_dir, f"prefix_t5_embeddings_{view_key}.pickle"), "rb") as f: |
|
|
self.prefix_t5_embeddings[view_key] = pickle.load(f)[0] |
|
|
|
|
|
def __str__(self): |
|
|
return f"{len(self.video_paths)} samples from {self.dataset_dir}" |
|
|
|
|
|
def _init_samples(self, video_paths): |
|
|
samples = [] |
|
|
with ThreadPoolExecutor(32) as executor: |
|
|
future_to_video_path = { |
|
|
executor.submit(self._load_and_process_video_path, video_path): video_path for video_path in video_paths |
|
|
} |
|
|
for future in tqdm(as_completed(future_to_video_path), total=len(video_paths)): |
|
|
samples.extend(future.result()) |
|
|
return samples |
|
|
|
|
|
def _load_and_process_video_path(self, video_path): |
|
|
vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) |
|
|
n_frames = len(vr) |
|
|
|
|
|
samples = [] |
|
|
for frame_i in range(0, n_frames, self.start_frame_interval): |
|
|
sample = dict() |
|
|
sample["video_path"] = video_path |
|
|
sample["t5_embedding_path"] = os.path.join( |
|
|
self.t5_dir, |
|
|
os.path.basename(os.path.dirname(video_path)), |
|
|
os.path.basename(video_path).replace(".mp4", ".pickle"), |
|
|
) |
|
|
sample["frame_ids"] = [] |
|
|
curr_frame_i = frame_i |
|
|
while True: |
|
|
if curr_frame_i > (n_frames - 1): |
|
|
break |
|
|
sample["frame_ids"].append(curr_frame_i) |
|
|
if len(sample["frame_ids"]) == self.sequence_length: |
|
|
break |
|
|
curr_frame_i += self.sequence_interval |
|
|
|
|
|
if len(sample["frame_ids"]) == self.sequence_length: |
|
|
samples.append(sample) |
|
|
return samples |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.samples) |
|
|
|
|
|
def _load_video(self, video_path, frame_ids): |
|
|
vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) |
|
|
assert (np.array(frame_ids) < len(vr)).all() |
|
|
assert (np.array(frame_ids) >= 0).all() |
|
|
vr.seek(0) |
|
|
frame_data = vr.get_batch(frame_ids).asnumpy() |
|
|
try: |
|
|
fps = vr.get_avg_fps() |
|
|
except Exception: |
|
|
fps = 24 |
|
|
return frame_data, fps |
|
|
|
|
|
def _get_frames(self, video_path, frame_ids): |
|
|
frames, fps = self._load_video(video_path, frame_ids) |
|
|
frames = frames.astype(np.uint8) |
|
|
frames = torch.from_numpy(frames).permute(0, 3, 1, 2) |
|
|
frames = self.preprocess(frames) |
|
|
frames = torch.clamp(frames * 255.0, 0, 255).to(torch.uint8) |
|
|
return frames, fps |
|
|
|
|
|
def __getitem__(self, index): |
|
|
try: |
|
|
sample = self.samples[index] |
|
|
video_path = sample["video_path"] |
|
|
frame_ids = sample["frame_ids"] |
|
|
t5_embedding_path = sample["t5_embedding_path"] |
|
|
|
|
|
data = dict() |
|
|
|
|
|
videos = [] |
|
|
t5_embeddings = [] |
|
|
for view_key in self.view_keys: |
|
|
video, fps = self._get_frames( |
|
|
os.path.join(os.path.dirname(os.path.dirname(video_path)), view_key, os.path.basename(video_path)), |
|
|
frame_ids, |
|
|
) |
|
|
video = video.permute(1, 0, 2, 3) |
|
|
videos.append(video) |
|
|
|
|
|
with open( |
|
|
os.path.join( |
|
|
os.path.dirname(os.path.dirname(t5_embedding_path)), |
|
|
view_key, |
|
|
os.path.basename(t5_embedding_path), |
|
|
), |
|
|
"rb", |
|
|
) as f: |
|
|
t5_embedding = pickle.load(f)[0] |
|
|
t5_embedding = np.concatenate([self.prefix_t5_embeddings[view_key], t5_embedding], axis=0) |
|
|
t5_embedding = torch.from_numpy(t5_embedding) |
|
|
if t5_embedding.shape[0] < 512: |
|
|
t5_embedding = torch.cat([t5_embedding, torch.zeros(512 - t5_embedding.shape[0], 1024)], dim=0) |
|
|
t5_embeddings.append(t5_embedding) |
|
|
video = torch.cat(videos, dim=1) |
|
|
t5_embedding = torch.cat(t5_embeddings, dim=0) |
|
|
|
|
|
data["video"] = video |
|
|
data["video_name"] = { |
|
|
"video_path": video_path, |
|
|
"t5_embedding_path": t5_embedding_path, |
|
|
"start_frame_id": str(frame_ids[0]), |
|
|
} |
|
|
data["t5_text_embeddings"] = t5_embedding |
|
|
data["t5_text_mask"] = torch.ones(512 * len(self.view_keys), dtype=torch.int64) |
|
|
data["fps"] = fps |
|
|
data["image_size"] = torch.tensor([704, 1280, 704, 1280]) |
|
|
data["num_frames"] = self.sequence_length |
|
|
data["padding_mask"] = torch.zeros(1, 704, 1280) |
|
|
|
|
|
return data |
|
|
except Exception: |
|
|
warnings.warn( |
|
|
f"Invalid data encountered: {self.samples[index]['video_path']}. Skipped " |
|
|
f"(by randomly sampling another sample in the same dataset)." |
|
|
) |
|
|
warnings.warn("FULL TRACEBACK:") |
|
|
warnings.warn(traceback.format_exc()) |
|
|
self.wrong_number += 1 |
|
|
print(self.wrong_number) |
|
|
return self[np.random.randint(len(self.samples))] |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
dataset = Dataset( |
|
|
dataset_dir="datasets/waymo/", |
|
|
sequence_interval=1, |
|
|
num_frames=57, |
|
|
view_keys=[ |
|
|
"pinhole_front_left", |
|
|
"pinhole_front", |
|
|
"pinhole_front_right", |
|
|
"pinhole_side_left", |
|
|
"pinhole_side_right", |
|
|
], |
|
|
video_size=[240, 360], |
|
|
) |
|
|
|
|
|
indices = [0, 13, 200, -1] |
|
|
for idx in indices: |
|
|
data = dataset[idx] |
|
|
print( |
|
|
( |
|
|
f"{idx=} " |
|
|
f"{data['video'].sum()=}\n" |
|
|
f"{data['video'].shape=}\n" |
|
|
f"{data['video_name']=}\n" |
|
|
f"{data['t5_text_embeddings'].shape=}\n" |
|
|
"---" |
|
|
) |
|
|
) |
|
|
|