import os import torch from shared.utils.hf import build_hf_url class family_handler: @staticmethod def query_supported_types(): return ["longcat_video", "longcat_avatar"] @staticmethod def query_family_maps(): return {}, {} @staticmethod def query_model_family(): return "longcat" @staticmethod def query_family_infos(): return {"longcat": (60, "LongCat")} @staticmethod def register_lora_cli_args(parser): parser.add_argument( "--lora-dir-longcat", type=str, default=os.path.join("loras", "longcat"), help="Path to a directory that contains LongCat Video LoRAs", ) parser.add_argument( "--lora-dir-longcat-avatar", type=str, default=os.path.join("loras", "longcat_avatar"), help="Path to a directory that contains LongCat Avatar LoRAs", ) @staticmethod def get_lora_dir(base_model_type, args): if base_model_type == "longcat_avatar": return args.lora_dir_longcat_avatar return args.lora_dir_longcat @staticmethod def query_model_def(base_model_type, model_def): extra_model_def = { "frames_minimum": 5, "frames_steps": 4, "sliding_window": True, "guidance_max_phases": 1, "image_prompt_types_allowed": "TSVL", "video_continuation": True, "sample_solvers": [ ("Auto (Continuation = Enhanced HF)", "auto"), ("Default", ""), ("Enhanced HF", "enhance_hf"), ("Distill", "distill"), ], } text_encoder_folder = "umt5-xxl" extra_model_def["text_encoder_URLs"] = [ build_hf_url("DeepBeepMeep/Wan2.1", text_encoder_folder, "models_t5_umt5-xxl-enc-bf16.safetensors"), build_hf_url("DeepBeepMeep/Wan2.1", text_encoder_folder, "models_t5_umt5-xxl-enc-quanto_int8.safetensors"), ] extra_model_def["text_encoder_folder"] = text_encoder_folder if base_model_type == "longcat_video": extra_model_def.update( { "fps": 15, "profiles_dir": ["longcat_video"], } ) elif base_model_type == "longcat_avatar": extra_model_def.update( { "fps": 16, "profiles_dir": [base_model_type], "audio_guide_label": "Voice to follow", "audio_guide2_label": "Voice to follow #2", "audio_guidance": True, "any_audio_prompt": True, "audio_prompt_choices": True, "image_ref_choices": { "choices": [("None", ""), ("Anchor Reference Image", "KI")], "letters_filter": "KI", "visible": True, "label": "Anchor Reference Image", }, "reference_image_enabled": True, "no_background_removal": True, "image_prompt_types_allowed": "TSVL", } ) return extra_model_def @staticmethod def get_rgb_factors(base_model_type): from shared.RGB_factors import get_rgb_factors return get_rgb_factors("wan") @staticmethod def query_model_files(computeList, base_model_type, model_def=None): download_def = [ { "repoId": "DeepBeepMeep/Wan2.1", "sourceFolderList": ["umt5-xxl", "chinese-wav2vec2-base"], "fileList": [ ["special_tokens_map.json", "spiece.model", "tokenizer.json", "tokenizer_config.json"], [ "config.json", "preprocessor_config.json", "pytorch_model.bin", "readme.txt", ], ], } ] download_def += [ { "repoId": "DeepBeepMeep/Wan2.1", "sourceFolderList": [""], "fileList": [["Wan2.1_VAE_bf16.safetensors"]], } ] return download_def @staticmethod def load_model( model_filename, model_type, base_model_type, model_def, quantizeTransformer=False, text_encoder_quantization=None, dtype=torch.bfloat16, VAE_dtype=torch.float32, mixed_precision_transformer=False, save_quantized=False, submodel_no_list=None, text_encoder_filename=None, **kwargs, ): from .longcat_main import LongCatModel longcat_model = LongCatModel( checkpoint_dir="ckpts", model_filename=model_filename, model_type=model_type, model_def=model_def, base_model_type=base_model_type, text_encoder_filename=text_encoder_filename, quantizeTransformer=quantizeTransformer, dtype=dtype, VAE_dtype=VAE_dtype, mixed_precision_transformer=mixed_precision_transformer, save_quantized=save_quantized, ) pipe = { "transformer": longcat_model.transformer, "vae": longcat_model.vae, "text_encoder": longcat_model.text_encoder.model, } if longcat_model.audio_encoder is not None: pipe["wav2vec"] = longcat_model.audio_encoder return longcat_model, pipe @staticmethod def update_default_settings(base_model_type, model_def, ui_defaults): ui_defaults.update( { "guidance_scale": 4.0, "num_inference_steps": 50, "audio_guidance_scale": 4.0, "sliding_window_overlap": 13, "sliding_window_size": 93, } ) if base_model_type == "longcat_video": ui_defaults.update({"video_length": 93}) if base_model_type in ["longcat_avatar"]: ui_defaults.update({"video_length": 93, "video_prompt_type": ""}) if ui_defaults.get("sample_solver", "") == "": ui_defaults["sample_solver"] = "auto"