File size: 11,737 Bytes
4b07c53
 
 
 
f0d6854
 
 
 
 
4b07c53
 
 
 
da237e1
 
b6de2a7
 
 
da237e1
 
 
b6de2a7
da237e1
 
 
 
 
 
 
4b07c53
da237e1
4b07c53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da237e1
 
 
 
 
 
 
4b07c53
 
 
 
 
 
 
 
da237e1
4b07c53
 
da237e1
4b07c53
 
 
da237e1
4b07c53
 
da237e1
f0d6854
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
import os
import cv2
import numpy as np
import torch
from typing import Union, List
import torch.nn.functional as F
from einops import rearrange
import shutil
import os.path as osp

from musetalk.models.vae import VAE
from musetalk.models.unet import UNet,PositionalEncoding


def load_all_model(
    unet_model_path=os.path.join("models", "musetalkV15", "unet.pth"),
    vae_type="sd-vae",
    unet_config=os.path.join("models", "musetalkV15", "musetalk.json"),
    device=None,
):
    vae = VAE(
        model_path = os.path.join("models", vae_type),
    )
    print(f"load unet model from {unet_model_path}")
    unet = UNet(
        unet_config=unet_config,
        model_path=unet_model_path,
        device=device
    )
    pe = PositionalEncoding(d_model=384)
    return vae, unet, pe

def get_file_type(video_path):
    _, ext = os.path.splitext(video_path)

    if ext.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']:
        return 'image'
    elif ext.lower() in ['.avi', '.mp4', '.mov', '.flv', '.mkv']:
        return 'video'
    else:
        return 'unsupported'

def get_video_fps(video_path):
    video = cv2.VideoCapture(video_path)
    fps = video.get(cv2.CAP_PROP_FPS)
    video.release()
    return fps

def datagen(
    whisper_chunks,
    vae_encode_latents,
    batch_size=8,
    delay_frame=0,
    device="cuda:0",
):
    whisper_batch, latent_batch = [], []
    for i, w in enumerate(whisper_chunks):
        idx = (i+delay_frame)%len(vae_encode_latents)
        latent = vae_encode_latents[idx]
        whisper_batch.append(w)
        latent_batch.append(latent)

        if len(latent_batch) >= batch_size:
            whisper_batch = torch.stack(whisper_batch)
            latent_batch = torch.cat(latent_batch, dim=0)
            yield whisper_batch, latent_batch
            whisper_batch, latent_batch  = [], []

    # the last batch may smaller than batch size
    if len(latent_batch) > 0:
        whisper_batch = torch.stack(whisper_batch)
        latent_batch = torch.cat(latent_batch, dim=0)

        yield whisper_batch.to(device), latent_batch.to(device)

def cast_training_params(
    model: Union[torch.nn.Module, List[torch.nn.Module]],
    dtype=torch.float32,
):
    if not isinstance(model, list):
        model = [model]
    for m in model:
        for param in m.parameters():
            # only upcast trainable parameters into fp32
            if param.requires_grad:
                param.data = param.to(dtype)

def rand_log_normal(
    shape,
    loc=0.,
    scale=1.,
    device='cpu',
    dtype=torch.float32,
    generator=None
):
    """Draws samples from an lognormal distribution."""
    rnd_normal = torch.randn(
        shape, device=device, dtype=dtype, generator=generator)  # N(0, I)
    sigma = (rnd_normal * scale + loc).exp()
    return sigma

