File size: 3,343 Bytes
142a1ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
from torch.distributed.fsdp import MixedPrecision
from torch.distributed.fsdp.wrap import ModuleWrapPolicy

# from algorithms.cogvideo import CogVideoXImageToVideo, CogVideoXVAE
from algorithms.wan import WanImageToVideo, WanTextToVideo
from datasets.dummy import DummyVideoDataset
from datasets.openx_base import OpenXVideoDataset
from datasets.droid import DroidVideoDataset
from datasets.something_something import SomethingSomethingDataset
from datasets.epic_kitchen import EpicKitchenDataset
from datasets.pandas import PandasVideoDataset
from datasets.ego4d import Ego4DVideoDataset
from datasets.agibot_world import AgibotWorldDataset
from datasets.mixture import MixtureDataset
from .exp_base import BaseLightningExperiment


class VideoPredictionExperiment(BaseLightningExperiment):
    """
    A video prediction experiment
    """

    compatible_algorithms = dict(
        wan_i2v=WanImageToVideo,
        wan_t2v=WanTextToVideo,
        wan_toy=WanImageToVideo,
    )

    compatible_datasets = dict(
        mixture=MixtureDataset,
        mixture_robot=MixtureDataset,
        dummy=DummyVideoDataset,
        something_something=SomethingSomethingDataset,
        epic_kitchen=EpicKitchenDataset,
        pandas=PandasVideoDataset,
        ego4d=Ego4DVideoDataset,
        bridge=OpenXVideoDataset,
        droid=DroidVideoDataset,
        agibot_world=AgibotWorldDataset,
        language_table=OpenXVideoDataset,
        # austin_buds=OpenXVideoDataset,
        # austin_sailor=OpenXVideoDataset,
        # austin_sirius=OpenXVideoDataset,
        # bc_z=OpenXVideoDataset,
        # berkeley_autolab=OpenXVideoDataset,
        # berkeley_cable=OpenXVideoDataset,
        # berkeley_fanuc=OpenXVideoDataset,
        # cmu_stretch=OpenXVideoDataset,
        # dlr_edan=OpenXVideoDataset,
        # dobbe=OpenXVideoDataset,
        # fmb=OpenXVideoDataset,
        # fractal=OpenXVideoDataset,
        # iamlab_cmu=OpenXVideoDataset,
        # jaco_play=OpenXVideoDataset,
        # nyu_franka=OpenXVideoDataset,
        # roboturk=OpenXVideoDataset,
        # stanford_hydra=OpenXVideoDataset,
        # taco_play=OpenXVideoDataset,
        # toto=OpenXVideoDataset,
        # ucsd_kitchen=OpenXVideoDataset,
        # utaustin_mutex=OpenXVideoDataset,
        # viola=OpenXVideoDataset,
    )

    def _build_strategy(self):
        from lightning.pytorch.strategies.fsdp import FSDPStrategy

        if self.cfg.strategy == "ddp":
            return super()._build_strategy()
        elif self.cfg.strategy == "fsdp":
            if self.cfg.num_nodes >= 8:
                device_mesh = (self.cfg.num_nodes // 8, 32)
            else:
                device_mesh = (1, self.cfg.num_nodes * 4)
            return FSDPStrategy(
                mixed_precision=MixedPrecision(
                    param_dtype=torch.bfloat16,
                    reduce_dtype=torch.bfloat16,
                    buffer_dtype=torch.bfloat16,
                ),
                auto_wrap_policy=ModuleWrapPolicy(self.algo.classes_to_shard()),
                # sharding_strategy="FULL_SHARD",
                sharding_strategy="HYBRID_SHARD",
                device_mesh=device_mesh,
            )

        else:
            return self.cfg.strategy

    def download_dataset(self):
        dataset = self._build_dataset("training")