File size: 738 Bytes
5c93746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""Inference-facing model registry for StreamDiffusionV2."""

from .wan.wan_wrapper import (
    CausalWanDiffusionWrapper,
    WanDiffusionWrapper,
    WanTextEncoder,
    WanVAEWrapper,
)


DIFFUSION_NAME_TO_CLASS = {
    "wan": WanDiffusionWrapper,
    "causal_wan": CausalWanDiffusionWrapper,
}


TEXT_ENCODER_NAME_TO_CLASS = {
    "wan": WanTextEncoder,
    "causal_wan": WanTextEncoder,
}


VAE_NAME_TO_CLASS = {
    "wan": WanVAEWrapper,
    "causal_wan": WanVAEWrapper,
}


def get_diffusion_wrapper(model_name):
    return DIFFUSION_NAME_TO_CLASS[model_name]


def get_text_encoder_wrapper(model_name):
    return TEXT_ENCODER_NAME_TO_CLASS[model_name]


def get_vae_wrapper(model_name):
    return VAE_NAME_TO_CLASS[model_name]