def get_mouth_region(frames, image_pred, pixel_values_face_mask):
    # Initialize lists to store the results for each image in the batch
    mouth_real_list = []
    mouth_generated_list = []

    # Process each image in the batch
    for b in range(frames.shape[0]):
        # Find the non-zero area in the face mask
        non_zero_indices = torch.nonzero(pixel_values_face_mask[b])
        # If there are no non-zero indices, skip this image
        if non_zero_indices.numel() == 0:
            continue

        min_y, max_y = torch.min(non_zero_indices[:, 1]), torch.max(
            non_zero_indices[:, 1])
        min_x, max_x = torch.min(non_zero_indices[:, 2]), torch.max(
            non_zero_indices[:, 2])

        # Crop the frames and image_pred according to the non-zero area
        frames_cropped = frames[b, :, min_y:max_y, min_x:max_x]
        image_pred_cropped = image_pred[b, :, min_y:max_y, min_x:max_x]
        # Resize the cropped images to 256*256
        frames_resized = F.interpolate(frames_cropped.unsqueeze(
            0), size=(256, 256), mode='bilinear', align_corners=False)
        image_pred_resized = F.interpolate(image_pred_cropped.unsqueeze(
            0), size=(256, 256), mode='bilinear', align_corners=False)

        # Append the resized images to the result lists
        mouth_real_list.append(frames_resized)
        mouth_generated_list.append(image_pred_resized)

    # Convert the lists to tensors if they are not empty
    mouth_real = torch.cat(mouth_real_list, dim=0) if mouth_real_list else None
    mouth_generated = torch.cat(
        mouth_generated_list, dim=0) if mouth_generated_list else None

    return mouth_real, mouth_generated

def get_image_pred(pixel_values,
                   ref_pixel_values,
                   audio_prompts,
                   vae,
                   net,
                   weight_dtype):
    with torch.no_grad():
        bsz, num_frames, c, h, w = pixel_values.shape

        masked_pixel_values = pixel_values.clone()
        masked_pixel_values[:, :, :, h//2:, :] = -1

        masked_frames = rearrange(
            masked_pixel_values, 'b f c h w -> (b f) c h w')
        masked_latents = vae.encode(masked_frames).latent_dist.mode()
        masked_latents = masked_latents * vae.config.scaling_factor
        masked_latents = masked_latents.float()

        ref_frames = rearrange(ref_pixel_values, 'b f c h w-> (b f) c h w')
        ref_latents = vae.encode(ref_frames).latent_dist.mode()
        ref_latents = ref_latents * vae.config.scaling_factor
        ref_latents = ref_latents.float()

        input_latents = torch.cat([masked_latents, ref_latents], dim=1)
        input_latents = input_latents.to(weight_dtype)
        timesteps = torch.tensor([0], device=input_latents.device)
        latents_pred = net(
            input_latents,
            timesteps,
            audio_prompts,
        )
        latents_pred = (1 / vae.config.scaling_factor) * latents_pred
        image_pred = vae.decode(latents_pred).sample
        image_pred = image_pred.float()

    return image_pred

def process_audio_features(cfg, batch, wav2vec, bsz, num_frames, weight_dtype):
    with torch.no_grad():
        audio_feature_length_per_frame = 2 * \
            (cfg.data.audio_padding_length_left +
             cfg.data.audio_padding_length_right + 1)
        audio_feats = batch['audio_feature'].to(weight_dtype)
        audio_feats = wav2vec.encoder(
            audio_feats, output_hidden_states=True).hidden_states
        audio_feats = torch.stack(audio_feats, dim=2).to(weight_dtype)  # [B, T, 10, 5, 384]

        start_ts = batch['audio_offset']
        step_ts = batch['audio_step']
        audio_feats = torch.cat([torch.zeros_like(audio_feats[:, :2*cfg.data.audio_padding_length_left]),
                                audio_feats,
                                torch.zeros_like(audio_feats[:, :2*cfg.data.audio_padding_length_right])], 1)
        audio_prompts = []
        for bb in range(bsz):
            audio_feats_list = []
            for f in range(num_frames):
                cur_t = (start_ts[bb] + f * step_ts[bb]) * 2
                audio_clip = audio_feats[bb:bb+1,
                                         cur_t: cur_t+audio_feature_length_per_frame]

                audio_feats_list.append(audio_clip)
            audio_feats_list = torch.stack(audio_feats_list, 1)
            audio_prompts.append(audio_feats_list)
        audio_prompts = torch.cat(audio_prompts)  # B, T, 10, 5, 384
    return audio_prompts

def save_checkpoint(model, save_dir, ckpt_num, name="appearance_net", total_limit=None, logger=None):
    save_path = os.path.join(save_dir, f"{name}-{ckpt_num}.pth")

    if total_limit is not None:
        checkpoints = os.listdir(save_dir)
        checkpoints = [d for d in checkpoints if d.endswith(".pth")]
        checkpoints = [d for d in checkpoints if name in d]
        checkpoints = sorted(
            checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
        )

        if len(checkpoints) >= total_limit:
            num_to_remove = len(checkpoints) - total_limit + 1
            removing_checkpoints = checkpoints[0:num_to_remove]
            logger.info(
                f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
            )
            logger.info(
                f"removing checkpoints: {', '.join(removing_checkpoints)}")

            for removing_checkpoint in removing_checkpoints:
                removing_checkpoint = os.path.join(
                    save_dir, removing_checkpoint)
                os.remove(removing_checkpoint)

    state_dict = model.state_dict()
    torch.save(state_dict, save_path)

def save_models(accelerator, net, save_dir, global_step, cfg, logger=None):
    unwarp_net = accelerator.unwrap_model(net)
    save_checkpoint(
        unwarp_net.unet,
        save_dir,
        global_step,
        name="unet",
        total_limit=cfg.total_limit,
        logger=logger
    )

def delete_additional_ckpt(base_path, num_keep):
    dirs = []
    for d in os.listdir(base_path):
        if d.startswith("checkpoint-"):
            dirs.append(d)
    num_tot = len(dirs)
    if num_tot <= num_keep:
        return
    # ensure ckpt is sorted and delete the ealier!
    del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
    for d in del_dirs:
        path_to_dir = osp.join(base_path, d)
        if osp.exists(path_to_dir):
            shutil.rmtree(path_to_dir)

def seed_everything(seed):
    import random

    import numpy as np

    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed % (2**32))
    random.seed(seed)

