|
|
|
|
|
import os |
|
|
import cv2 |
|
|
import json |
|
|
import time |
|
|
import decord |
|
|
import einops |
|
|
import librosa |
|
|
import torch |
|
|
import random |
|
|
import argparse |
|
|
import traceback |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
from PIL import Image |
|
|
from einops import rearrange |
|
|
|
|
|
|
|
|
|
|
|
def get_facemask(ref_image, align_instance, area=1.25): |
|
|
|
|
|
bsz, f, c, h, w = ref_image.shape |
|
|
images = rearrange(ref_image, "b f c h w -> (b f) h w c").data.cpu().numpy().astype(np.uint8) |
|
|
face_masks = [] |
|
|
for image in images: |
|
|
image_pil = Image.fromarray(image).convert("RGB") |
|
|
_, _, bboxes_list = align_instance(np.array(image_pil)[:,:,[2,1,0]], maxface=True) |
|
|
try: |
|
|
bboxSrc = bboxes_list[0] |
|
|
except: |
|
|
bboxSrc = [0, 0, w, h] |
|
|
x1, y1, ww, hh = bboxSrc |
|
|
x2, y2 = x1 + ww, y1 + hh |
|
|
ww, hh = (x2-x1) * area, (y2-y1) * area |
|
|
center = [(x2+x1)//2, (y2+y1)//2] |
|
|
x1 = max(center[0] - ww//2, 0) |
|
|
y1 = max(center[1] - hh//2, 0) |
|
|
x2 = min(center[0] + ww//2, w) |
|
|
y2 = min(center[1] + hh//2, h) |
|
|
|
|
|
face_mask = np.zeros_like(np.array(image_pil)) |
|
|
face_mask[int(y1):int(y2), int(x1):int(x2)] = 1.0 |
|
|
face_masks.append(torch.from_numpy(face_mask[...,:1])) |
|
|
face_masks = torch.stack(face_masks, dim=0) |
|
|
face_masks = rearrange(face_masks, "(b f) h w c -> b c f h w", b=bsz, f=f) |
|
|
face_masks = face_masks.to(device=ref_image.device, dtype=ref_image.dtype) |
|
|
return face_masks |
|
|
|
|
|
|
|
|
def encode_audio(wav2vec, audio_feats, fps, num_frames=129): |
|
|
if fps == 25: |
|
|
start_ts = [0] |
|
|
step_ts = [1] |
|
|
elif fps == 12.5: |
|
|
start_ts = [0] |
|
|
step_ts = [2] |
|
|
else: |
|
|
start_ts = [0] |
|
|
step_ts = [1] |
|
|
|
|
|
num_frames = min(num_frames, 400) |
|
|
audio_feats = wav2vec.encoder(audio_feats.unsqueeze(0)[:, :, :3000], output_hidden_states=True).hidden_states |
|
|
audio_feats = torch.stack(audio_feats, dim=2) |
|
|
audio_feats = torch.cat([torch.zeros_like(audio_feats[:,:4]), audio_feats], 1) |
|
|
|
|
|
audio_prompts = [] |
|
|
for bb in range(1): |
|
|
audio_feats_list = [] |
|
|
for f in range(num_frames): |
|
|
cur_t = (start_ts[bb] + f * step_ts[bb]) * 2 |
|
|
audio_clip = audio_feats[bb:bb+1, cur_t: cur_t+10] |
|
|
audio_feats_list.append(audio_clip) |
|
|
audio_feats_list = torch.stack(audio_feats_list, 1) |
|
|
audio_prompts.append(audio_feats_list) |
|
|
audio_prompts = torch.cat(audio_prompts) |
|
|
return audio_prompts |