File size: 8,791 Bytes
0f4bcb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import torch 
import os
import json
from safetensors.torch import load_file

from modules.fusion import FusionModel
from modules.t5 import T5EncoderModel
from modules.vae2_2 import Wan2_2_VAE
from modules.mmaudio.features_utils import FeaturesUtils
    
def init_wan_vae_2_2(ckpt_dir, rank=0):
    vae_config = {}
    vae_config['device'] = rank
    vae_pth = os.path.join(ckpt_dir, "Wan2.2-TI2V-5B/Wan2.2_VAE.pth")
    vae_config['vae_pth'] = vae_pth
    vae_model = Wan2_2_VAE(**vae_config)

    return vae_model

def init_mmaudio_vae(ckpt_dir, rank=0):
    vae_config = {}
    vae_config['mode'] = '16k'
    vae_config['need_vae_encoder'] = True

    tod_vae_ckpt = os.path.join(ckpt_dir, "MMAudio/ext_weights/v1-16.pth")
    bigvgan_vocoder_ckpt = os.path.join(ckpt_dir, "MMAudio/ext_weights/best_netG.pt")

    vae_config['tod_vae_ckpt'] = tod_vae_ckpt
    vae_config['bigvgan_vocoder_ckpt'] = bigvgan_vocoder_ckpt

    vae = FeaturesUtils(**vae_config).to(rank)

    return vae

def init_fusion_score_model_ovi(rank: int = 0, meta_init=False):
    video_config = "configs/model/dit/video.json"
    audio_config = "configs/model/dit/audio.json"
    assert os.path.exists(video_config), f"{video_config} does not exist"
    assert os.path.exists(audio_config), f"{audio_config} does not exist"

    with open(video_config) as f:
        video_config = json.load(f)

    with open(audio_config) as f:
        audio_config = json.load(f)

    if meta_init:
        with torch.device("meta"):
            fusion_model = FusionModel(video_config, audio_config)
    else:
        fusion_model = FusionModel(video_config, audio_config)
    
    params_all = sum(p.numel() for p in fusion_model.parameters())
    
    if rank == 0:
        print(
            f"Score model (Fusion) all parameters:{params_all}"
        )

    return fusion_model, video_config, audio_config

def init_text_model(ckpt_dir, rank, cpu_offload=False):
    wan_dir = os.path.join(ckpt_dir, "Wan2.2-TI2V-5B")
    text_encoder_path = os.path.join(wan_dir, "models_t5_umt5-xxl-enc-bf16.pth")
    text_tokenizer_path = os.path.join(wan_dir, "google/umt5-xxl")

    text_encoder = T5EncoderModel(
        text_len=512,
        dtype=torch.bfloat16,
        device=rank,
        checkpoint_path=text_encoder_path,
        tokenizer_path=text_tokenizer_path,
        cpu_offload=cpu_offload,
        shard_fn=None)

    return text_encoder


def load_fusion_checkpoint(model, checkpoint_path, from_meta=False, strict=False):
    if checkpoint_path and os.path.exists(checkpoint_path):
        if checkpoint_path.endswith(".safetensors"): 
            df = load_file(checkpoint_path, device="cpu")
        elif checkpoint_path.endswith(".pt"):
            try:
                df = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
                df = df['module'] if 'module' in df else df
            except Exception as e:
                df = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
                df = df['app']['model']
        else: 
            raise RuntimeError("We only support .safetensors and .pt checkpoints")

        missing, unexpected = model.load_state_dict(df, strict=strict, assign=from_meta)
        #print(f"Missing Keys: [{missing}]")
        #print(f"Unexpected Keys: [{unexpected}]")
        
        del df
        import gc
        gc.collect()
        print(f"Successfully loaded fusion checkpoint from {checkpoint_path}")
    else: 
        raise RuntimeError("{checkpoint=} does not exists'")
    
    
