import einops import torch from torch.utils.data.dataset import Dataset import torchvision.transforms as transforms from typing import Optional import os import random from PIL import Image import traceback import subprocess from tqdm import tqdm import av from pathlib import Path import numpy as np import cv2 import math from torchvision.io import read_video import pickle def vae_encode(vae,latents): # video : N,T,C,H,W latents_type = None if len(latents.shape) == 5: N,T,C,H,W = latents.shape latents_type = 'video' latents = einops.rearrange(latents,'n t c h w -> (n t) c h w') else: N,C,H,W = latents.shape latents_type = 'image' with torch.no_grad(): latents = vae.encode(latents).latent_dist latents = latents.sample() latents = latents * 0.18215 if latents_type == 'video': latents = einops.rearrange(latents,'(n t) c h w -> n t c h w',n=N,t=T) return latents def vae_decode(vae,latents): latents_type = None if len(latents.shape) == 5: N,T,C,H,W = latents.shape latents_type = 'video' latents = einops.rearrange(latents,'n t c h w -> (n t) c h w') else: N,C,H,W = latents.shape latents_type = 'image' latents = 1 / 0.18215 * latents with torch.no_grad(): latents = vae.decode(latents).sample # (nt)chw if latents_type == 'video': latents = einops.rearrange(latents,'(n t) c h w -> n t c h w',n=N,t=T) return latents def lsdir(dir): filenames = os.listdir(dir) paths = list(map( lambda x: os.path.join(dir, x), filenames )) return paths class A2MEvalDataset(Dataset): def __init__( self, audio_emb_dir:str, dwpose_dir:str, ref_img_dir:str, num_frames:int, random_audio:bool, random_dwpose:bool, audio_dir:Optional[str]=None, num_evals:Optional[int] = None, audio_suffix:str = 'wav' ): super().__init__() self.audio_emb_dir = audio_emb_dir self.audio_emb_paths = lsdir(audio_emb_dir) ref_img_paths = lsdir(ref_img_dir) self.ref_img_paths = sorted(ref_img_paths) self.dwpose_dir = dwpose_dir self.dwpose_paths = lsdir(dwpose_dir) self.audio_dir = audio_dir if not audio_dir is None else None self.num_evals = num_evals self.availables = min(len(self.ref_img_paths), len(self.audio_emb_paths)) self.num_frames = num_frames self.random_audio = random_audio self.random_dwpose = random_dwpose self.audio_suffix = audio_suffix self.transforms = transforms.Compose([ transforms.ToTensor(), transforms.Resize(256), transforms.CenterCrop(256), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ]) def __len__(self): return self.num_evals if not self.num_evals is None else self.availables def __getitem__(self, index): while True: try: sample = self.get(index) break except: traceback.print_exc() index = index + 1 return sample def get(self,index:int): ref_img_path = self.ref_img_paths[index] filename = os.path.basename(ref_img_path).split(".")[0] if self.random_audio: audio_emb_path = random.sample(self.audio_emb_paths, 1)[0] else: audio_emb_path = os.path.join(self.audio_emb_dir, filename + '.pt') if not os.path.exists(audio_emb_path): raise ValueError("audio emb path not exists") audio_filename = os.path.basename(audio_emb_path).split(".")[0] if self.random_dwpose: dwpose_path = random.sample(self.dwpose_paths, 1)[0] else: dwpose_path = os.path.join(self.dwpose_dir, filename + '.jpg') if not os.path.exists(dwpose_path): raise ValueError("dwpose path not exists") if not self.audio_dir is None: audio_path = os.path.join(self.audio_dir, audio_filename + '.' + self.audio_suffix) if not os.path.exists(audio_path): audio_path = None else: audio_path = None audio_emb = torch.load(audio_emb_path) if audio_emb.shape[0] < self.num_frames: raise ValueError(f"audio too short, {audio_emb.shape}") else: audio_emb = audio_emb[:self.num_frames] ref_img = Image.open(ref_img_path) ref_img = self.transforms(ref_img) dwpose = Image.open(dwpose_path) dwpose = self.transforms(dwpose) return audio_emb, ref_img, dwpose, audio_path, ref_img_path @staticmethod def collate(batch): audio_emb = torch.stack([b[0] for b in batch]) ref_img = torch.stack([b[1] for b in batch]) dwpose = torch.stack([b[2] for b in batch]) audio_path = list([b[3] for b in batch]) refimg_path = list([b[4] for b in batch]) return dict( audio_emb = audio_emb, ref_img = ref_img, dwpose = dwpose, audio_path = audio_path, refimg_path = refimg_path ) class P2MEvalDataset(Dataset): def __init__( self, ref_img_dir:str, dwpose_dict_dir:str, num_frames:int, random_dwpose:bool, num_evals:Optional[int] = None, ): super().__init__() ref_img_paths = lsdir(ref_img_dir) self.ref_img_paths = sorted(ref_img_paths) self.dwpose_dict_dir = dwpose_dict_dir self.dwpose_dict_paths = lsdir(dwpose_dict_dir) self.num_evals = num_evals self.availables = min(len(self.ref_img_paths), len(self.dwpose_dict_paths)) self.num_frames = num_frames self.random_dwpose = random_dwpose self.transforms = transforms.Compose([ transforms.ToTensor(), transforms.Resize(256), transforms.CenterCrop(256), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ]) self.w = 256 self.h = 256 def __len__(self): return self.num_evals if not self.num_evals is None else self.availables def __getitem__(self, index): while True: try: sample = self.get(index) break except: traceback.print_exc() index = index + 1 return sample def get(self,index:int): ref_img_path = self.ref_img_paths[index] filename = os.path.basename(ref_img_path).split(".")[0] if self.random_dwpose: driven_dwpose_path = random.sample(self.dwpose_dict_paths, 1)[0] else: driven_dwpose_path = os.path.join(self.dwpose_dict_dir, filename + '.npy') if not os.path.exists(driven_dwpose_path): raise ValueError("driven dwpose path not exists") source_dwpose_path = os.path.join(self.dwpose_dict_dir, filename + '.npy') if not os.path.exists(source_dwpose_path): raise ValueError("sourec dwpose path not exists") driven_pose = np.load(driven_dwpose_path, allow_pickle=True) if driven_pose.shape[0] < self.num_frames: raise ValueError(f"driven pose too short. Total frames = {driven_pose.shape[0]}") driven_pose = driven_pose[:self.num_frames] source_pose = np.load(source_dwpose_path, allow_pickle=True)[0] ref_img = Image.open(ref_img_path) ref_img = self.transforms(ref_img) driven_poses = [] for pose in driven_pose: driven_pose_frame = align_pose(source_pose, pose, self.h, self.w) driven_pose_frame = self.transforms(driven_pose_frame) driven_poses.append(driven_pose_frame) driven_poses = torch.stack(driven_poses, dim=0) source_pose = draw_facebody( np.zeros(self.w,self.h,3), source_pose["faces"], source_pose["bodies"] ) source_pose = self.transforms(source_pose) return ref_img, source_pose, driven_poses, driven_dwpose_path @staticmethod def collate(batch): ref_img = torch.stack([b[0] for b in batch]) source_pose = torch.stack([b[1] for b in batch]) driven_poses = torch.stack([b[2] for b in batch]) driven_pose_paths = list([b[3] for b in batch]) return dict( ref_img = ref_img, source_pose = source_pose, driven_poses = driven_poses, driven_pose_paths = driven_pose_paths ) class RecEvalDataset(Dataset): def __init__( self, video_dir:str, num_frames:int, num_evals:Optional[int] = None, ): super().__init__() if video_dir.endswith(".pkl"): self.video_paths = pickle.load(open(video_dir, "rb")) else: self.video_paths = lsdir(video_dir) self.num_frames = num_frames self.transforms = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(256), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ]) self.length = min(num_evals, len(self.video_paths)) if not num_evals is None else len(self.video_paths) self.w = 256 self.h = 256 def __len__(self): return self.length def __getitem__(self, index): while True: try: sample = self.get(index) break except: traceback.print_exc() index = index + 1 return sample def get(self,index:int): video_path = self.video_paths[index] filename = os.path.basename(video_path).split(".")[0] video, _, aux = read_video(video_path, pts_unit="sec", output_format="TCHW") fps = aux["video_fps"] sample_frames = self.num_frames + 1 video_length = len(video) clip_length = min(video_length, sample_frames) start_idx = random.randint(0, video_length - clip_length) batch_index = np.linspace(start_idx, start_idx + clip_length - 1, sample_frames, dtype=int) random_idx = random.randint(0, video_length - 1) random_frame = video[random_idx] / 255.0 video = video[batch_index] / 255.0 video = self.transforms(video) random_frame = self.transforms(random_frame) ref_img = video[0] video = video[1:] return ref_img, video, filename, fps, random_frame @staticmethod def collate(batch): ref_img = torch.stack([b[0] for b in batch]) video = torch.stack([b[1] for b in batch]) filename = list ([b[2] for b in batch]) fps = list([b[3] for b in batch]) random_frame = torch.stack([b[4] for b in batch]) return dict( ref_img = ref_img, video = video, filename = filename, fps = fps, random_frame = random_frame ) def read_frames(video_path): container = av.open(video_path) video_stream = next(s for s in container.streams if s.type == "video") frames = [] for packet in container.demux(video_stream): for frame in packet.decode(): image = Image.frombytes( "RGB", (frame.width, frame.height), frame.to_rgb().to_ndarray(), ) frames.append(image) return frames def save_videos_from_pil(pil_images, path, fps=8): save_fmt = Path(path).suffix os.makedirs(os.path.dirname(path), exist_ok=True) width, height = pil_images[0].size if save_fmt == ".mp4": codec = "libx264" container = av.open(path, "w") stream = container.add_stream(codec, rate=fps) stream.width = width stream.height = height for pil_image in pil_images: # pil_image = Image.fromarray(image_arr).convert("RGB") av_frame = av.VideoFrame.from_image(pil_image) container.mux(stream.encode(av_frame)) container.mux(stream.encode()) container.close() elif save_fmt == ".gif": pil_images[0].save( fp=path, format="GIF", append_images=pil_images[1:], save_all=True, duration=(1 / fps * 1000), loop=0, ) else: raise ValueError("Unsupported file type. Use .mp4 or .gif.") def align_face(face_1:np.ndarray,face_2:np.ndarray): """ Align face_1 to face_2 Input: face: np.ndarray [68,2], -1 is non visible Return: face_1 after align, [68,2] """ face1_non_vis = face_1 == -1 face2_vis = face_2 > 0 face_vis = (face_1 > 0) * face2_vis face_vis = face_vis[:,0] * face_vis[:,1] face_1_vis = face_1[face_vis] face_2_vis = face_2[face_vis] print(face_1_vis.shape) x_1, y_1 = face_1_vis.copy(), face_1_vis[:,1:] x_1[:,1] = 1 y_1 = np.concatenate([y_1,np.ones(shape=(68,1))], axis=1) x_2, y_2 = face_2_vis[:,0], face_2_vis[:,1] s_x,t_x = np.linalg.inv(x_1.T @ x_1) @ x_1.T @ x_2 s_y,t_y = np.linalg.inv(y_1.T @ y_1) @ y_1.T @ y_2 trans = np.array([ s_x,0,t_x, 0,s_y,t_y, 0,0,1 ]).reshape(3,3) face_ret = np.concatenate([face_1, np.ones(shape=(68,1))],axis=1) face_ret = (face_ret @ trans)[:,:2] face_ret[face1_non_vis] = -1 return face_ret def draw_facepose(canvas, lmks): eps = 0.01 H, W = canvas.shape[:2] for lmk in lmks: x, y = lmk x = int(x * W) y = int(y * H) if x > eps and y > eps: cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1) return canvas def align_body(body_1,body_2): cdd_1,cdd_2 = body_1["candidate"].copy(), body_2["candidate"].copy() sub_1,sub_2 = body_1["subset"][0].copy(), body_2["subset"][0].copy() sub1_vis = sub_1 >= 0 sub2_vis = sub_2 >= 0 vis = sub1_vis * sub2_vis num_points = np.sum(vis) if num_points < 3: return body_2 cdd1_vis, cdd2_vis = cdd_1[vis], cdd_2[vis] x_1, y_1 = cdd1_vis.copy(), cdd1_vis[:,1:] x_1[:,1] = 1 y_1 = np.concatenate([y_1, np.ones(shape=(num_points,1))], axis=1) x_2, y_2 = cdd2_vis[:,0], cdd2_vis[:,1] s_x, t_x = np.linalg.inv(x_1.T @ x_1) @ x_1.T @ x_2 s_y, t_y = np.linalg.inv(y_1.T @ y_1) @ y_1.T @ y_2 cdd_1[:,0] = cdd_1[:,0] * s_x + t_x cdd_1[:,1] = cdd_1[:,1] * s_y + t_y subset_vis = sub_1 * (sub1_vis * 2 - 1) * (vis * 2 - 1) return dict(candidate=cdd_1,subset=subset_vis[np.newaxis,:]) def draw_bodypose(canvas, candidate, subset): H, W, C = canvas.shape candidate = np.array(candidate) subset = np.array(subset) stickwidth = 4 limbSeq = [ [2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], [1, 16], [16, 18], [3, 17], [6, 18], ] colors = [ [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85], ] for i in range(17): for n in range(len(subset)): index = subset[n][np.array(limbSeq[i]) - 1] if -1 in index: continue Y = candidate[index.astype(int), 0] * float(W) X = candidate[index.astype(int), 1] * float(H) mX = np.mean(X) mY = np.mean(Y) length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) polygon = cv2.ellipse2Poly( (int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1 ) cv2.fillConvexPoly(canvas, polygon, colors[i]) canvas = (canvas * 0.6).astype(np.uint8) for i in range(18): for n in range(len(subset)): index = int(subset[n][i]) if index == -1: continue x, y = candidate[index][0:2] x = int(x * W) y = int(y * H) cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) return canvas def draw_facebody(canvas,face,body): canvas = draw_bodypose(canvas,body["candidate"], body["subset"]) canvas = draw_facepose(canvas,face) return canvas def align_pose(source_pose_dict, target_pose_dict, height:int = 256, width:int = 256): face_1 = target_pose_dict["faces"][0] face_2 = source_pose_dict["faces"][0] body_1 = target_pose_dict["bodies"] body_2 = source_pose_dict["bodies"] cdd_1,cdd_2 = body_1["candidate"].copy(), body_2["candidate"].copy() sub_1,sub_2 = body_1["subset"][0].copy(), body_2["subset"][0].copy() sub1_vis = sub_1 >= 0 sub2_vis = sub_2 >= 0 vis = sub1_vis * sub2_vis num_points = np.sum(vis) cdd1_vis, cdd2_vis = cdd_1[vis], cdd_2[vis] body_x_1, body_y_1 = cdd1_vis.copy(), cdd1_vis[:,1:] body_x_1[:,1] = 1 body_y_1 = np.concatenate([body_y_1, np.ones(shape=(num_points,1))], axis=1) body_x_2, body_y_2 = cdd2_vis[:,0], cdd2_vis[:,1] subset_vis = sub_1 * (sub1_vis * 2 - 1) * (vis * 2 - 1) face1_non_vis = face_1 == -1 face2_vis = face_2 > 0 face_vis = (face_1 > 0) * face2_vis face_vis = face_vis[:,0] * face_vis[:,1] face_1_vis = face_1[face_vis] face_2_vis = face_2[face_vis] face_x_1, face_y_1 = face_1_vis.copy(), face_1_vis[:,1:] face_x_1[:,1] = 1 face_y_1 = np.concatenate([face_y_1,np.ones(shape=(68,1))], axis=1) face_x_2, face_y_2 = face_2_vis[:,0], face_2_vis[:,1] x_1 = np.concatenate([body_x_1, face_x_1], axis=0) y_1 = np.concatenate([body_y_1, face_y_1], axis=0) x_2 = np.concatenate([body_x_2, face_x_2], axis=0) y_2 = np.concatenate([body_y_2, face_y_2], axis=0) s_x,t_x = np.linalg.inv(x_1.T @ x_1) @ x_1.T @ x_2 s_y,t_y = np.linalg.inv(y_1.T @ y_1) @ y_1.T @ y_2 cdd_1[:,0] = cdd_1[:,0] * s_x + t_x cdd_1[:,1] = cdd_1[:,1] * s_y + t_y face_ret = np.copy(face_1) face_ret[:,0] = face_ret[:,0] * s_x + t_x face_ret[:,1] = face_ret[:,1] * s_y + t_y face_ret[face1_non_vis] = -1 body_ret = dict(candidate=cdd_1,subset=subset_vis[np.newaxis,:]) pil = draw_facebody( np.zeros(shape=(height, width, 3), dtype=np.uint8), face_ret, body_ret ) pil = Image.fromarray(pil) return pil class first_frame_extractor: def __init__(self, video_dir:str, output_dir:str): self.video_dir = video_dir self.output_dir = output_dir self.video_paths = lsdir(video_dir) def extract(self): for vp in tqdm(self.video_paths): out = os.path.join(self.output_dir, os.path.basename(vp).split('.')[0] + ".jpg") command = [ 'ffmpeg', '-hide_banner', '-y', '-i', vp, '-vf', 'scale=256:256', '-vframes', '1', '-ss', '00:00:00', out ] subprocess.run(command) if __name__ == "__main__": from torch.utils.data import DataLoader evalset = RecEvalDataset( video_dir= "/mnt/pfs-gv8sxa/tts/dhg/zqy/data/FaceVid_240h/videos", num_frames=96, ) evalloader = DataLoader( evalset, 4, shuffle=False,drop_last=True,collate_fn=evalset.collate,num_workers=0 ) for data in evalloader: img, video, filename, fps = data["ref_img"], data["video"], data["filename"], data["fps"] print(img.shape) print(video.shape) print(filename) print(fps) break # from torch.utils.data import DataLoader # evalset = P2MEvalDataset( # ref_img_dir= "/mnt/pfs-gv8sxa/tts/dhg/zqy/data/FaceVid_240h/firstframes/fromvideo", # dwpose_dict_dir = "/mnt/pfs-gv8sxa/tts/dhg/zqy/data/FaceVid_240h/dwpose_dict", # num_frames=96, # random_dwpose=True # ) # evalloader = DataLoader( # evalset, 4, shuffle=False,drop_last=True,collate_fn=evalset.collate,num_workers=16 # ) # for data in evalloader: # img, source_pose, driven_poses = data["ref_img"], data["source_pose"], data["driven_poses"] # driven_pose_paths = data["driven_pose_paths"] # print(img.shape) # print(source_pose.shape) # print(driven_poses.shape) # print(driven_pose_paths) # break # evalset = EvalDataset( # audio_emb_dir = "/mnt/pfs-gv8sxa/tts/dhg/zqy/data/FaceVid_240h/whisper_embs", # dwpose_dir = "/mnt/pfs-gv8sxa/tts/dhg/zqy/data/FaceVid_240h/firstframes/fromdwpose", # ref_img_dir = "/mnt/pfs-gv8sxa/tts/dhg/zqy/data/FaceVid_240h/firstframes/fromvideo", # num_frames = 96, # random_audio=True, # random_dwpose=False, # audio_dir = "/mnt/pfs-gv8sxa/tts/dhg/zqy/data/FaceVid_240h/audios", # num_evals=4, # audio_suffix="wav" # ) # evalloader = DataLoader( # evalset, 4, shuffle=False,drop_last=True,collate_fn=evalset.collate,num_workers=16 # ) # for data in evalloader: # audio_emb, lmk, img = data["audio_emb"], data["dwpose"], data["ref_img"] # audio_path = data["audio_path"] # print(audio_emb.shape) # print(lmk.shape) # print(img.shape) # print(audio_path) # break # video_dir = "/mnt/pfs-gv8sxa/tts/dhg/zqy/data/FaceVid_240h/videos" # dwpose_video_dir = "/mnt/pfs-gv8sxa/tts/dhg/zqy/data/FaceVid_240h/videos_dwpose" # video_output_dir = "/mnt/pfs-gv8sxa/tts/dhg/zqy/data/FaceVid_240h/firstframes/fromvideo" # dwpose_video_output_dir = "/mnt/pfs-gv8sxa/tts/dhg/zqy/data/FaceVid_240h/firstframes/fromdwpose" # dw_ex = first_frame_extractor( # dwpose_video_dir, # dwpose_video_output_dir # ) # dw_ex.extract() # vex = first_frame_extractor( # video_dir, # video_output_dir # ) # vex.extract()