Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from torch.distributed.fsdp import MixedPrecision | |
| from torch.distributed.fsdp.wrap import ModuleWrapPolicy | |
| # from algorithms.cogvideo import CogVideoXImageToVideo, CogVideoXVAE | |
| from algorithms.wan import WanImageToVideo, WanTextToVideo | |
| from datasets.dummy import DummyVideoDataset | |
| from datasets.openx_base import OpenXVideoDataset | |
| from datasets.droid import DroidVideoDataset | |
| from datasets.something_something import SomethingSomethingDataset | |
| from datasets.epic_kitchen import EpicKitchenDataset | |
| from datasets.pandas import PandasVideoDataset | |
| from datasets.ego4d import Ego4DVideoDataset | |
| from datasets.agibot_world import AgibotWorldDataset | |
| from datasets.mixture import MixtureDataset | |
| from .exp_base import BaseLightningExperiment | |
| class VideoPredictionExperiment(BaseLightningExperiment): | |
| """ | |
| A video prediction experiment | |
| """ | |
| compatible_algorithms = dict( | |
| wan_i2v=WanImageToVideo, | |
| wan_t2v=WanTextToVideo, | |
| wan_toy=WanImageToVideo, | |
| ) | |
| compatible_datasets = dict( | |
| mixture=MixtureDataset, | |
| mixture_robot=MixtureDataset, | |
| dummy=DummyVideoDataset, | |
| something_something=SomethingSomethingDataset, | |
| epic_kitchen=EpicKitchenDataset, | |
| pandas=PandasVideoDataset, | |
| ego4d=Ego4DVideoDataset, | |
| bridge=OpenXVideoDataset, | |
| droid=DroidVideoDataset, | |
| agibot_world=AgibotWorldDataset, | |
| language_table=OpenXVideoDataset, | |
| # austin_buds=OpenXVideoDataset, | |
| # austin_sailor=OpenXVideoDataset, | |
| # austin_sirius=OpenXVideoDataset, | |
| # bc_z=OpenXVideoDataset, | |
| # berkeley_autolab=OpenXVideoDataset, | |
| # berkeley_cable=OpenXVideoDataset, | |
| # berkeley_fanuc=OpenXVideoDataset, | |
| # cmu_stretch=OpenXVideoDataset, | |
| # dlr_edan=OpenXVideoDataset, | |
| # dobbe=OpenXVideoDataset, | |
| # fmb=OpenXVideoDataset, | |
| # fractal=OpenXVideoDataset, | |
| # iamlab_cmu=OpenXVideoDataset, | |
| # jaco_play=OpenXVideoDataset, | |
| # nyu_franka=OpenXVideoDataset, | |
| # roboturk=OpenXVideoDataset, | |
| # stanford_hydra=OpenXVideoDataset, | |
| # taco_play=OpenXVideoDataset, | |
| # toto=OpenXVideoDataset, | |
| # ucsd_kitchen=OpenXVideoDataset, | |
| # utaustin_mutex=OpenXVideoDataset, | |
| # viola=OpenXVideoDataset, | |
| ) | |
| def _build_strategy(self): | |
| from lightning.pytorch.strategies.fsdp import FSDPStrategy | |
| if self.cfg.strategy == "ddp": | |
| return super()._build_strategy() | |
| elif self.cfg.strategy == "fsdp": | |
| if self.cfg.num_nodes >= 8: | |
| device_mesh = (self.cfg.num_nodes // 8, 32) | |
| else: | |
| device_mesh = (1, self.cfg.num_nodes * 4) | |
| return FSDPStrategy( | |
| mixed_precision=MixedPrecision( | |
| param_dtype=torch.bfloat16, | |
| reduce_dtype=torch.bfloat16, | |
| buffer_dtype=torch.bfloat16, | |
| ), | |
| auto_wrap_policy=ModuleWrapPolicy(self.algo.classes_to_shard()), | |
| # sharding_strategy="FULL_SHARD", | |
| sharding_strategy="HYBRID_SHARD", | |
| device_mesh=device_mesh, | |
| ) | |
| else: | |
| return self.cfg.strategy | |
| def download_dataset(self): | |
| dataset = self._build_dataset("training") | |