def load_fusion_lora(fusion, ckpt_path, from_meta=False, strict=True):
    print("=" * 45 + " Loading LoRA Weights " + "=" * 45)
    if ckpt_path and os.path.exists(ckpt_path):
        
        if ckpt_path.endswith(".safetensors"): 
            df = load_file(ckpt_path, device="cpu")
        elif ckpt_path.endswith(".pt"):
            try:
                df = torch.load(ckpt_path, map_location="cpu", weights_only=False)
                df = df['module'] if 'module' in df else df
            except Exception as e:
                df = torch.load(ckpt_path, map_location="cpu", weights_only=True)
                df = df['app']['model']
        else: 
            raise RuntimeError("We only support .safetensors and .pt checkpoints")
        
        state_dict = df.get("state_dict", df)
        #print(state_dict.keys())
        #print(f"=" * 90)
        #print(state_dict.keys())
        
        model = {}
        # audio model
        if hasattr(fusion.audio_model, "ip_projection"):
            print(f"[Audio Model] Loading IP_PROJECTION")
            for sub in ["0", "2"]:
                weight_key = f"audio_model.ip_projection.{sub}.weight"
                bias_key = f"audio_model.ip_projection.{sub}.bias"
                if hasattr(getattr(fusion.audio_model, "ip_projection"), sub):
                    model[weight_key] = getattr(getattr(fusion.audio_model, "ip_projection"), sub).weight
                    model[bias_key] = getattr(getattr(fusion.audio_model, "ip_projection"), sub).bias
                else:
                    if strict:
                        raise KeyError(f"Missing module: {key}")
        print(f"[Audio Model] Loading LoRAs & IP_EMBEDDING Layer")
        for i, block in enumerate(fusion.audio_model.blocks):
            prefix = f"audio_model.blocks.{i}.self_attn."
            attn = block.self_attn
            for name in ["q_loras", "k_loras", "v_loras", "o_loras", "s_q_loras", "s_k_loras", "s_v_loras", "s_o_loras"]:
                if hasattr(attn, name):
                    for sub in ["down", "up"]:
                        key = f"{prefix}{name}.{sub}.weight"
                        if hasattr(getattr(attn, name), sub):
                            model[key] = getattr(getattr(attn, name), sub).weight
                        else:
                            if strict:
                                raise KeyError(f"Missing module: {key}")
            # load ip embedding layer
            name = "ip_embedding"
            weight_key = f"{prefix}{name}.weight"
            bias = f"{prefix}{name}.bias"
            if hasattr(attn, name):
                model[weight_key] = getattr(attn, name).weight
                model[bias] = getattr(attn, name).bias
        
        # video model
        if hasattr(fusion.video_model, "ip_projection"):
            print(f"[Video Model] Loading IP_PROJECTION")
            for sub in ["0", "2"]:
                weight_key = f"video_model.ip_projection.{sub}.weight"
                bias_key = f"video_model.ip_projection.{sub}.bias"
                if hasattr(getattr(fusion.video_model, "ip_projection"), sub):
                    model[weight_key] = getattr(getattr(fusion.video_model, "ip_projection"), sub).weight
                    model[bias_key] = getattr(getattr(fusion.video_model, "ip_projection"), sub).bias
                else:
                    if strict:
                        raise KeyError(f"Missing module: {key}")
        print(f"[Video Model] Loading LoRAs & IP_EMBEDDING Layer")
        for i, block in enumerate(fusion.video_model.blocks):
            prefix = f"video_model.blocks.{i}.self_attn."
            attn = block.self_attn
            for name in ["q_loras", "k_loras", "v_loras", "o_loras", "s_q_loras", "s_k_loras", "s_v_loras", "s_o_loras"]:
                if hasattr(attn, name):
                    for sub in ["down", "up"]:
                        key = f"{prefix}{name}.{sub}.weight"
                        if hasattr(getattr(attn, name), sub):
                            model[key] = getattr(getattr(attn, name), sub).weight
                        else:
                            if strict:
                                raise KeyError(f"Missing module: {key}")
            # load ip embedding layer
            name = "ip_embedding"
            weight_key = f"{prefix}{name}.weight"
            bias = f"{prefix}{name}.bias"
            if hasattr(attn, name):
                model[weight_key] = getattr(attn, name).weight
                model[bias] = getattr(attn, name).bias

        for k, param in state_dict.items():
            if k in model:
                if model[k].shape != param.shape:
                    if strict:
                        raise ValueError(
                            f"Shape mismatch: {k} | {model[k].shape} vs {param.shape}"
                        )
                    else:
                        continue
                model[k].data.copy_(param)
            else:
                if strict and "pipe.speaker_extractor" not in k:
                    raise KeyError(f"Unexpected key in ckpt: {k}")
    else: 
        raise RuntimeError("{checkpoint=} does not exists'")
    
    print("=" * 45 + " Loading LoRA Weights " + "=" * 45)