|
|
import torch |
|
|
from torch.distributed.fsdp import MixedPrecision |
|
|
from torch.distributed.fsdp.wrap import ModuleWrapPolicy |
|
|
|
|
|
|
|
|
from algorithms.wan import WanImageToVideo, WanTextToVideo |
|
|
from datasets.dummy import DummyVideoDataset |
|
|
from datasets.openx_base import OpenXVideoDataset |
|
|
from datasets.droid import DroidVideoDataset |
|
|
from datasets.something_something import SomethingSomethingDataset |
|
|
from datasets.epic_kitchen import EpicKitchenDataset |
|
|
from datasets.pandas import PandasVideoDataset |
|
|
from datasets.ego4d import Ego4DVideoDataset |
|
|
from datasets.agibot_world import AgibotWorldDataset |
|
|
from datasets.mixture import MixtureDataset |
|
|
from .exp_base import BaseLightningExperiment |
|
|
|
|
|
|
|
|
class VideoPredictionExperiment(BaseLightningExperiment): |
|
|
""" |
|
|
A video prediction experiment |
|
|
""" |
|
|
|
|
|
compatible_algorithms = dict( |
|
|
wan_i2v=WanImageToVideo, |
|
|
wan_t2v=WanTextToVideo, |
|
|
wan_toy=WanImageToVideo, |
|
|
) |
|
|
|
|
|
compatible_datasets = dict( |
|
|
mixture=MixtureDataset, |
|
|
mixture_robot=MixtureDataset, |
|
|
dummy=DummyVideoDataset, |
|
|
something_something=SomethingSomethingDataset, |
|
|
epic_kitchen=EpicKitchenDataset, |
|
|
pandas=PandasVideoDataset, |
|
|
ego4d=Ego4DVideoDataset, |
|
|
bridge=OpenXVideoDataset, |
|
|
droid=DroidVideoDataset, |
|
|
agibot_world=AgibotWorldDataset, |
|
|
language_table=OpenXVideoDataset, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
def _build_strategy(self): |
|
|
from lightning.pytorch.strategies.fsdp import FSDPStrategy |
|
|
|
|
|
if self.cfg.strategy == "ddp": |
|
|
return super()._build_strategy() |
|
|
elif self.cfg.strategy == "fsdp": |
|
|
if self.cfg.num_nodes >= 8: |
|
|
device_mesh = (self.cfg.num_nodes // 8, 32) |
|
|
else: |
|
|
device_mesh = (1, self.cfg.num_nodes * 4) |
|
|
return FSDPStrategy( |
|
|
mixed_precision=MixedPrecision( |
|
|
param_dtype=torch.bfloat16, |
|
|
reduce_dtype=torch.bfloat16, |
|
|
buffer_dtype=torch.bfloat16, |
|
|
), |
|
|
auto_wrap_policy=ModuleWrapPolicy(self.algo.classes_to_shard()), |
|
|
|
|
|
sharding_strategy="HYBRID_SHARD", |
|
|
device_mesh=device_mesh, |
|
|
) |
|
|
|
|
|
else: |
|
|
return self.cfg.strategy |
|
|
|
|
|
def download_dataset(self): |
|
|
dataset = self._build_dataset("training") |
|
|
|