world_model / wm /model /interface.py
t1an's picture
Upload folder using huggingface_hub
f17ae24 verified
# model name to class mapping
from wm.model.dit.dit import DiT
from wm.model.tokenizer.wan_tokenizer import WanVAEWrapper
DIT_CLASS_MAP = {
'VideoDiT': DiT
}
VAE_CLASS_MAP = {
'WanVAE': WanVAEWrapper
}
def get_dynamics_class(name):
if name == 'Bidirectional_FullTrajectory':
from wm.dynamics.bi_fulltrajectory import Bidirectional_FullTrajectory
return Bidirectional_FullTrajectory
elif name == 'DiffusionForcing_WM':
from wm.dynamics.diffusion_forcing_wm import DiffusionForcing_WM
return DiffusionForcing_WM
raise ValueError(f"Unknown dynamics class: {name}")