| # model name to class mapping | |
| from wm.model.dit.dit import DiT | |
| from wm.model.tokenizer.wan_tokenizer import WanVAEWrapper | |
| DIT_CLASS_MAP = { | |
| 'VideoDiT': DiT | |
| } | |
| VAE_CLASS_MAP = { | |
| 'WanVAE': WanVAEWrapper | |
| } | |
| def get_dynamics_class(name): | |
| if name == 'Bidirectional_FullTrajectory': | |
| from wm.dynamics.bi_fulltrajectory import Bidirectional_FullTrajectory | |
| return Bidirectional_FullTrajectory | |
| elif name == 'DiffusionForcing_WM': | |
| from wm.dynamics.diffusion_forcing_wm import DiffusionForcing_WM | |
| return DiffusionForcing_WM | |
| raise ValueError(f"Unknown dynamics class: {name}") | |