File size: 618 Bytes
f17ae24 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | # model name to class mapping
from wm.model.dit.dit import DiT
from wm.model.tokenizer.wan_tokenizer import WanVAEWrapper
DIT_CLASS_MAP = {
'VideoDiT': DiT
}
VAE_CLASS_MAP = {
'WanVAE': WanVAEWrapper
}
def get_dynamics_class(name):
if name == 'Bidirectional_FullTrajectory':
from wm.dynamics.bi_fulltrajectory import Bidirectional_FullTrajectory
return Bidirectional_FullTrajectory
elif name == 'DiffusionForcing_WM':
from wm.dynamics.diffusion_forcing_wm import DiffusionForcing_WM
return DiffusionForcing_WM
raise ValueError(f"Unknown dynamics class: {name}")
|