LVP / experiments /exp_video.py
kiwhansong's picture
add demo
142a1ac
import torch
from torch.distributed.fsdp import MixedPrecision
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
# from algorithms.cogvideo import CogVideoXImageToVideo, CogVideoXVAE
from algorithms.wan import WanImageToVideo, WanTextToVideo
from datasets.dummy import DummyVideoDataset
from datasets.openx_base import OpenXVideoDataset
from datasets.droid import DroidVideoDataset
from datasets.something_something import SomethingSomethingDataset
from datasets.epic_kitchen import EpicKitchenDataset
from datasets.pandas import PandasVideoDataset
from datasets.ego4d import Ego4DVideoDataset
from datasets.agibot_world import AgibotWorldDataset
from datasets.mixture import MixtureDataset
from .exp_base import BaseLightningExperiment
class VideoPredictionExperiment(BaseLightningExperiment):
"""
A video prediction experiment
"""
compatible_algorithms = dict(
wan_i2v=WanImageToVideo,
wan_t2v=WanTextToVideo,
wan_toy=WanImageToVideo,
)
compatible_datasets = dict(
mixture=MixtureDataset,
mixture_robot=MixtureDataset,
dummy=DummyVideoDataset,
something_something=SomethingSomethingDataset,
epic_kitchen=EpicKitchenDataset,
pandas=PandasVideoDataset,
ego4d=Ego4DVideoDataset,
bridge=OpenXVideoDataset,
droid=DroidVideoDataset,
agibot_world=AgibotWorldDataset,
language_table=OpenXVideoDataset,
# austin_buds=OpenXVideoDataset,
# austin_sailor=OpenXVideoDataset,
# austin_sirius=OpenXVideoDataset,
# bc_z=OpenXVideoDataset,
# berkeley_autolab=OpenXVideoDataset,
# berkeley_cable=OpenXVideoDataset,
# berkeley_fanuc=OpenXVideoDataset,
# cmu_stretch=OpenXVideoDataset,
# dlr_edan=OpenXVideoDataset,
# dobbe=OpenXVideoDataset,
# fmb=OpenXVideoDataset,
# fractal=OpenXVideoDataset,
# iamlab_cmu=OpenXVideoDataset,
# jaco_play=OpenXVideoDataset,
# nyu_franka=OpenXVideoDataset,
# roboturk=OpenXVideoDataset,
# stanford_hydra=OpenXVideoDataset,
# taco_play=OpenXVideoDataset,
# toto=OpenXVideoDataset,
# ucsd_kitchen=OpenXVideoDataset,
# utaustin_mutex=OpenXVideoDataset,
# viola=OpenXVideoDataset,
)
def _build_strategy(self):
from lightning.pytorch.strategies.fsdp import FSDPStrategy
if self.cfg.strategy == "ddp":
return super()._build_strategy()
elif self.cfg.strategy == "fsdp":
if self.cfg.num_nodes >= 8:
device_mesh = (self.cfg.num_nodes // 8, 32)
else:
device_mesh = (1, self.cfg.num_nodes * 4)
return FSDPStrategy(
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
auto_wrap_policy=ModuleWrapPolicy(self.algo.classes_to_shard()),
# sharding_strategy="FULL_SHARD",
sharding_strategy="HYBRID_SHARD",
device_mesh=device_mesh,
)
else:
return self.cfg.strategy
def download_dataset(self):
dataset = self._build_dataset("training")