|
|
import einops |
|
|
import torch |
|
|
from torch.utils.data.dataset import Dataset |
|
|
import torchvision.transforms as transforms |
|
|
from typing import Optional |
|
|
import os |
|
|
import random |
|
|
from PIL import Image |
|
|
import traceback |
|
|
import subprocess |
|
|
from tqdm import tqdm |
|
|
import av |
|
|
from pathlib import Path |
|
|
import numpy as np |
|
|
import cv2 |
|
|
import math |
|
|
from torchvision.io import read_video |
|
|
import pickle |
|
|
def vae_encode(vae,latents): |
|
|
|
|
|
latents_type = None |
|
|
|
|
|
if len(latents.shape) == 5: |
|
|
N,T,C,H,W = latents.shape |
|
|
latents_type = 'video' |
|
|
latents = einops.rearrange(latents,'n t c h w -> (n t) c h w') |
|
|
else: |
|
|
N,C,H,W = latents.shape |
|
|
latents_type = 'image' |
|
|
|
|
|
with torch.no_grad(): |
|
|
latents = vae.encode(latents).latent_dist |
|
|
latents = latents.sample() |
|
|
latents = latents * 0.18215 |
|
|
|
|
|
if latents_type == 'video': |
|
|
latents = einops.rearrange(latents,'(n t) c h w -> n t c h w',n=N,t=T) |
|
|
return latents |
|
|
|
|
|
def vae_decode(vae,latents): |
|
|
latents_type = None |
|
|
|
|
|
if len(latents.shape) == 5: |
|
|
N,T,C,H,W = latents.shape |
|
|
latents_type = 'video' |
|
|
latents = einops.rearrange(latents,'n t c h w -> (n t) c h w') |
|
|
else: |
|
|
N,C,H,W = latents.shape |
|
|
latents_type = 'image' |
|
|
|
|
|
latents = 1 / 0.18215 * latents |
|
|
with torch.no_grad(): |
|
|
latents = vae.decode(latents).sample |
|
|
|
|
|
if latents_type == 'video': |
|
|
latents = einops.rearrange(latents,'(n t) c h w -> n t c h w',n=N,t=T) |
|
|
|
|
|
return latents |
|
|
|
|
|
def lsdir(dir): |
|
|
filenames = os.listdir(dir) |
|
|
paths = list(map( |
|
|
lambda x: os.path.join(dir, x), filenames |
|
|
)) |
|
|
return paths |
|
|
|
|
|
class A2MEvalDataset(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
audio_emb_dir:str, |
|
|
dwpose_dir:str, |
|
|
ref_img_dir:str, |
|
|
num_frames:int, |
|
|
random_audio:bool, |
|
|
random_dwpose:bool, |
|
|
audio_dir:Optional[str]=None, |
|
|
num_evals:Optional[int] = None, |
|
|
audio_suffix:str = 'wav' |
|
|
): |
|
|
super().__init__() |
|
|
self.audio_emb_dir = audio_emb_dir |
|
|
self.audio_emb_paths = lsdir(audio_emb_dir) |
|
|
ref_img_paths = lsdir(ref_img_dir) |
|
|
self.ref_img_paths = sorted(ref_img_paths) |
|
|
self.dwpose_dir = dwpose_dir |
|
|
self.dwpose_paths = lsdir(dwpose_dir) |
|
|
self.audio_dir = audio_dir if not audio_dir is None else None |
|
|
self.num_evals = num_evals |
|
|
self.availables = min(len(self.ref_img_paths), len(self.audio_emb_paths)) |
|
|
self.num_frames = num_frames |
|
|
self.random_audio = random_audio |
|
|
self.random_dwpose = random_dwpose |
|
|
self.audio_suffix = audio_suffix |
|
|
self.transforms = transforms.Compose([ |
|
|
transforms.ToTensor(), |
|
|
transforms.Resize(256), |
|
|
transforms.CenterCrop(256), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
|
]) |
|
|
|
|
|
def __len__(self): |
|
|
return self.num_evals if not self.num_evals is None else self.availables |
|
|
|
|
|
def __getitem__(self, index): |
|
|
while True: |
|
|
try: |
|
|
sample = self.get(index) |
|
|
break |
|
|
except: |
|
|
traceback.print_exc() |
|
|
index = index + 1 |
|
|
return sample |
|
|
|
|
|
def get(self,index:int): |
|
|
ref_img_path = self.ref_img_paths[index] |
|
|
filename = os.path.basename(ref_img_path).split(".")[0] |
|
|
|
|
|
if self.random_audio: |
|
|
audio_emb_path = random.sample(self.audio_emb_paths, 1)[0] |
|
|
else: |
|
|
audio_emb_path = os.path.join(self.audio_emb_dir, filename + '.pt') |
|
|
if not os.path.exists(audio_emb_path): |
|
|
raise ValueError("audio emb path not exists") |
|
|
audio_filename = os.path.basename(audio_emb_path).split(".")[0] |
|
|
if self.random_dwpose: |
|
|
dwpose_path = random.sample(self.dwpose_paths, 1)[0] |
|
|
else: |
|
|
dwpose_path = os.path.join(self.dwpose_dir, filename + '.jpg') |
|
|
if not os.path.exists(dwpose_path): |
|
|
raise ValueError("dwpose path not exists") |
|
|
|
|
|
if not self.audio_dir is None: |
|
|
audio_path = os.path.join(self.audio_dir, audio_filename + '.' + self.audio_suffix) |
|
|
if not os.path.exists(audio_path): |
|
|
audio_path = None |
|
|
else: |
|
|
audio_path = None |
|
|
|
|
|
audio_emb = torch.load(audio_emb_path) |
|
|
if audio_emb.shape[0] < self.num_frames: |
|
|
raise ValueError(f"audio too short, {audio_emb.shape}") |
|
|
else: |
|
|
audio_emb = audio_emb[:self.num_frames] |
|
|
|
|
|
ref_img = Image.open(ref_img_path) |
|
|
ref_img = self.transforms(ref_img) |
|
|
dwpose = Image.open(dwpose_path) |
|
|
dwpose = self.transforms(dwpose) |
|
|
return audio_emb, ref_img, dwpose, audio_path, ref_img_path |
|
|
|
|
|
@staticmethod |
|
|
def collate(batch): |
|
|
audio_emb = torch.stack([b[0] for b in batch]) |
|
|
ref_img = torch.stack([b[1] for b in batch]) |
|
|
dwpose = torch.stack([b[2] for b in batch]) |
|
|
audio_path = list([b[3] for b in batch]) |
|
|
refimg_path = list([b[4] for b in batch]) |
|
|
return dict( |
|
|
audio_emb = audio_emb, |
|
|
ref_img = ref_img, |
|
|
dwpose = dwpose, |
|
|
audio_path = audio_path, |
|
|
refimg_path = refimg_path |
|
|
) |
|
|
|
|
|
class P2MEvalDataset(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
ref_img_dir:str, |
|
|
dwpose_dict_dir:str, |
|
|
num_frames:int, |
|
|
random_dwpose:bool, |
|
|
num_evals:Optional[int] = None, |
|
|
): |
|
|
super().__init__() |
|
|
ref_img_paths = lsdir(ref_img_dir) |
|
|
self.ref_img_paths = sorted(ref_img_paths) |
|
|
self.dwpose_dict_dir = dwpose_dict_dir |
|
|
self.dwpose_dict_paths = lsdir(dwpose_dict_dir) |
|
|
self.num_evals = num_evals |
|
|
self.availables = min(len(self.ref_img_paths), len(self.dwpose_dict_paths)) |
|
|
self.num_frames = num_frames |
|
|
self.random_dwpose = random_dwpose |
|
|
self.transforms = transforms.Compose([ |
|
|
transforms.ToTensor(), |
|
|
transforms.Resize(256), |
|
|
transforms.CenterCrop(256), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
|
]) |
|
|
self.w = 256 |
|
|
self.h = 256 |
|
|
|
|
|
def __len__(self): |
|
|
return self.num_evals if not self.num_evals is None else self.availables |
|
|
|
|
|
def __getitem__(self, index): |
|
|
while True: |
|
|
try: |
|
|
sample = self.get(index) |
|
|
break |
|
|
except: |
|
|
traceback.print_exc() |
|
|
index = index + 1 |
|
|
return sample |
|
|
|
|
|
def get(self,index:int): |
|
|
|
|
|
ref_img_path = self.ref_img_paths[index] |
|
|
filename = os.path.basename(ref_img_path).split(".")[0] |
|
|
|
|
|
if self.random_dwpose: |
|
|
driven_dwpose_path = random.sample(self.dwpose_dict_paths, 1)[0] |
|
|
else: |
|
|
driven_dwpose_path = os.path.join(self.dwpose_dict_dir, filename + '.npy') |
|
|
if not os.path.exists(driven_dwpose_path): |
|
|
raise ValueError("driven dwpose path not exists") |
|
|
source_dwpose_path = os.path.join(self.dwpose_dict_dir, filename + '.npy') |
|
|
if not os.path.exists(source_dwpose_path): |
|
|
raise ValueError("sourec dwpose path not exists") |
|
|
|
|
|
driven_pose = np.load(driven_dwpose_path, allow_pickle=True) |
|
|
if driven_pose.shape[0] < self.num_frames: |
|
|
raise ValueError(f"driven pose too short. Total frames = {driven_pose.shape[0]}") |
|
|
driven_pose = driven_pose[:self.num_frames] |
|
|
source_pose = np.load(source_dwpose_path, allow_pickle=True)[0] |
|
|
ref_img = Image.open(ref_img_path) |
|
|
ref_img = self.transforms(ref_img) |
|
|
driven_poses = [] |
|
|
for pose in driven_pose: |
|
|
driven_pose_frame = align_pose(source_pose, pose, self.h, self.w) |
|
|
driven_pose_frame = self.transforms(driven_pose_frame) |
|
|
driven_poses.append(driven_pose_frame) |
|
|
driven_poses = torch.stack(driven_poses, dim=0) |
|
|
source_pose = draw_facebody( |
|
|
np.zeros(self.w,self.h,3), |
|
|
source_pose["faces"], |
|
|
source_pose["bodies"] |
|
|
) |
|
|
source_pose = self.transforms(source_pose) |
|
|
|
|
|
return ref_img, source_pose, driven_poses, driven_dwpose_path |
|
|
|
|
|
@staticmethod |
|
|
def collate(batch): |
|
|
ref_img = torch.stack([b[0] for b in batch]) |
|
|
source_pose = torch.stack([b[1] for b in batch]) |
|
|
driven_poses = torch.stack([b[2] for b in batch]) |
|
|
driven_pose_paths = list([b[3] for b in batch]) |
|
|
return dict( |
|
|
ref_img = ref_img, |
|
|
source_pose = source_pose, |
|
|
driven_poses = driven_poses, |
|
|
driven_pose_paths = driven_pose_paths |
|
|
) |
|
|
|
|
|
class RecEvalDataset(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
video_dir:str, |
|
|
num_frames:int, |
|
|
num_evals:Optional[int] = None, |
|
|
): |
|
|
super().__init__() |
|
|
if video_dir.endswith(".pkl"): |
|
|
self.video_paths = pickle.load(open(video_dir, "rb")) |
|
|
else: |
|
|
self.video_paths = lsdir(video_dir) |
|
|
self.num_frames = num_frames |
|
|
self.transforms = transforms.Compose([ |
|
|
transforms.Resize(256), |
|
|
transforms.CenterCrop(256), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
|
]) |
|
|
self.length = min(num_evals, len(self.video_paths)) if not num_evals is None else len(self.video_paths) |
|
|
self.w = 256 |
|
|
self.h = 256 |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
def __getitem__(self, index): |
|
|
while True: |
|
|
try: |
|
|
sample = self.get(index) |
|
|
break |
|
|
except: |
|
|
traceback.print_exc() |
|
|
index = index + 1 |
|
|
return sample |
|
|
|
|
|
def get(self,index:int): |
|
|
video_path = self.video_paths[index] |
|
|
filename = os.path.basename(video_path).split(".")[0] |
|
|
video, _, aux = read_video(video_path, pts_unit="sec", output_format="TCHW") |
|
|
fps = aux["video_fps"] |
|
|
sample_frames = self.num_frames + 1 |
|
|
video_length = len(video) |
|
|
clip_length = min(video_length, sample_frames) |
|
|
start_idx = random.randint(0, video_length - clip_length) |
|
|
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, sample_frames, dtype=int) |
|
|
random_idx = random.randint(0, video_length - 1) |
|
|
random_frame = video[random_idx] / 255.0 |
|
|
video = video[batch_index] / 255.0 |
|
|
video = self.transforms(video) |
|
|
random_frame = self.transforms(random_frame) |
|
|
ref_img = video[0] |
|
|
video = video[1:] |
|
|
return ref_img, video, filename, fps, random_frame |
|
|
|
|
|
@staticmethod |
|
|
def collate(batch): |
|
|
ref_img = torch.stack([b[0] for b in batch]) |
|
|
video = torch.stack([b[1] for b in batch]) |
|
|
filename = list ([b[2] for b in batch]) |
|
|
fps = list([b[3] for b in batch]) |
|
|
random_frame = torch.stack([b[4] for b in batch]) |
|
|
return dict( |
|
|
ref_img = ref_img, |
|
|
video = video, |
|
|
filename = filename, |
|
|
fps = fps, |
|
|
random_frame = random_frame |
|
|
) |
|
|
def read_frames(video_path): |
|
|
container = av.open(video_path) |
|
|
|
|
|
video_stream = next(s for s in container.streams if s.type == "video") |
|
|
frames = [] |
|
|
for packet in container.demux(video_stream): |
|
|
for frame in packet.decode(): |
|
|
image = Image.frombytes( |
|
|
"RGB", |
|
|
(frame.width, frame.height), |
|
|
frame.to_rgb().to_ndarray(), |
|
|
) |
|
|
frames.append(image) |
|
|
|
|
|
return frames |
|
|
|
|
|
def save_videos_from_pil(pil_images, path, fps=8): |
|
|
save_fmt = Path(path).suffix |
|
|
os.makedirs(os.path.dirname(path), exist_ok=True) |
|
|
width, height = pil_images[0].size |
|
|
|
|
|
if save_fmt == ".mp4": |
|
|
codec = "libx264" |
|
|
container = av.open(path, "w") |
|
|
stream = container.add_stream(codec, rate=fps) |
|
|
|
|
|
stream.width = width |
|
|
stream.height = height |
|
|
|
|
|
for pil_image in pil_images: |
|
|
|
|
|
av_frame = av.VideoFrame.from_image(pil_image) |
|
|
container.mux(stream.encode(av_frame)) |
|
|
container.mux(stream.encode()) |
|
|
container.close() |
|
|
|
|
|
elif save_fmt == ".gif": |
|
|
pil_images[0].save( |
|
|
fp=path, |
|
|
format="GIF", |
|
|
append_images=pil_images[1:], |
|
|
save_all=True, |
|
|
duration=(1 / fps * 1000), |
|
|
loop=0, |
|
|
) |
|
|
else: |
|
|
raise ValueError("Unsupported file type. Use .mp4 or .gif.") |
|
|
|
|
|
def align_face(face_1:np.ndarray,face_2:np.ndarray): |
|
|
""" |
|
|
Align face_1 to face_2 |
|
|
Input: |
|
|
face: np.ndarray [68,2], -1 is non visible |
|
|
Return: |
|
|
face_1 after align, [68,2] |
|
|
""" |
|
|
face1_non_vis = face_1 == -1 |
|
|
face2_vis = face_2 > 0 |
|
|
face_vis = (face_1 > 0) * face2_vis |
|
|
face_vis = face_vis[:,0] * face_vis[:,1] |
|
|
face_1_vis = face_1[face_vis] |
|
|
face_2_vis = face_2[face_vis] |
|
|
print(face_1_vis.shape) |
|
|
x_1, y_1 = face_1_vis.copy(), face_1_vis[:,1:] |
|
|
x_1[:,1] = 1 |
|
|
y_1 = np.concatenate([y_1,np.ones(shape=(68,1))], axis=1) |
|
|
x_2, y_2 = face_2_vis[:,0], face_2_vis[:,1] |
|
|
s_x,t_x = np.linalg.inv(x_1.T @ x_1) @ x_1.T @ x_2 |
|
|
s_y,t_y = np.linalg.inv(y_1.T @ y_1) @ y_1.T @ y_2 |
|
|
trans = np.array([ |
|
|
s_x,0,t_x, |
|
|
0,s_y,t_y, |
|
|
0,0,1 |
|
|
]).reshape(3,3) |
|
|
face_ret = np.concatenate([face_1, np.ones(shape=(68,1))],axis=1) |
|
|
face_ret = (face_ret @ trans)[:,:2] |
|
|
face_ret[face1_non_vis] = -1 |
|
|
return face_ret |
|
|
|
|
|
def draw_facepose(canvas, lmks): |
|
|
eps = 0.01 |
|
|
H, W = canvas.shape[:2] |
|
|
for lmk in lmks: |
|
|
x, y = lmk |
|
|
x = int(x * W) |
|
|
y = int(y * H) |
|
|
if x > eps and y > eps: |
|
|
cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1) |
|
|
return canvas |
|
|
|
|
|
def align_body(body_1,body_2): |
|
|
cdd_1,cdd_2 = body_1["candidate"].copy(), body_2["candidate"].copy() |
|
|
sub_1,sub_2 = body_1["subset"][0].copy(), body_2["subset"][0].copy() |
|
|
sub1_vis = sub_1 >= 0 |
|
|
sub2_vis = sub_2 >= 0 |
|
|
vis = sub1_vis * sub2_vis |
|
|
num_points = np.sum(vis) |
|
|
if num_points < 3: |
|
|
return body_2 |
|
|
cdd1_vis, cdd2_vis = cdd_1[vis], cdd_2[vis] |
|
|
x_1, y_1 = cdd1_vis.copy(), cdd1_vis[:,1:] |
|
|
x_1[:,1] = 1 |
|
|
y_1 = np.concatenate([y_1, np.ones(shape=(num_points,1))], axis=1) |
|
|
x_2, y_2 = cdd2_vis[:,0], cdd2_vis[:,1] |
|
|
s_x, t_x = np.linalg.inv(x_1.T @ x_1) @ x_1.T @ x_2 |
|
|
s_y, t_y = np.linalg.inv(y_1.T @ y_1) @ y_1.T @ y_2 |
|
|
cdd_1[:,0] = cdd_1[:,0] * s_x + t_x |
|
|
cdd_1[:,1] = cdd_1[:,1] * s_y + t_y |
|
|
subset_vis = sub_1 * (sub1_vis * 2 - 1) * (vis * 2 - 1) |
|
|
return dict(candidate=cdd_1,subset=subset_vis[np.newaxis,:]) |
|
|
|
|
|
def draw_bodypose(canvas, candidate, subset): |
|
|
H, W, C = canvas.shape |
|
|
candidate = np.array(candidate) |
|
|
subset = np.array(subset) |
|
|
|
|
|
stickwidth = 4 |
|
|
|
|
|
limbSeq = [ |
|
|
[2, 3], |
|
|
[2, 6], |
|
|
[3, 4], |
|
|
[4, 5], |
|
|
[6, 7], |
|
|
[7, 8], |
|
|
[2, 9], |
|
|
[9, 10], |
|
|
[10, 11], |
|
|
[2, 12], |
|
|
[12, 13], |
|
|
[13, 14], |
|
|
[2, 1], |
|
|
[1, 15], |
|
|
[15, 17], |
|
|
[1, 16], |
|
|
[16, 18], |
|
|
[3, 17], |
|
|
[6, 18], |
|
|
] |
|
|
|
|
|
colors = [ |
|
|
[255, 0, 0], |
|
|
[255, 85, 0], |
|
|
[255, 170, 0], |
|
|
[255, 255, 0], |
|
|
[170, 255, 0], |
|
|
[85, 255, 0], |
|
|
[0, 255, 0], |
|
|
[0, 255, 85], |
|
|
[0, 255, 170], |
|
|
[0, 255, 255], |
|
|
[0, 170, 255], |
|
|
[0, 85, 255], |
|
|
[0, 0, 255], |
|
|
[85, 0, 255], |
|
|
[170, 0, 255], |
|
|
[255, 0, 255], |
|
|
[255, 0, 170], |
|
|
[255, 0, 85], |
|
|
] |
|
|
|
|
|
for i in range(17): |
|
|
for n in range(len(subset)): |
|
|
index = subset[n][np.array(limbSeq[i]) - 1] |
|
|
if -1 in index: |
|
|
continue |
|
|
Y = candidate[index.astype(int), 0] * float(W) |
|
|
X = candidate[index.astype(int), 1] * float(H) |
|
|
mX = np.mean(X) |
|
|
mY = np.mean(Y) |
|
|
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 |
|
|
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) |
|
|
polygon = cv2.ellipse2Poly( |
|
|
(int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1 |
|
|
) |
|
|
cv2.fillConvexPoly(canvas, polygon, colors[i]) |
|
|
|
|
|
canvas = (canvas * 0.6).astype(np.uint8) |
|
|
|
|
|
for i in range(18): |
|
|
for n in range(len(subset)): |
|
|
index = int(subset[n][i]) |
|
|
if index == -1: |
|
|
continue |
|
|
x, y = candidate[index][0:2] |
|
|
x = int(x * W) |
|
|
y = int(y * H) |
|
|
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) |
|
|
|
|
|
return canvas |
|
|
def draw_facebody(canvas,face,body): |
|
|
canvas = draw_bodypose(canvas,body["candidate"], body["subset"]) |
|
|
canvas = draw_facepose(canvas,face) |
|
|
return canvas |
|
|
|
|
|
def align_pose(source_pose_dict, target_pose_dict, height:int = 256, width:int = 256): |
|
|
face_1 = target_pose_dict["faces"][0] |
|
|
face_2 = source_pose_dict["faces"][0] |
|
|
body_1 = target_pose_dict["bodies"] |
|
|
body_2 = source_pose_dict["bodies"] |
|
|
cdd_1,cdd_2 = body_1["candidate"].copy(), body_2["candidate"].copy() |
|
|
sub_1,sub_2 = body_1["subset"][0].copy(), body_2["subset"][0].copy() |
|
|
|
|
|
sub1_vis = sub_1 >= 0 |
|
|
sub2_vis = sub_2 >= 0 |
|
|
vis = sub1_vis * sub2_vis |
|
|
num_points = np.sum(vis) |
|
|
cdd1_vis, cdd2_vis = cdd_1[vis], cdd_2[vis] |
|
|
body_x_1, body_y_1 = cdd1_vis.copy(), cdd1_vis[:,1:] |
|
|
body_x_1[:,1] = 1 |
|
|
body_y_1 = np.concatenate([body_y_1, np.ones(shape=(num_points,1))], axis=1) |
|
|
body_x_2, body_y_2 = cdd2_vis[:,0], cdd2_vis[:,1] |
|
|
|
|
|
subset_vis = sub_1 * (sub1_vis * 2 - 1) * (vis * 2 - 1) |
|
|
|
|
|
face1_non_vis = face_1 == -1 |
|
|
face2_vis = face_2 > 0 |
|
|
face_vis = (face_1 > 0) * face2_vis |
|
|
face_vis = face_vis[:,0] * face_vis[:,1] |
|
|
face_1_vis = face_1[face_vis] |
|
|
face_2_vis = face_2[face_vis] |
|
|
|
|
|
face_x_1, face_y_1 = face_1_vis.copy(), face_1_vis[:,1:] |
|
|
face_x_1[:,1] = 1 |
|
|
face_y_1 = np.concatenate([face_y_1,np.ones(shape=(68,1))], axis=1) |
|
|
face_x_2, face_y_2 = face_2_vis[:,0], face_2_vis[:,1] |
|
|
|
|
|
x_1 = np.concatenate([body_x_1, face_x_1], axis=0) |
|
|
y_1 = np.concatenate([body_y_1, face_y_1], axis=0) |
|
|
x_2 = np.concatenate([body_x_2, face_x_2], axis=0) |
|
|
y_2 = np.concatenate([body_y_2, face_y_2], axis=0) |
|
|
|
|
|
s_x,t_x = np.linalg.inv(x_1.T @ x_1) @ x_1.T @ x_2 |
|
|
s_y,t_y = np.linalg.inv(y_1.T @ y_1) @ y_1.T @ y_2 |
|
|
|
|
|
cdd_1[:,0] = cdd_1[:,0] * s_x + t_x |
|
|
cdd_1[:,1] = cdd_1[:,1] * s_y + t_y |
|
|
face_ret = np.copy(face_1) |
|
|
face_ret[:,0] = face_ret[:,0] * s_x + t_x |
|
|
face_ret[:,1] = face_ret[:,1] * s_y + t_y |
|
|
face_ret[face1_non_vis] = -1 |
|
|
body_ret = dict(candidate=cdd_1,subset=subset_vis[np.newaxis,:]) |
|
|
pil = draw_facebody( |
|
|
np.zeros(shape=(height, width, 3), dtype=np.uint8), |
|
|
face_ret, body_ret |
|
|
) |
|
|
pil = Image.fromarray(pil) |
|
|
return pil |
|
|
class first_frame_extractor: |
|
|
def __init__(self, video_dir:str, output_dir:str): |
|
|
self.video_dir = video_dir |
|
|
self.output_dir = output_dir |
|
|
self.video_paths = lsdir(video_dir) |
|
|
def extract(self): |
|
|
for vp in tqdm(self.video_paths): |
|
|
out = os.path.join(self.output_dir, os.path.basename(vp).split('.')[0] + ".jpg") |
|
|
command = [ |
|
|
'ffmpeg', |
|
|
'-hide_banner', |
|
|
'-y', |
|
|
'-i', vp, |
|
|
'-vf', 'scale=256:256', |
|
|
'-vframes', '1', |
|
|
'-ss', '00:00:00', |
|
|
out |
|
|
] |
|
|
subprocess.run(command) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
from torch.utils.data import DataLoader |
|
|
evalset = RecEvalDataset( |
|
|
video_dir= "/mnt/pfs-gv8sxa/tts/dhg/zqy/data/FaceVid_240h/videos", |
|
|
num_frames=96, |
|
|
) |
|
|
evalloader = DataLoader( |
|
|
evalset, 4, shuffle=False,drop_last=True,collate_fn=evalset.collate,num_workers=0 |
|
|
) |
|
|
for data in evalloader: |
|
|
img, video, filename, fps = data["ref_img"], data["video"], data["filename"], data["fps"] |
|
|
print(img.shape) |
|
|
print(video.shape) |
|
|
print(filename) |
|
|
print(fps) |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|