File size: 8,791 Bytes
0f4bcb8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 | import torch
import os
import json
from safetensors.torch import load_file
from modules.fusion import FusionModel
from modules.t5 import T5EncoderModel
from modules.vae2_2 import Wan2_2_VAE
from modules.mmaudio.features_utils import FeaturesUtils
def init_wan_vae_2_2(ckpt_dir, rank=0):
vae_config = {}
vae_config['device'] = rank
vae_pth = os.path.join(ckpt_dir, "Wan2.2-TI2V-5B/Wan2.2_VAE.pth")
vae_config['vae_pth'] = vae_pth
vae_model = Wan2_2_VAE(**vae_config)
return vae_model
def init_mmaudio_vae(ckpt_dir, rank=0):
vae_config = {}
vae_config['mode'] = '16k'
vae_config['need_vae_encoder'] = True
tod_vae_ckpt = os.path.join(ckpt_dir, "MMAudio/ext_weights/v1-16.pth")
bigvgan_vocoder_ckpt = os.path.join(ckpt_dir, "MMAudio/ext_weights/best_netG.pt")
vae_config['tod_vae_ckpt'] = tod_vae_ckpt
vae_config['bigvgan_vocoder_ckpt'] = bigvgan_vocoder_ckpt
vae = FeaturesUtils(**vae_config).to(rank)
return vae
def init_fusion_score_model_ovi(rank: int = 0, meta_init=False):
video_config = "configs/model/dit/video.json"
audio_config = "configs/model/dit/audio.json"
assert os.path.exists(video_config), f"{video_config} does not exist"
assert os.path.exists(audio_config), f"{audio_config} does not exist"
with open(video_config) as f:
video_config = json.load(f)
with open(audio_config) as f:
audio_config = json.load(f)
if meta_init:
with torch.device("meta"):
fusion_model = FusionModel(video_config, audio_config)
else:
fusion_model = FusionModel(video_config, audio_config)
params_all = sum(p.numel() for p in fusion_model.parameters())
if rank == 0:
print(
f"Score model (Fusion) all parameters:{params_all}"
)
return fusion_model, video_config, audio_config
def init_text_model(ckpt_dir, rank, cpu_offload=False):
wan_dir = os.path.join(ckpt_dir, "Wan2.2-TI2V-5B")
text_encoder_path = os.path.join(wan_dir, "models_t5_umt5-xxl-enc-bf16.pth")
text_tokenizer_path = os.path.join(wan_dir, "google/umt5-xxl")
text_encoder = T5EncoderModel(
text_len=512,
dtype=torch.bfloat16,
device=rank,
checkpoint_path=text_encoder_path,
tokenizer_path=text_tokenizer_path,
cpu_offload=cpu_offload,
shard_fn=None)
return text_encoder
def load_fusion_checkpoint(model, checkpoint_path, from_meta=False, strict=False):
if checkpoint_path and os.path.exists(checkpoint_path):
if checkpoint_path.endswith(".safetensors"):
df = load_file(checkpoint_path, device="cpu")
elif checkpoint_path.endswith(".pt"):
try:
df = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
df = df['module'] if 'module' in df else df
except Exception as e:
df = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
df = df['app']['model']
else:
raise RuntimeError("We only support .safetensors and .pt checkpoints")
missing, unexpected = model.load_state_dict(df, strict=strict, assign=from_meta)
#print(f"Missing Keys: [{missing}]")
#print(f"Unexpected Keys: [{unexpected}]")
del df
import gc
gc.collect()
print(f"Successfully loaded fusion checkpoint from {checkpoint_path}")
else:
raise RuntimeError("{checkpoint=} does not exists'")
def load_fusion_lora(fusion, ckpt_path, from_meta=False, strict=True):
print("=" * 45 + " Loading LoRA Weights " + "=" * 45)
if ckpt_path and os.path.exists(ckpt_path):
if ckpt_path.endswith(".safetensors"):
df = load_file(ckpt_path, device="cpu")
elif ckpt_path.endswith(".pt"):
try:
df = torch.load(ckpt_path, map_location="cpu", weights_only=False)
df = df['module'] if 'module' in df else df
except Exception as e:
df = torch.load(ckpt_path, map_location="cpu", weights_only=True)
df = df['app']['model']
else:
raise RuntimeError("We only support .safetensors and .pt checkpoints")
state_dict = df.get("state_dict", df)
#print(state_dict.keys())
#print(f"=" * 90)
#print(state_dict.keys())
model = {}
# audio model
if hasattr(fusion.audio_model, "ip_projection"):
print(f"[Audio Model] Loading IP_PROJECTION")
for sub in ["0", "2"]:
weight_key = f"audio_model.ip_projection.{sub}.weight"
bias_key = f"audio_model.ip_projection.{sub}.bias"
if hasattr(getattr(fusion.audio_model, "ip_projection"), sub):
model[weight_key] = getattr(getattr(fusion.audio_model, "ip_projection"), sub).weight
model[bias_key] = getattr(getattr(fusion.audio_model, "ip_projection"), sub).bias
else:
if strict:
raise KeyError(f"Missing module: {key}")
print(f"[Audio Model] Loading LoRAs & IP_EMBEDDING Layer")
for i, block in enumerate(fusion.audio_model.blocks):
prefix = f"audio_model.blocks.{i}.self_attn."
attn = block.self_attn
for name in ["q_loras", "k_loras", "v_loras", "o_loras", "s_q_loras", "s_k_loras", "s_v_loras", "s_o_loras"]:
if hasattr(attn, name):
for sub in ["down", "up"]:
key = f"{prefix}{name}.{sub}.weight"
if hasattr(getattr(attn, name), sub):
model[key] = getattr(getattr(attn, name), sub).weight
else:
if strict:
raise KeyError(f"Missing module: {key}")
# load ip embedding layer
name = "ip_embedding"
weight_key = f"{prefix}{name}.weight"
bias = f"{prefix}{name}.bias"
if hasattr(attn, name):
model[weight_key] = getattr(attn, name).weight
model[bias] = getattr(attn, name).bias
# video model
if hasattr(fusion.video_model, "ip_projection"):
print(f"[Video Model] Loading IP_PROJECTION")
for sub in ["0", "2"]:
weight_key = f"video_model.ip_projection.{sub}.weight"
bias_key = f"video_model.ip_projection.{sub}.bias"
if hasattr(getattr(fusion.video_model, "ip_projection"), sub):
model[weight_key] = getattr(getattr(fusion.video_model, "ip_projection"), sub).weight
model[bias_key] = getattr(getattr(fusion.video_model, "ip_projection"), sub).bias
else:
if strict:
raise KeyError(f"Missing module: {key}")
print(f"[Video Model] Loading LoRAs & IP_EMBEDDING Layer")
for i, block in enumerate(fusion.video_model.blocks):
prefix = f"video_model.blocks.{i}.self_attn."
attn = block.self_attn
for name in ["q_loras", "k_loras", "v_loras", "o_loras", "s_q_loras", "s_k_loras", "s_v_loras", "s_o_loras"]:
if hasattr(attn, name):
for sub in ["down", "up"]:
key = f"{prefix}{name}.{sub}.weight"
if hasattr(getattr(attn, name), sub):
model[key] = getattr(getattr(attn, name), sub).weight
else:
if strict:
raise KeyError(f"Missing module: {key}")
# load ip embedding layer
name = "ip_embedding"
weight_key = f"{prefix}{name}.weight"
bias = f"{prefix}{name}.bias"
if hasattr(attn, name):
model[weight_key] = getattr(attn, name).weight
model[bias] = getattr(attn, name).bias
for k, param in state_dict.items():
if k in model:
if model[k].shape != param.shape:
if strict:
raise ValueError(
f"Shape mismatch: {k} | {model[k].shape} vs {param.shape}"
)
else:
continue
model[k].data.copy_(param)
else:
if strict and "pipe.speaker_extractor" not in k:
raise KeyError(f"Unexpected key in ckpt: {k}")
else:
raise RuntimeError("{checkpoint=} does not exists'")
print("=" * 45 + " Loading LoRA Weights " + "=" * 45) |