|
|
from typing import List |
|
|
from torch.utils.data import IterableDataset, Dataset |
|
|
from omegaconf import DictConfig |
|
|
import torch |
|
|
import numpy as np |
|
|
from datasets.dummy import DummyVideoDataset |
|
|
from datasets.openx_base import OpenXVideoDataset |
|
|
from datasets.droid import DroidVideoDataset |
|
|
from datasets.something_something import SomethingSomethingDataset |
|
|
from datasets.epic_kitchen import EpicKitchenDataset |
|
|
from datasets.pandas import PandasVideoDataset |
|
|
from datasets.deprecated.video_1x_wm import WorldModel1XDataset |
|
|
from datasets.agibot_world import AgibotWorldDataset |
|
|
from datasets.ego4d import Ego4DVideoDataset |
|
|
|
|
|
subset_classes = dict( |
|
|
dummy=DummyVideoDataset, |
|
|
something_something=SomethingSomethingDataset, |
|
|
epic_kitchen=EpicKitchenDataset, |
|
|
pandas=PandasVideoDataset, |
|
|
agibot_world=AgibotWorldDataset, |
|
|
video_1x_wm=WorldModel1XDataset, |
|
|
ego4d=Ego4DVideoDataset, |
|
|
droid=DroidVideoDataset, |
|
|
austin_buds=OpenXVideoDataset, |
|
|
austin_sailor=OpenXVideoDataset, |
|
|
austin_sirius=OpenXVideoDataset, |
|
|
bc_z=OpenXVideoDataset, |
|
|
berkeley_autolab=OpenXVideoDataset, |
|
|
berkeley_cable=OpenXVideoDataset, |
|
|
berkeley_fanuc=OpenXVideoDataset, |
|
|
bridge=OpenXVideoDataset, |
|
|
cmu_stretch=OpenXVideoDataset, |
|
|
dlr_edan=OpenXVideoDataset, |
|
|
dobbe=OpenXVideoDataset, |
|
|
fmb=OpenXVideoDataset, |
|
|
fractal=OpenXVideoDataset, |
|
|
iamlab_cmu=OpenXVideoDataset, |
|
|
jaco_play=OpenXVideoDataset, |
|
|
language_table=OpenXVideoDataset, |
|
|
nyu_franka=OpenXVideoDataset, |
|
|
roboturk=OpenXVideoDataset, |
|
|
stanford_hydra=OpenXVideoDataset, |
|
|
taco_play=OpenXVideoDataset, |
|
|
toto=OpenXVideoDataset, |
|
|
ucsd_kitchen=OpenXVideoDataset, |
|
|
utaustin_mutex=OpenXVideoDataset, |
|
|
viola=OpenXVideoDataset, |
|
|
) |
|
|
|
|
|
|
|
|
class MixtureDataset(IterableDataset): |
|
|
""" |
|
|
A fault tolerant mixture of video datasets |
|
|
""" |
|
|
|
|
|
def __init__(self, cfg: DictConfig, split: str = "training"): |
|
|
super().__init__() |
|
|
self.cfg = cfg |
|
|
self.debug = cfg.debug |
|
|
self.split = split |
|
|
self.random_seed = np.random.get_state()[1][0] |
|
|
self.subset_cfg = { |
|
|
k.split("/")[1]: v for k, v in self.cfg.items() if k.startswith("subset/") |
|
|
} |
|
|
if split == "all": |
|
|
raise ValueError("split cannot be `all` for MixtureDataset`") |
|
|
weight = dict(self.cfg[split].weight) |
|
|
|
|
|
for key in weight: |
|
|
if key not in self.subset_cfg: |
|
|
raise ValueError( |
|
|
f"Dataset '{key}' specified in weights but not found in configuration" |
|
|
) |
|
|
self.subset_cfg = {k: v for k, v in self.subset_cfg.items() if k in weight} |
|
|
weight_type = self.cfg[split].weight_type |
|
|
self.subsets: List[Dataset] = [] |
|
|
for subset_name, subset_cfg in self.subset_cfg.items(): |
|
|
subset_cfg["height"] = self.cfg.height |
|
|
subset_cfg["width"] = self.cfg.width |
|
|
subset_cfg["n_frames"] = self.cfg.n_frames |
|
|
subset_cfg["fps"] = self.cfg.fps |
|
|
subset_cfg["load_video_latent"] = self.cfg.load_video_latent |
|
|
subset_cfg["load_prompt_embed"] = self.cfg.load_prompt_embed |
|
|
subset_cfg["max_text_tokens"] = self.cfg.max_text_tokens |
|
|
subset_cfg["image_to_video"] = self.cfg.image_to_video |
|
|
self.subsets.append(subset_classes[subset_name](subset_cfg, split)) |
|
|
if weight_type == "relative": |
|
|
weight[subset_name] = weight[subset_name] * len(self.subsets[-1]) |
|
|
|
|
|
|
|
|
total_weight = sum(weight.values()) |
|
|
self.normalized_weights = {k: v / total_weight for k, v in weight.items()} |
|
|
|
|
|
|
|
|
dataset_sizes = { |
|
|
subset_name: len(subset) |
|
|
for subset_name, subset in zip(self.subset_cfg.keys(), self.subsets) |
|
|
} |
|
|
|
|
|
|
|
|
print("\nDataset information for split '{}':".format(self.split)) |
|
|
print("-" * 60) |
|
|
print(f"{'Dataset':<25} {'Size':<10} {'Weight':<10} {'Normalized':<10}") |
|
|
print("-" * 60) |
|
|
for subset_name, norm_weight in sorted( |
|
|
self.normalized_weights.items(), key=lambda x: -x[1] |
|
|
): |
|
|
size = dataset_sizes[subset_name] |
|
|
orig_weight = self.cfg[split].weight[subset_name] |
|
|
print( |
|
|
f"{subset_name:<25} {size:<10,d} {orig_weight:<10.4f} {norm_weight:<10.4f}" |
|
|
) |
|
|
print("-" * 60) |
|
|
|
|
|
|
|
|
self.cumsum_weights = {} |
|
|
cumsum = 0 |
|
|
for k, v in self.normalized_weights.items(): |
|
|
cumsum += v |
|
|
self.cumsum_weights[k] = cumsum |
|
|
|
|
|
|
|
|
self.records = [] |
|
|
for subset in self.subsets: |
|
|
self.records.extend(subset.records) |
|
|
|
|
|
def __iter__(self): |
|
|
while True: |
|
|
|
|
|
rand = np.random.random() |
|
|
for subset_name, cumsum in self.cumsum_weights.items(): |
|
|
if rand <= cumsum: |
|
|
selected_subset = subset_name |
|
|
break |
|
|
|
|
|
|
|
|
subset_idx = list(self.subset_cfg.keys()).index(selected_subset) |
|
|
|
|
|
try: |
|
|
|
|
|
dataset = self.subsets[subset_idx] |
|
|
idx = np.random.randint(len(dataset)) |
|
|
sample = dataset[idx] |
|
|
yield sample |
|
|
except Exception as e: |
|
|
if self.debug: |
|
|
raise e |
|
|
else: |
|
|
print(f"Error sampling from {selected_subset}: {str(e)}") |
|
|
continue |
|
|
|