File size: 5,037 Bytes
31112ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import torch
from shared.utils.hf import build_hf_url
class family_handler():
@staticmethod
def set_cache_parameters(cache_type, base_model_type, model_def, inputs, skip_steps_cache):
if base_model_type == "sky_df_1.3B":
coefficients= [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
else:
coefficients= [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
skip_steps_cache.coefficients = coefficients
@staticmethod
def query_model_def(base_model_type, model_def):
extra_model_def = {}
text_encoder_folder = "umt5-xxl"
extra_model_def["text_encoder_URLs"] = [
build_hf_url("DeepBeepMeep/Wan2.1", text_encoder_folder, "models_t5_umt5-xxl-enc-bf16.safetensors"),
build_hf_url("DeepBeepMeep/Wan2.1", text_encoder_folder, "models_t5_umt5-xxl-enc-quanto_int8.safetensors"),
]
extra_model_def["text_encoder_folder"] = text_encoder_folder
if base_model_type in ["sky_df_14B"]:
fps = 24
else:
fps = 16
extra_model_def["fps"] =fps
extra_model_def["frames_minimum"] = 17
extra_model_def["frames_steps"] = 20
extra_model_def["latent_size"] = 4
extra_model_def["sliding_window"] = True
extra_model_def["skip_layer_guidance"] = True
extra_model_def["tea_cache"] = True
extra_model_def["guidance_max_phases"] = 1
extra_model_def["flow_shift"] = True
extra_model_def["model_modes"] = {
"choices": [
("Synchronous", 0),
("Asynchronous (better quality but around 50% extra steps added)", 5),
],
"default": 0,
"label" : "Generation Type"
}
extra_model_def["image_prompt_types_allowed"] = "TSV"
return extra_model_def
@staticmethod
def query_supported_types():
return ["sky_df_1.3B", "sky_df_14B"]
@staticmethod
def query_family_maps():
models_eqv_map = {
"sky_df_1.3B" : "sky_df_14B",
}
models_comp_map = {
"sky_df_14B": ["sky_df_1.3B"],
}
return models_eqv_map, models_comp_map
@staticmethod
def query_model_family():
return "wan"
@staticmethod
def query_family_infos():
return {}
@staticmethod
def register_lora_cli_args(parser):
from .wan_handler import family_handler as wan_family_handler
return wan_family_handler.register_lora_cli_args(parser)
@staticmethod
def get_lora_dir(base_model_type, args):
from .wan_handler import family_handler as wan_family_handler
return wan_family_handler.get_lora_dir(base_model_type, args)
@staticmethod
def get_rgb_factors(base_model_type ):
from shared.RGB_factors import get_rgb_factors
latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan", base_model_type)
return latent_rgb_factors, latent_rgb_factors_bias
@staticmethod
def query_model_files(computeList, base_model_type, model_def=None):
from .wan_handler import family_handler
return family_handler.query_model_files(computeList, base_model_type, model_def)
@staticmethod
def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False, submodel_no_list = None, text_encoder_filename = None):
from .configs import WAN_CONFIGS
cfg = WAN_CONFIGS['t2v-14B']
from . import DTT2V
wan_model = DTT2V(
config=cfg,
checkpoint_dir="ckpts",
model_filename=model_filename,
model_type = model_type,
model_def = model_def,
base_model_type=base_model_type,
text_encoder_filename= text_encoder_filename,
quantizeTransformer = quantizeTransformer,
dtype = dtype,
VAE_dtype = VAE_dtype,
mixed_precision_transformer = mixed_precision_transformer,
save_quantized = save_quantized
)
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model }
return wan_model, pipe
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
ui_defaults.update({
"guidance_scale": 6.0,
"flow_shift": 8,
"sliding_window_discard_last_frames" : 0,
"resolution": "1280x720" if "720" in base_model_type else "960x544",
"sliding_window_size" : 121 if "720" in base_model_type else 97,
"RIFLEx_setting": 2,
"guidance_scale": 6,
"flow_shift": 8,
})
|