File size: 618 Bytes
f17ae24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# model name to class mapping
from wm.model.dit.dit import DiT
from wm.model.tokenizer.wan_tokenizer import WanVAEWrapper

DIT_CLASS_MAP = {
    'VideoDiT': DiT 
}

VAE_CLASS_MAP = {
    'WanVAE': WanVAEWrapper
}

def get_dynamics_class(name):
    if name == 'Bidirectional_FullTrajectory':
        from wm.dynamics.bi_fulltrajectory import Bidirectional_FullTrajectory
        return Bidirectional_FullTrajectory
    elif name == 'DiffusionForcing_WM':
        from wm.dynamics.diffusion_forcing_wm import DiffusionForcing_WM
        return DiffusionForcing_WM
    raise ValueError(f"Unknown dynamics class: {name}")