| | import torch |
| | import os |
| | import json |
| | from safetensors.torch import load_file |
| |
|
| | from modules.fusion import FusionModel |
| | from modules.t5 import T5EncoderModel |
| | from modules.vae2_2 import Wan2_2_VAE |
| | from modules.mmaudio.features_utils import FeaturesUtils |
| | |
| | def init_wan_vae_2_2(ckpt_dir, rank=0): |
| | vae_config = {} |
| | vae_config['device'] = rank |
| | vae_pth = os.path.join(ckpt_dir, "Wan2.2-TI2V-5B/Wan2.2_VAE.pth") |
| | vae_config['vae_pth'] = vae_pth |
| | vae_model = Wan2_2_VAE(**vae_config) |
| |
|
| | return vae_model |
| |
|
| | def init_mmaudio_vae(ckpt_dir, rank=0): |
| | vae_config = {} |
| | vae_config['mode'] = '16k' |
| | vae_config['need_vae_encoder'] = True |
| |
|
| | tod_vae_ckpt = os.path.join(ckpt_dir, "MMAudio/ext_weights/v1-16.pth") |
| | bigvgan_vocoder_ckpt = os.path.join(ckpt_dir, "MMAudio/ext_weights/best_netG.pt") |
| |
|
| | vae_config['tod_vae_ckpt'] = tod_vae_ckpt |
| | vae_config['bigvgan_vocoder_ckpt'] = bigvgan_vocoder_ckpt |
| |
|
| | vae = FeaturesUtils(**vae_config).to(rank) |
| |
|
| | return vae |
| |
|
| | def init_fusion_score_model_ovi(rank: int = 0, meta_init=False): |
| | video_config = "configs/model/dit/video.json" |
| | audio_config = "configs/model/dit/audio.json" |
| | assert os.path.exists(video_config), f"{video_config} does not exist" |
| | assert os.path.exists(audio_config), f"{audio_config} does not exist" |
| |
|
| | with open(video_config) as f: |
| | video_config = json.load(f) |
| |
|
| | with open(audio_config) as f: |
| | audio_config = json.load(f) |
| |
|
| | if meta_init: |
| | with torch.device("meta"): |
| | fusion_model = FusionModel(video_config, audio_config) |
| | else: |
| | fusion_model = FusionModel(video_config, audio_config) |
| | |
| | params_all = sum(p.numel() for p in fusion_model.parameters()) |
| | |
| | if rank == 0: |
| | print( |
| | f"Score model (Fusion) all parameters:{params_all}" |
| | ) |
| |
|
| | return fusion_model, video_config, audio_config |
| |
|
| | def init_text_model(ckpt_dir, rank, cpu_offload=False): |
| | wan_dir = os.path.join(ckpt_dir, "Wan2.2-TI2V-5B") |
| | text_encoder_path = os.path.join(wan_dir, "models_t5_umt5-xxl-enc-bf16.pth") |
| | text_tokenizer_path = os.path.join(wan_dir, "google/umt5-xxl") |
| |
|
| | text_encoder = T5EncoderModel( |
| | text_len=512, |
| | dtype=torch.bfloat16, |
| | device=rank, |
| | checkpoint_path=text_encoder_path, |
| | tokenizer_path=text_tokenizer_path, |
| | cpu_offload=cpu_offload, |
| | shard_fn=None) |
| |
|
| | return text_encoder |
| |
|
| |
|
| | def load_fusion_checkpoint(model, checkpoint_path, from_meta=False, strict=False): |
| | if checkpoint_path and os.path.exists(checkpoint_path): |
| | if checkpoint_path.endswith(".safetensors"): |
| | df = load_file(checkpoint_path, device="cpu") |
| | elif checkpoint_path.endswith(".pt"): |
| | try: |
| | df = torch.load(checkpoint_path, map_location="cpu", weights_only=False) |
| | df = df['module'] if 'module' in df else df |
| | except Exception as e: |
| | df = torch.load(checkpoint_path, map_location="cpu", weights_only=True) |
| | df = df['app']['model'] |
| | else: |
| | raise RuntimeError("We only support .safetensors and .pt checkpoints") |
| |
|
| | missing, unexpected = model.load_state_dict(df, strict=strict, assign=from_meta) |
| | |
| | |
| | |
| | del df |
| | import gc |
| | gc.collect() |
| | print(f"Successfully loaded fusion checkpoint from {checkpoint_path}") |
| | else: |
| | raise RuntimeError("{checkpoint=} does not exists'") |
| | |
| | |
| | def load_fusion_lora(fusion, ckpt_path, from_meta=False, strict=True): |
| | print("=" * 45 + " Loading LoRA Weights " + "=" * 45) |
| | if ckpt_path and os.path.exists(ckpt_path): |
| | |
| | if ckpt_path.endswith(".safetensors"): |
| | df = load_file(ckpt_path, device="cpu") |
| | elif ckpt_path.endswith(".pt"): |
| | try: |
| | df = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
| | df = df['module'] if 'module' in df else df |
| | except Exception as e: |
| | df = torch.load(ckpt_path, map_location="cpu", weights_only=True) |
| | df = df['app']['model'] |
| | else: |
| | raise RuntimeError("We only support .safetensors and .pt checkpoints") |
| | |
| | state_dict = df.get("state_dict", df) |
| | |
| | |
| | |
| | |
| | model = {} |
| | |
| | if hasattr(fusion.audio_model, "ip_projection"): |
| | print(f"[Audio Model] Loading IP_PROJECTION") |
| | for sub in ["0", "2"]: |
| | weight_key = f"audio_model.ip_projection.{sub}.weight" |
| | bias_key = f"audio_model.ip_projection.{sub}.bias" |
| | if hasattr(getattr(fusion.audio_model, "ip_projection"), sub): |
| | model[weight_key] = getattr(getattr(fusion.audio_model, "ip_projection"), sub).weight |
| | model[bias_key] = getattr(getattr(fusion.audio_model, "ip_projection"), sub).bias |
| | else: |
| | if strict: |
| | raise KeyError(f"Missing module: {key}") |
| | print(f"[Audio Model] Loading LoRAs & IP_EMBEDDING Layer") |
| | for i, block in enumerate(fusion.audio_model.blocks): |
| | prefix = f"audio_model.blocks.{i}.self_attn." |
| | attn = block.self_attn |
| | for name in ["q_loras", "k_loras", "v_loras", "o_loras", "s_q_loras", "s_k_loras", "s_v_loras", "s_o_loras"]: |
| | if hasattr(attn, name): |
| | for sub in ["down", "up"]: |
| | key = f"{prefix}{name}.{sub}.weight" |
| | if hasattr(getattr(attn, name), sub): |
| | model[key] = getattr(getattr(attn, name), sub).weight |
| | else: |
| | if strict: |
| | raise KeyError(f"Missing module: {key}") |
| | |
| | name = "ip_embedding" |
| | weight_key = f"{prefix}{name}.weight" |
| | bias = f"{prefix}{name}.bias" |
| | if hasattr(attn, name): |
| | model[weight_key] = getattr(attn, name).weight |
| | model[bias] = getattr(attn, name).bias |
| | |
| | |
| | if hasattr(fusion.video_model, "ip_projection"): |
| | print(f"[Video Model] Loading IP_PROJECTION") |
| | for sub in ["0", "2"]: |
| | weight_key = f"video_model.ip_projection.{sub}.weight" |
| | bias_key = f"video_model.ip_projection.{sub}.bias" |
| | if hasattr(getattr(fusion.video_model, "ip_projection"), sub): |
| | model[weight_key] = getattr(getattr(fusion.video_model, "ip_projection"), sub).weight |
| | model[bias_key] = getattr(getattr(fusion.video_model, "ip_projection"), sub).bias |
| | else: |
| | if strict: |
| | raise KeyError(f"Missing module: {key}") |
| | print(f"[Video Model] Loading LoRAs & IP_EMBEDDING Layer") |
| | for i, block in enumerate(fusion.video_model.blocks): |
| | prefix = f"video_model.blocks.{i}.self_attn." |
| | attn = block.self_attn |
| | for name in ["q_loras", "k_loras", "v_loras", "o_loras", "s_q_loras", "s_k_loras", "s_v_loras", "s_o_loras"]: |
| | if hasattr(attn, name): |
| | for sub in ["down", "up"]: |
| | key = f"{prefix}{name}.{sub}.weight" |
| | if hasattr(getattr(attn, name), sub): |
| | model[key] = getattr(getattr(attn, name), sub).weight |
| | else: |
| | if strict: |
| | raise KeyError(f"Missing module: {key}") |
| | |
| | name = "ip_embedding" |
| | weight_key = f"{prefix}{name}.weight" |
| | bias = f"{prefix}{name}.bias" |
| | if hasattr(attn, name): |
| | model[weight_key] = getattr(attn, name).weight |
| | model[bias] = getattr(attn, name).bias |
| |
|
| | for k, param in state_dict.items(): |
| | if k in model: |
| | if model[k].shape != param.shape: |
| | if strict: |
| | raise ValueError( |
| | f"Shape mismatch: {k} | {model[k].shape} vs {param.shape}" |
| | ) |
| | else: |
| | continue |
| | model[k].data.copy_(param) |
| | else: |
| | if strict and "pipe.speaker_extractor" not in k: |
| | raise KeyError(f"Unexpected key in ckpt: {k}") |
| | else: |
| | raise RuntimeError("{checkpoint=} does not exists'") |
| | |
| | print("=" * 45 + " Loading LoRA Weights " + "=" * 45) |