|
|
| import os |
| import cv2 |
| import json |
| import time |
| import einops |
| import librosa |
| import torch |
| import random |
| import argparse |
| import traceback |
| import numpy as np |
| from tqdm import tqdm |
| from PIL import Image |
| from einops import rearrange |
|
|
|
|
|
|
| def get_facemask(ref_image, align_instance, area=1.25): |
| |
| bsz, f, c, h, w = ref_image.shape |
| images = rearrange(ref_image, "b f c h w -> (b f) h w c").data.cpu().numpy().astype(np.uint8) |
| face_masks = [] |
| for image in images: |
| image_pil = Image.fromarray(image).convert("RGB") |
| _, _, bboxes_list = align_instance(np.array(image_pil)[:,:,[2,1,0]], maxface=True) |
| try: |
| bboxSrc = bboxes_list[0] |
| except: |
| bboxSrc = [0, 0, w, h] |
| x1, y1, ww, hh = bboxSrc |
| x2, y2 = x1 + ww, y1 + hh |
| ww, hh = (x2-x1) * area, (y2-y1) * area |
| center = [(x2+x1)//2, (y2+y1)//2] |
| x1 = max(center[0] - ww//2, 0) |
| y1 = max(center[1] - hh//2, 0) |
| x2 = min(center[0] + ww//2, w) |
| y2 = min(center[1] + hh//2, h) |
| |
| face_mask = np.zeros_like(np.array(image_pil)) |
| face_mask[int(y1):int(y2), int(x1):int(x2)] = 1.0 |
| face_masks.append(torch.from_numpy(face_mask[...,:1])) |
| face_masks = torch.stack(face_masks, dim=0) |
| face_masks = rearrange(face_masks, "(b f) h w c -> b c f h w", b=bsz, f=f) |
| face_masks = face_masks.to(device=ref_image.device, dtype=ref_image.dtype) |
| return face_masks |
|
|
|
|
| def encode_audio(wav2vec, audio_feats, fps, num_frames=129): |
| if fps == 25: |
| start_ts = [0] |
| step_ts = [1] |
| elif fps == 12.5: |
| start_ts = [0] |
| step_ts = [2] |
| else: |
| start_ts = [0] |
| step_ts = [1] |
|
|
| num_frames = min(num_frames, 400) |
| audio_feats = wav2vec.encoder(audio_feats.unsqueeze(0)[:, :, :3000], output_hidden_states=True).hidden_states |
| audio_feats = torch.stack(audio_feats, dim=2) |
| audio_feats = torch.cat([torch.zeros_like(audio_feats[:,:4]), audio_feats], 1) |
| |
| audio_prompts = [] |
| for bb in range(1): |
| audio_feats_list = [] |
| for f in range(num_frames): |
| cur_t = (start_ts[bb] + f * step_ts[bb]) * 2 |
| audio_clip = audio_feats[bb:bb+1, cur_t: cur_t+10] |
| audio_feats_list.append(audio_clip) |
| audio_feats_list = torch.stack(audio_feats_list, 1) |
| audio_prompts.append(audio_feats_list) |
| audio_prompts = torch.cat(audio_prompts) |
| return audio_prompts |
|
|