def process_and_save_images(
    batch, 
    image_pred,
    image_pred_infer,
    save_dir,
    global_step,
    accelerator,
    num_images_to_keep=10,
    syncnet_score=1
):
    # Rearrange the tensors
    print("image_pred.shape: ", image_pred.shape)
    pixel_values_ref_img = rearrange(batch['pixel_values_ref_img'], "b f c h w -> (b f) c h w")
    pixel_values = rearrange(batch["pixel_values_vid"], 'b f c h w -> (b f) c h w')
    
    # Create masked pixel values
    masked_pixel_values = batch["pixel_values_vid"].clone()
    _, _, _, h, _ = batch["pixel_values_vid"].shape
    masked_pixel_values[:, :, :, h//2:, :] = -1
    masked_pixel_values = rearrange(masked_pixel_values, 'b f c h w -> (b f) c h w')
    
    # Keep only the specified number of images
    pixel_values = pixel_values[:num_images_to_keep, :, :, :]
    masked_pixel_values = masked_pixel_values[:num_images_to_keep, :, :, :]
    pixel_values_ref_img = pixel_values_ref_img[:num_images_to_keep, :, :, :]
    image_pred = image_pred.detach()[:num_images_to_keep, :, :, :]
    image_pred_infer = image_pred_infer.detach()[:num_images_to_keep, :, :, :]
    
    # Concatenate images
    concat = torch.cat([
        masked_pixel_values * 0.5 + 0.5, 
        pixel_values_ref_img * 0.5 + 0.5,
        image_pred * 0.5 + 0.5,
        pixel_values * 0.5 + 0.5,
        image_pred_infer * 0.5 + 0.5,
    ], dim=2)
    print("concat.shape: ", concat.shape)
    
    # Create the save directory if it doesn't exist
    os.makedirs(f'{save_dir}/samples/', exist_ok=True)

    # Try to save the concatenated image
    try:
        # Concatenate images horizontally and convert to numpy array
        final_image = torch.cat([concat[i] for i in range(concat.shape[0])], dim=-1).permute(1, 2, 0).cpu().numpy()[:, :, [2, 1, 0]] * 255
        # Save the image
        cv2.imwrite(f'{save_dir}/samples/sample_{global_step}_{accelerator.device}_SyncNetScore_{syncnet_score}.jpg', final_image)
        print(f"Image saved successfully: {save_dir}/samples/sample_{global_step}_{accelerator.device}_SyncNetScore_{syncnet_score}.jpg")
    except Exception as e:
        print(f"Failed to save image: {e}")