ltx2 / Wan2GP /models /kandinsky5 /kandinsky_handler.py
vidfom's picture
Upload folder using huggingface_hub
31112ad verified
import os
import torch
from omegaconf import OmegaConf
from shared.utils.hf import build_hf_url
_MAGCACHE_RATIOS_CACHE = {}
def _load_magcache_ratios(config_name):
ratios = _MAGCACHE_RATIOS_CACHE.get(config_name)
if ratios is not None:
return ratios
config_path = os.path.join("models", "kandinsky5", "configs", config_name)
if not os.path.isfile(config_path):
_MAGCACHE_RATIOS_CACHE[config_name] = None
return None
conf = OmegaConf.load(config_path)
ratios = None
if hasattr(conf, "magcache") and "mag_ratios" in conf.magcache:
ratios = list(conf.magcache.mag_ratios)
_MAGCACHE_RATIOS_CACHE[config_name] = ratios
return ratios
def _select_k5_bucket(width, height, is_video):
if width is None or height is None:
return 512
area = width * height
if is_video:
bucket_areas = {512: 512 * 768, 1024: 1024 * 1024}
else:
bucket_areas = {512: 512 * 512, 1024: 1024 * 1024}
return min(bucket_areas, key=lambda res: abs(area - bucket_areas[res]))
def _is_k5_sparse(model_type, model_def):
if model_type and "sparse" in model_type.lower():
return True
overrides = (model_def or {}).get("k5_config_overrides", {})
attention = overrides.get("model", {}).get("attention", {})
return attention.get("type") == "nabla"
def _select_k5_magcache_config(base_model_type, model_type, bucket, is_sparse):
bucket_tag = "hd" if bucket == 1024 else "sd"
model_type = (model_type or "").lower()
if base_model_type == "k5_pro_t2v":
if "10s" in model_type:
return f"k5_pro_t2v_10s_sft_{bucket_tag}.yaml"
return f"k5_pro_t2v_5s_sft_{bucket_tag}.yaml"
if base_model_type == "k5_pro_i2v":
return f"k5_pro_i2v_5s_sft_{bucket_tag}.yaml"
if base_model_type == "k5_lite_t2v":
if "10s" in model_type:
return "k5_lite_t2v_10s_sft_sd.yaml"
return "k5_lite_t2v_5s_sft_sd.yaml"
if base_model_type == "k5_lite_i2v":
return "k5_lite_i2v_5s_sft_sd.yaml"
return None
def _infer_task(base_model_type):
if not base_model_type:
return "t2v"
base = base_model_type.lower()
if "i2v" in base:
return "i2v"
if "t2v" in base:
return "t2v"
if "i2i" in base:
return "i2i"
if "t2i" in base:
return "t2i"
return "t2v"
class family_handler:
@staticmethod
def query_supported_types():
return [
"k5_lite_t2v",
"k5_lite_i2v",
"k5_pro_t2v",
"k5_pro_i2v",
]
@staticmethod
def query_family_maps():
return {}, {}
@staticmethod
def query_model_family():
return "kandinsky5"
@staticmethod
def query_family_infos():
return {
"kandinsky5": (50, "Kandinsky 5"),
}
@staticmethod
def register_lora_cli_args(parser):
parser.add_argument(
"--lora-dir-kandinsky5",
type=str,
default=os.path.join("loras", "kandinsky5"),
help="Base path for Kandinsky 5 loras (per-architecture subfolders are used).",
)
parser.add_argument(
"--lora-dir-k5-lite-t2v",
type=str,
default=os.path.join("loras", "k5_lite_t2v"),
help="Path to a directory that contains Kandinsky 5 Lite T2V loras.",
)
parser.add_argument(
"--lora-dir-k5-lite-i2v",
type=str,
default=os.path.join("loras", "k5_lite_i2v"),
help="Path to a directory that contains Kandinsky 5 Lite I2V loras.",
)
parser.add_argument(
"--lora-dir-k5-pro-t2v",
type=str,
default=os.path.join("loras", "k5_pro_t2v"),
help="Path to a directory that contains Kandinsky 5 Pro T2V loras.",
)
parser.add_argument(
"--lora-dir-k5-pro-i2v",
type=str,
default=os.path.join("loras", "k5_pro_i2v"),
help="Path to a directory that contains Kandinsky 5 Pro I2V loras.",
)
@staticmethod
def get_lora_dir(base_model_type, args):
base_dir = getattr(args, "lora_dir_kandinsky5", None) or os.path.join("loras", "kandinsky5")
per_arch = {
"k5_lite_t2v": getattr(args, "lora_dir_k5_lite_t2v", None) or os.path.join("loras", "k5_lite_t2v"),
"k5_lite_i2v": getattr(args, "lora_dir_k5_lite_i2v", None) or os.path.join("loras", "k5_lite_i2v"),
"k5_pro_t2v": getattr(args, "lora_dir_k5_pro_t2v", None) or os.path.join("loras", "k5_pro_t2v"),
"k5_pro_i2v": getattr(args, "lora_dir_k5_pro_i2v", None) or os.path.join("loras", "k5_pro_i2v"),
}
if base_model_type in per_arch:
return per_arch[base_model_type]
if base_model_type and base_model_type != "kandinsky5":
return os.path.join(base_dir, base_model_type)
return base_dir
@staticmethod
def set_cache_parameters(cache_type, base_model_type, model_def, inputs, skip_steps_cache):
if cache_type != "mag":
return
skip_steps_cache.update({
"magcache_thresh": 0,
"magcache_K": 2,
})
resolution = inputs.get("resolution")
width = height = None
if isinstance(resolution, str) and "x" in resolution:
width_str, height_str = resolution.split("x", 1)
if width_str.isdigit() and height_str.isdigit():
width = int(width_str)
height = int(height_str)
bucket = _select_k5_bucket(width, height, is_video=True)
model_type = inputs.get("model_type")
is_sparse = _is_k5_sparse(model_type, model_def)
config_name = _select_k5_magcache_config(base_model_type, model_type, bucket, is_sparse)
if not config_name:
return
ratios = _load_magcache_ratios(config_name)
if ratios:
skip_steps_cache.def_mag_ratios = ratios
@staticmethod
def query_model_def(base_model_type, model_def):
task = _infer_task(base_model_type)
is_video = task in ("t2v", "i2v")
is_image = task in ("t2i", "i2i")
profiles_dir = base_model_type or "kandinsky5"
extra_model_def = {
"i2v_class": task == "i2v",
"t2v_class": task == "t2v",
"image_outputs": is_image,
"guidance_max_phases": 1,
"sliding_window": False,
"flow_shift": True,
"mag_cache": True,
"profiles_dir": [profiles_dir],
}
text_encoder_folder = "Qwen2.5-VL-7B-Instruct"
extra_model_def["text_encoder_URLs"] = [
build_hf_url("DeepBeepMeep/Qwen_image", text_encoder_folder, "Qwen2.5-VL-7B-Instruct_bf16.safetensors"),
build_hf_url("DeepBeepMeep/Qwen_image", text_encoder_folder, "Qwen2.5-VL-7B-Instruct_quanto_bf16_int8.safetensors"),
]
extra_model_def["text_encoder_folder"] = text_encoder_folder
if is_video:
extra_model_def.update(
{
"fps": 24,
"frames_minimum": 5,
"frames_steps": 4,
}
)
else:
extra_model_def.update(
{
"fps": 1,
"frames_minimum": 1,
"frames_steps": 1,
}
)
if task in ("i2v", "i2i"):
extra_model_def["image_prompt_types_allowed"] = "S"
else:
extra_model_def["image_prompt_types_allowed"] = ""
return extra_model_def
@staticmethod
def query_model_files(computeList, base_model_type, model_def=None):
return [
{
"repoId": "DeepBeepMeep/Qwen_image",
"sourceFolderList": ["", "Qwen2.5-VL-7B-Instruct"],
"fileList": [
["qwen_vae.safetensors", "qwen_vae_config.json"],
[
"merges.txt",
"tokenizer_config.json",
"config.json",
"vocab.json",
"video_preprocessor_config.json",
"preprocessor_config.json",
"chat_template.json",
],
],
},
{
"repoId": "DeepBeepMeep/HunyuanVideo",
"sourceFolderList": ["clip_vit_large_patch14", ""],
"fileList": [
[
"config.json",
"merges.txt",
"model.safetensors",
"preprocessor_config.json",
"special_tokens_map.json",
"tokenizer.json",
"tokenizer_config.json",
"vocab.json",
],
[
"hunyuan_video_VAE_fp32.safetensors",
"hunyuan_video_VAE_config.json",
],
],
},
]
@staticmethod
def load_model(
model_filename,
model_type=None,
base_model_type=None,
model_def=None,
quantizeTransformer=False,
text_encoder_quantization=None,
dtype=torch.bfloat16,
VAE_dtype=torch.float32,
mixed_precision_transformer=False,
save_quantized=False,
submodel_no_list=None,
text_encoder_filename=None,
**kwargs,
):
from .kandinsky_main import model_factory
kandinsky = model_factory(
checkpoint_dir="ckpts",
model_filename=model_filename,
model_type=model_type,
model_def=model_def,
base_model_type=base_model_type,
text_encoder_filename=text_encoder_filename,
quantizeTransformer=quantizeTransformer,
dtype=dtype,
VAE_dtype=VAE_dtype,
mixed_precision_transformer=mixed_precision_transformer,
save_quantized=save_quantized,
)
pipe = {
"transformer": kandinsky.transformer,
"text_encoder": kandinsky.text_embedder.embedder.model,
"text_encoder_2": kandinsky.text_embedder.clip_embedder.model,
"vae": kandinsky.vae,
}
for module in pipe.values():
if isinstance(module, torch.nn.Module):
module.to("cpu")
return kandinsky, pipe
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
task = _infer_task(base_model_type)
if task in ("t2i", "i2i"):
ui_defaults["image_mode"] = 1
if task in ("i2v", "i2i"):
ui_defaults["image_prompt_type"] = "S"
ui_defaults["skip_steps_start_step_perc"] = 20
@staticmethod
def get_rgb_factors(base_model_type):
from shared.RGB_factors import get_rgb_factors
return get_rgb_factors("hunyuan")