File size: 5,037 Bytes
31112ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
from shared.utils.hf import build_hf_url

class family_handler():

    @staticmethod
    def set_cache_parameters(cache_type, base_model_type, model_def, inputs, skip_steps_cache):
        if base_model_type == "sky_df_1.3B":
            coefficients= [2.39676752e+03, -1.31110545e+03,  2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
        else: 
            coefficients= [-5784.54975374,  5449.50911966, -1811.16591783,   256.27178429, -13.02252404]

        skip_steps_cache.coefficients = coefficients

    @staticmethod
    def query_model_def(base_model_type, model_def):
        extra_model_def = {}
        text_encoder_folder = "umt5-xxl"
        extra_model_def["text_encoder_URLs"] = [
            build_hf_url("DeepBeepMeep/Wan2.1", text_encoder_folder, "models_t5_umt5-xxl-enc-bf16.safetensors"),
            build_hf_url("DeepBeepMeep/Wan2.1", text_encoder_folder, "models_t5_umt5-xxl-enc-quanto_int8.safetensors"),
        ]
        extra_model_def["text_encoder_folder"] = text_encoder_folder
        if base_model_type in ["sky_df_14B"]:
            fps = 24
        else:
            fps = 16
        extra_model_def["fps"] =fps
        extra_model_def["frames_minimum"] = 17
        extra_model_def["frames_steps"] = 20
        extra_model_def["latent_size"] = 4
        extra_model_def["sliding_window"] = True
        extra_model_def["skip_layer_guidance"] = True
        extra_model_def["tea_cache"] = True
        extra_model_def["guidance_max_phases"] = 1
        extra_model_def["flow_shift"] = True
        extra_model_def["model_modes"] = {
                    "choices": [
                        ("Synchronous", 0),
                        ("Asynchronous (better quality but around 50% extra steps added)", 5),
                        ],
                    "default": 0,
                    "label" : "Generation Type"
        }

        extra_model_def["image_prompt_types_allowed"] = "TSV"


        return extra_model_def 

    @staticmethod
    def query_supported_types():
        return ["sky_df_1.3B", "sky_df_14B"]


    @staticmethod
    def query_family_maps():
        models_eqv_map = {
            "sky_df_1.3B" : "sky_df_14B",
        }

        models_comp_map = { 
                    "sky_df_14B": ["sky_df_1.3B"],
                    }
        return models_eqv_map, models_comp_map



    @staticmethod
    def query_model_family():
        return "wan"

    @staticmethod
    def query_family_infos():
        return {}

    @staticmethod
    def register_lora_cli_args(parser):
        from .wan_handler import family_handler as wan_family_handler

        return wan_family_handler.register_lora_cli_args(parser)

    @staticmethod
    def get_lora_dir(base_model_type, args):
        from .wan_handler import family_handler as wan_family_handler

        return wan_family_handler.get_lora_dir(base_model_type, args)

    @staticmethod
    def get_rgb_factors(base_model_type ):
        from shared.RGB_factors import get_rgb_factors
        latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan", base_model_type)
        return latent_rgb_factors, latent_rgb_factors_bias

    @staticmethod
    def query_model_files(computeList, base_model_type, model_def=None):
        from .wan_handler import family_handler
        return family_handler.query_model_files(computeList, base_model_type,  model_def)
    
    @staticmethod
    def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False, submodel_no_list = None, text_encoder_filename = None):
        from .configs import WAN_CONFIGS
        cfg = WAN_CONFIGS['t2v-14B']
        from . import DTT2V
        wan_model = DTT2V(
            config=cfg,
            checkpoint_dir="ckpts",
            model_filename=model_filename,
            model_type = model_type,        
            model_def = model_def,
            base_model_type=base_model_type,
            text_encoder_filename= text_encoder_filename,
            quantizeTransformer = quantizeTransformer,
            dtype = dtype,
            VAE_dtype = VAE_dtype, 
            mixed_precision_transformer = mixed_precision_transformer,
            save_quantized = save_quantized
        )

        pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model }
        return wan_model, pipe

    @staticmethod
    def update_default_settings(base_model_type, model_def, ui_defaults):
        ui_defaults.update({
            "guidance_scale": 6.0,
            "flow_shift": 8,
            "sliding_window_discard_last_frames" : 0,
            "resolution": "1280x720" if "720" in base_model_type else "960x544",
            "sliding_window_size" : 121 if "720" in base_model_type else 97,
            "RIFLEx_setting": 2,
            "guidance_scale": 6,
            "flow_shift": 8,
        })