ColabWan / models /longcat /longcat_handler.py
1ripon1's picture
Upload folder using huggingface_hub
7344bef verified
Raw
History Blame Contribute Delete
9.67 kB
import os
import torch
from shared.utils.hf import build_hf_url
LONGCAT_AVATAR_TYPES = {"longcat_avatar", "longcat_avatar_v1_5"}
class family_handler:
@staticmethod
def query_supported_types():
return ["longcat_video", "longcat_avatar", "longcat_avatar_v1_5"]
@staticmethod
def query_family_maps():
return {}, {}
@staticmethod
def query_model_family():
return "longcat"
@staticmethod
def query_family_infos():
return {"longcat": (60, "LongCat")}
@staticmethod
def register_lora_cli_args(parser, lora_root):
parser.add_argument(
"--lora-dir-longcat",
type=str,
default=None,
help=f"Path to a directory that contains LongCat Video LoRAs (default: {os.path.join(lora_root, 'longcat')})",
)
parser.add_argument(
"--lora-dir-longcat-avatar",
type=str,
default=None,
help=f"Path to a directory that contains LongCat Avatar LoRAs (default: {os.path.join(lora_root, 'longcat_avatar')})",
)
parser.add_argument(
"--lora-dir-longcat-avatar-v1-5",
type=str,
default=None,
help=f"Path to a directory that contains LongCat Avatar 1.5 LoRAs (default: {os.path.join(lora_root, 'longcat_avatar_v1_5')})",
)
@staticmethod
def get_lora_dir(base_model_type, args, lora_root):
if base_model_type == "longcat_avatar":
return getattr(args, "lora_dir_longcat_avatar", None) or os.path.join(lora_root, "longcat_avatar")
if base_model_type == "longcat_avatar_v1_5":
return getattr(args, "lora_dir_longcat_avatar_v1_5", None) or os.path.join(lora_root, "longcat_avatar_v1_5")
return getattr(args, "lora_dir_longcat", None) or os.path.join(lora_root, "longcat")
@staticmethod
def query_model_def(base_model_type, model_def):
extra_model_def = {
"frames_minimum": 5,
"frames_steps": 4,
"sliding_window": True,
"guidance_max_phases": 1,
"image_prompt_types_allowed": "TSVL",
"video_continuation": True,
"sample_solvers": [
("Auto (Continuation = Enhanced HF)", "auto"),
("Default", ""),
("Enhanced HF", "enhance_hf"),
("Distill", "distill"),
],
}
text_encoder_folder = "umt5-xxl"
extra_model_def["text_encoder_URLs"] = [
build_hf_url("DeepBeepMeep/Wan2.1", text_encoder_folder, "models_t5_umt5-xxl-enc-bf16.safetensors"),
build_hf_url("DeepBeepMeep/Wan2.1", text_encoder_folder, "models_t5_umt5-xxl-enc-quanto_int8.safetensors"),
]
extra_model_def["text_encoder_folder"] = text_encoder_folder
if base_model_type == "longcat_video":
extra_model_def.update(
{
"fps": 15,
"profiles_dir": ["longcat_video"],
}
)
elif base_model_type in LONGCAT_AVATAR_TYPES:
is_v1_5 = base_model_type == "longcat_avatar_v1_5"
extra_model_def.update(
{
"fps": 25 if is_v1_5 else 16,
"profiles_dir": [base_model_type],
"audio_guide_label": "Voice to follow",
"audio_guide2_label": "Voice to follow #2",
"audio_guidance": True,
"any_audio_prompt": True,
"audio_prompt_choices": True,
"audio_prompt_type_sources": {
"selection": ["A", "XA", "CAB", "PAB"],
"default": "A",
"label": "Voices",
"scale": 3,
},
"speaker_locations": True,
"image_ref_choices": {
"choices": [ ("Anchor Reference Image", "KI")],
"letters_filter": "KI",
"visible": False,
"label": "Anchor Reference Image",
},
"reference_image_enabled": True,
"no_background_removal": True,
"image_prompt_types_allowed": "TSVL",
"audio_stride": 1 if is_v1_5 else 2,
}
)
if is_v1_5:
extra_model_def.update(
{
"sample_solvers": [("Distill", "distill")],
"sample_solver": "distill",
"distill_only": True,
"num_distill_sample_steps": 8,
"scheduler_config": "models/longcat/configs/longcat_avatar_v1_5_scheduler.json",
"transformer_config": "models/longcat/configs/longcat_avatar_v1_5.json",
"audio_encoder_type": "whisper-large-v3",
"audio_encoder_folder": "whisper-large-v3",
"distill_lora_URL": build_hf_url("DeepBeepMeep/LongCat", "longcat_avatar_v1_5", "dmd_lora.safetensors"),
}
)
return extra_model_def
@staticmethod
def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
if base_model_type in LONGCAT_AVATAR_TYPES:
ui_defaults["video_prompt_type"] = "KI"
@staticmethod
def get_rgb_factors(base_model_type):
from shared.RGB_factors import get_rgb_factors
return get_rgb_factors("wan")
@staticmethod
def query_model_files(computeList, base_model_type, model_def=None):
download_def = [
{
"repoId": "DeepBeepMeep/Wan2.1",
"sourceFolderList": ["umt5-xxl"],
"fileList": [["special_tokens_map.json", "spiece.model", "tokenizer.json", "tokenizer_config.json"]],
}
]
if base_model_type == "longcat_avatar_v1_5":
download_def.append(
{
"repoId": "DeepBeepMeep/LongCat",
"sourceFolderList": ["whisper-large-v3"],
"fileList": [["config.json", "generation_config.json", "model.safetensors", "preprocessor_config.json"]],
}
)
else:
download_def.append(
{
"repoId": "DeepBeepMeep/Wan2.1",
"sourceFolderList": ["chinese-wav2vec2-base"],
"fileList": [["config.json", "preprocessor_config.json", "pytorch_model.bin", "readme.txt"]],
}
)
download_def += [
{
"repoId": "DeepBeepMeep/Wan2.1",
"sourceFolderList": [""],
"fileList": [["Wan2.1_VAE_bf16.safetensors"]],
}
]
return download_def
@staticmethod
def load_model(
model_filename,
model_type,
base_model_type,
model_def,
quantizeTransformer=False,
text_encoder_quantization=None,
dtype=torch.bfloat16,
VAE_dtype=torch.float32,
mixed_precision_transformer=False,
save_quantized=False,
submodel_no_list=None,
text_encoder_filename=None,
**kwargs,
):
from .longcat_main import LongCatModel
longcat_model = LongCatModel(
checkpoint_dir="ckpts",
model_filename=model_filename,
model_type=model_type,
model_def=model_def,
base_model_type=base_model_type,
text_encoder_filename=text_encoder_filename,
quantizeTransformer=quantizeTransformer,
dtype=dtype,
VAE_dtype=VAE_dtype,
mixed_precision_transformer=mixed_precision_transformer,
save_quantized=save_quantized,
)
pipe = {
"transformer": longcat_model.transformer,
"vae": longcat_model.vae,
"text_encoder": longcat_model.text_encoder.model,
}
if longcat_model.audio_encoder is not None:
audio_key = "whisper" if getattr(longcat_model, "audio_encoder_name", "") == "whisper-large-v3" else "wav2vec"
pipe[audio_key] = longcat_model.audio_encoder
return longcat_model, pipe
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
if base_model_type == "longcat_avatar_v1_5":
ui_defaults.update(
{
"guidance_scale": 1.0,
"num_inference_steps": 8,
"audio_guidance_scale": 1.0,
"sliding_window_overlap": 13,
"sliding_window_size": 93,
"video_length": 93,
"video_prompt_type": "",
"sample_solver": "distill",
}
)
return
ui_defaults.update(
{
"guidance_scale": 4.0,
"num_inference_steps": 50,
"audio_guidance_scale": 4.0,
"sliding_window_overlap": 13,
"sliding_window_size": 93,
}
)
if base_model_type == "longcat_video":
ui_defaults.update({"video_length": 93})
if base_model_type in LONGCAT_AVATAR_TYPES:
ui_defaults.update({"video_length": 93, "video_prompt_type": ""})
if ui_defaults.get("sample_solver", "") == "":
ui_defaults["sample_solver"] = "auto"