OmniCustom / utils /model_loading_utils.py
sunnyday1307
Initial commit
0f4bcb8
import torch
import os
import json
from safetensors.torch import load_file
from modules.fusion import FusionModel
from modules.t5 import T5EncoderModel
from modules.vae2_2 import Wan2_2_VAE
from modules.mmaudio.features_utils import FeaturesUtils
def init_wan_vae_2_2(ckpt_dir, rank=0):
vae_config = {}
vae_config['device'] = rank
vae_pth = os.path.join(ckpt_dir, "Wan2.2-TI2V-5B/Wan2.2_VAE.pth")
vae_config['vae_pth'] = vae_pth
vae_model = Wan2_2_VAE(**vae_config)
return vae_model
def init_mmaudio_vae(ckpt_dir, rank=0):
vae_config = {}
vae_config['mode'] = '16k'
vae_config['need_vae_encoder'] = True
tod_vae_ckpt = os.path.join(ckpt_dir, "MMAudio/ext_weights/v1-16.pth")
bigvgan_vocoder_ckpt = os.path.join(ckpt_dir, "MMAudio/ext_weights/best_netG.pt")
vae_config['tod_vae_ckpt'] = tod_vae_ckpt
vae_config['bigvgan_vocoder_ckpt'] = bigvgan_vocoder_ckpt
vae = FeaturesUtils(**vae_config).to(rank)
return vae
def init_fusion_score_model_ovi(rank: int = 0, meta_init=False):
video_config = "configs/model/dit/video.json"
audio_config = "configs/model/dit/audio.json"
assert os.path.exists(video_config), f"{video_config} does not exist"
assert os.path.exists(audio_config), f"{audio_config} does not exist"
with open(video_config) as f:
video_config = json.load(f)
with open(audio_config) as f:
audio_config = json.load(f)
if meta_init:
with torch.device("meta"):
fusion_model = FusionModel(video_config, audio_config)
else:
fusion_model = FusionModel(video_config, audio_config)
params_all = sum(p.numel() for p in fusion_model.parameters())
if rank == 0:
print(
f"Score model (Fusion) all parameters:{params_all}"
)
return fusion_model, video_config, audio_config
def init_text_model(ckpt_dir, rank, cpu_offload=False):
wan_dir = os.path.join(ckpt_dir, "Wan2.2-TI2V-5B")
text_encoder_path = os.path.join(wan_dir, "models_t5_umt5-xxl-enc-bf16.pth")
text_tokenizer_path = os.path.join(wan_dir, "google/umt5-xxl")
text_encoder = T5EncoderModel(
text_len=512,
dtype=torch.bfloat16,
device=rank,
checkpoint_path=text_encoder_path,
tokenizer_path=text_tokenizer_path,
cpu_offload=cpu_offload,
shard_fn=None)
return text_encoder
def load_fusion_checkpoint(model, checkpoint_path, from_meta=False, strict=False):
if checkpoint_path and os.path.exists(checkpoint_path):
if checkpoint_path.endswith(".safetensors"):
df = load_file(checkpoint_path, device="cpu")
elif checkpoint_path.endswith(".pt"):
try:
df = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
df = df['module'] if 'module' in df else df
except Exception as e:
df = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
df = df['app']['model']
else:
raise RuntimeError("We only support .safetensors and .pt checkpoints")
missing, unexpected = model.load_state_dict(df, strict=strict, assign=from_meta)
#print(f"Missing Keys: [{missing}]")
#print(f"Unexpected Keys: [{unexpected}]")
del df
import gc
gc.collect()
print(f"Successfully loaded fusion checkpoint from {checkpoint_path}")
else:
raise RuntimeError("{checkpoint=} does not exists'")
def load_fusion_lora(fusion, ckpt_path, from_meta=False, strict=True):
print("=" * 45 + " Loading LoRA Weights " + "=" * 45)
if ckpt_path and os.path.exists(ckpt_path):
if ckpt_path.endswith(".safetensors"):
df = load_file(ckpt_path, device="cpu")
elif ckpt_path.endswith(".pt"):
try:
df = torch.load(ckpt_path, map_location="cpu", weights_only=False)
df = df['module'] if 'module' in df else df
except Exception as e:
df = torch.load(ckpt_path, map_location="cpu", weights_only=True)
df = df['app']['model']
else:
raise RuntimeError("We only support .safetensors and .pt checkpoints")
state_dict = df.get("state_dict", df)
#print(state_dict.keys())
#print(f"=" * 90)
#print(state_dict.keys())
model = {}
# audio model
if hasattr(fusion.audio_model, "ip_projection"):
print(f"[Audio Model] Loading IP_PROJECTION")
for sub in ["0", "2"]:
weight_key = f"audio_model.ip_projection.{sub}.weight"
bias_key = f"audio_model.ip_projection.{sub}.bias"
if hasattr(getattr(fusion.audio_model, "ip_projection"), sub):
model[weight_key] = getattr(getattr(fusion.audio_model, "ip_projection"), sub).weight
model[bias_key] = getattr(getattr(fusion.audio_model, "ip_projection"), sub).bias
else:
if strict:
raise KeyError(f"Missing module: {key}")
print(f"[Audio Model] Loading LoRAs & IP_EMBEDDING Layer")
for i, block in enumerate(fusion.audio_model.blocks):
prefix = f"audio_model.blocks.{i}.self_attn."
attn = block.self_attn
for name in ["q_loras", "k_loras", "v_loras", "o_loras", "s_q_loras", "s_k_loras", "s_v_loras", "s_o_loras"]:
if hasattr(attn, name):
for sub in ["down", "up"]:
key = f"{prefix}{name}.{sub}.weight"
if hasattr(getattr(attn, name), sub):
model[key] = getattr(getattr(attn, name), sub).weight
else:
if strict:
raise KeyError(f"Missing module: {key}")
# load ip embedding layer
name = "ip_embedding"
weight_key = f"{prefix}{name}.weight"
bias = f"{prefix}{name}.bias"
if hasattr(attn, name):
model[weight_key] = getattr(attn, name).weight
model[bias] = getattr(attn, name).bias
# video model
if hasattr(fusion.video_model, "ip_projection"):
print(f"[Video Model] Loading IP_PROJECTION")
for sub in ["0", "2"]:
weight_key = f"video_model.ip_projection.{sub}.weight"
bias_key = f"video_model.ip_projection.{sub}.bias"
if hasattr(getattr(fusion.video_model, "ip_projection"), sub):
model[weight_key] = getattr(getattr(fusion.video_model, "ip_projection"), sub).weight
model[bias_key] = getattr(getattr(fusion.video_model, "ip_projection"), sub).bias
else:
if strict:
raise KeyError(f"Missing module: {key}")
print(f"[Video Model] Loading LoRAs & IP_EMBEDDING Layer")
for i, block in enumerate(fusion.video_model.blocks):
prefix = f"video_model.blocks.{i}.self_attn."
attn = block.self_attn
for name in ["q_loras", "k_loras", "v_loras", "o_loras", "s_q_loras", "s_k_loras", "s_v_loras", "s_o_loras"]:
if hasattr(attn, name):
for sub in ["down", "up"]:
key = f"{prefix}{name}.{sub}.weight"
if hasattr(getattr(attn, name), sub):
model[key] = getattr(getattr(attn, name), sub).weight
else:
if strict:
raise KeyError(f"Missing module: {key}")
# load ip embedding layer
name = "ip_embedding"
weight_key = f"{prefix}{name}.weight"
bias = f"{prefix}{name}.bias"
if hasattr(attn, name):
model[weight_key] = getattr(attn, name).weight
model[bias] = getattr(attn, name).bias
for k, param in state_dict.items():
if k in model:
if model[k].shape != param.shape:
if strict:
raise ValueError(
f"Shape mismatch: {k} | {model[k].shape} vs {param.shape}"
)
else:
continue
model[k].data.copy_(param)
else:
if strict and "pipe.speaker_extractor" not in k:
raise KeyError(f"Unexpected key in ckpt: {k}")
else:
raise RuntimeError("{checkpoint=} does not exists'")
print("=" * 45 + " Loading LoRA Weights " + "=" * 45)