semo / pipeline /utils.py
HappyP4nda's picture
Upload folder using huggingface_hub
bd546bf verified
import einops
import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms
from typing import Optional
import os
import random
from PIL import Image
import traceback
import subprocess
from tqdm import tqdm
import av
from pathlib import Path
import numpy as np
import cv2
import math
from torchvision.io import read_video
import pickle
def vae_encode(vae,latents):
# video : N,T,C,H,W
latents_type = None
if len(latents.shape) == 5:
N,T,C,H,W = latents.shape
latents_type = 'video'
latents = einops.rearrange(latents,'n t c h w -> (n t) c h w')
else:
N,C,H,W = latents.shape
latents_type = 'image'
with torch.no_grad():
latents = vae.encode(latents).latent_dist
latents = latents.sample()
latents = latents * 0.18215
if latents_type == 'video':
latents = einops.rearrange(latents,'(n t) c h w -> n t c h w',n=N,t=T)
return latents
def vae_decode(vae,latents):
latents_type = None
if len(latents.shape) == 5:
N,T,C,H,W = latents.shape
latents_type = 'video'
latents = einops.rearrange(latents,'n t c h w -> (n t) c h w')
else:
N,C,H,W = latents.shape
latents_type = 'image'
latents = 1 / 0.18215 * latents
with torch.no_grad():
latents = vae.decode(latents).sample # (nt)chw
if latents_type == 'video':
latents = einops.rearrange(latents,'(n t) c h w -> n t c h w',n=N,t=T)
return latents
def lsdir(dir):
filenames = os.listdir(dir)
paths = list(map(
lambda x: os.path.join(dir, x), filenames
))
return paths
class A2MEvalDataset(Dataset):
def __init__(
self,
audio_emb_dir:str,
dwpose_dir:str,
ref_img_dir:str,
num_frames:int,
random_audio:bool,
random_dwpose:bool,
audio_dir:Optional[str]=None,
num_evals:Optional[int] = None,
audio_suffix:str = 'wav'
):
super().__init__()
self.audio_emb_dir = audio_emb_dir
self.audio_emb_paths = lsdir(audio_emb_dir)
ref_img_paths = lsdir(ref_img_dir)
self.ref_img_paths = sorted(ref_img_paths)
self.dwpose_dir = dwpose_dir
self.dwpose_paths = lsdir(dwpose_dir)
self.audio_dir = audio_dir if not audio_dir is None else None
self.num_evals = num_evals
self.availables = min(len(self.ref_img_paths), len(self.audio_emb_paths))
self.num_frames = num_frames
self.random_audio = random_audio
self.random_dwpose = random_dwpose
self.audio_suffix = audio_suffix
self.transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
def __len__(self):
return self.num_evals if not self.num_evals is None else self.availables
def __getitem__(self, index):
while True:
try:
sample = self.get(index)
break
except:
traceback.print_exc()
index = index + 1
return sample
def get(self,index:int):
ref_img_path = self.ref_img_paths[index]
filename = os.path.basename(ref_img_path).split(".")[0]
if self.random_audio:
audio_emb_path = random.sample(self.audio_emb_paths, 1)[0]
else:
audio_emb_path = os.path.join(self.audio_emb_dir, filename + '.pt')
if not os.path.exists(audio_emb_path):
raise ValueError("audio emb path not exists")
audio_filename = os.path.basename(audio_emb_path).split(".")[0]
if self.random_dwpose:
dwpose_path = random.sample(self.dwpose_paths, 1)[0]
else:
dwpose_path = os.path.join(self.dwpose_dir, filename + '.jpg')
if not os.path.exists(dwpose_path):
raise ValueError("dwpose path not exists")
if not self.audio_dir is None:
audio_path = os.path.join(self.audio_dir, audio_filename + '.' + self.audio_suffix)
if not os.path.exists(audio_path):
audio_path = None
else:
audio_path = None
audio_emb = torch.load(audio_emb_path)
if audio_emb.shape[0] < self.num_frames:
raise ValueError(f"audio too short, {audio_emb.shape}")
else:
audio_emb = audio_emb[:self.num_frames]
ref_img = Image.open(ref_img_path)
ref_img = self.transforms(ref_img)
dwpose = Image.open(dwpose_path)
dwpose = self.transforms(dwpose)
return audio_emb, ref_img, dwpose, audio_path, ref_img_path
@staticmethod
def collate(batch):
audio_emb = torch.stack([b[0] for b in batch])
ref_img = torch.stack([b[1] for b in batch])
dwpose = torch.stack([b[2] for b in batch])
audio_path = list([b[3] for b in batch])
refimg_path = list([b[4] for b in batch])
return dict(
audio_emb = audio_emb,
ref_img = ref_img,
dwpose = dwpose,
audio_path = audio_path,
refimg_path = refimg_path
)
class P2MEvalDataset(Dataset):
def __init__(
self,
ref_img_dir:str,
dwpose_dict_dir:str,
num_frames:int,
random_dwpose:bool,
num_evals:Optional[int] = None,
):
super().__init__()
ref_img_paths = lsdir(ref_img_dir)
self.ref_img_paths = sorted(ref_img_paths)
self.dwpose_dict_dir = dwpose_dict_dir
self.dwpose_dict_paths = lsdir(dwpose_dict_dir)
self.num_evals = num_evals
self.availables = min(len(self.ref_img_paths), len(self.dwpose_dict_paths))
self.num_frames = num_frames
self.random_dwpose = random_dwpose
self.transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
self.w = 256
self.h = 256
def __len__(self):
return self.num_evals if not self.num_evals is None else self.availables
def __getitem__(self, index):
while True:
try:
sample = self.get(index)
break
except:
traceback.print_exc()
index = index + 1
return sample
def get(self,index:int):
ref_img_path = self.ref_img_paths[index]
filename = os.path.basename(ref_img_path).split(".")[0]
if self.random_dwpose:
driven_dwpose_path = random.sample(self.dwpose_dict_paths, 1)[0]
else:
driven_dwpose_path = os.path.join(self.dwpose_dict_dir, filename + '.npy')
if not os.path.exists(driven_dwpose_path):
raise ValueError("driven dwpose path not exists")
source_dwpose_path = os.path.join(self.dwpose_dict_dir, filename + '.npy')
if not os.path.exists(source_dwpose_path):
raise ValueError("sourec dwpose path not exists")
driven_pose = np.load(driven_dwpose_path, allow_pickle=True)
if driven_pose.shape[0] < self.num_frames:
raise ValueError(f"driven pose too short. Total frames = {driven_pose.shape[0]}")
driven_pose = driven_pose[:self.num_frames]
source_pose = np.load(source_dwpose_path, allow_pickle=True)[0]
ref_img = Image.open(ref_img_path)
ref_img = self.transforms(ref_img)
driven_poses = []
for pose in driven_pose:
driven_pose_frame = align_pose(source_pose, pose, self.h, self.w)
driven_pose_frame = self.transforms(driven_pose_frame)
driven_poses.append(driven_pose_frame)
driven_poses = torch.stack(driven_poses, dim=0)
source_pose = draw_facebody(
np.zeros(self.w,self.h,3),
source_pose["faces"],
source_pose["bodies"]
)
source_pose = self.transforms(source_pose)
return ref_img, source_pose, driven_poses, driven_dwpose_path
@staticmethod
def collate(batch):
ref_img = torch.stack([b[0] for b in batch])
source_pose = torch.stack([b[1] for b in batch])
driven_poses = torch.stack([b[2] for b in batch])
driven_pose_paths = list([b[3] for b in batch])
return dict(
ref_img = ref_img,
source_pose = source_pose,
driven_poses = driven_poses,
driven_pose_paths = driven_pose_paths
)
class RecEvalDataset(Dataset):
def __init__(
self,
video_dir:str,
num_frames:int,
num_evals:Optional[int] = None,
):
super().__init__()
if video_dir.endswith(".pkl"):
self.video_paths = pickle.load(open(video_dir, "rb"))
else:
self.video_paths = lsdir(video_dir)
self.num_frames = num_frames
self.transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
self.length = min(num_evals, len(self.video_paths)) if not num_evals is None else len(self.video_paths)
self.w = 256
self.h = 256
def __len__(self):
return self.length
def __getitem__(self, index):
while True:
try:
sample = self.get(index)
break
except:
traceback.print_exc()
index = index + 1
return sample
def get(self,index:int):
video_path = self.video_paths[index]
filename = os.path.basename(video_path).split(".")[0]
video, _, aux = read_video(video_path, pts_unit="sec", output_format="TCHW")
fps = aux["video_fps"]
sample_frames = self.num_frames + 1
video_length = len(video)
clip_length = min(video_length, sample_frames)
start_idx = random.randint(0, video_length - clip_length)
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, sample_frames, dtype=int)
random_idx = random.randint(0, video_length - 1)
random_frame = video[random_idx] / 255.0
video = video[batch_index] / 255.0
video = self.transforms(video)
random_frame = self.transforms(random_frame)
ref_img = video[0]
video = video[1:]
return ref_img, video, filename, fps, random_frame
@staticmethod
def collate(batch):
ref_img = torch.stack([b[0] for b in batch])
video = torch.stack([b[1] for b in batch])
filename = list ([b[2] for b in batch])
fps = list([b[3] for b in batch])
random_frame = torch.stack([b[4] for b in batch])
return dict(
ref_img = ref_img,
video = video,
filename = filename,
fps = fps,
random_frame = random_frame
)
def read_frames(video_path):
container = av.open(video_path)
video_stream = next(s for s in container.streams if s.type == "video")
frames = []
for packet in container.demux(video_stream):
for frame in packet.decode():
image = Image.frombytes(
"RGB",
(frame.width, frame.height),
frame.to_rgb().to_ndarray(),
)
frames.append(image)
return frames
def save_videos_from_pil(pil_images, path, fps=8):
save_fmt = Path(path).suffix
os.makedirs(os.path.dirname(path), exist_ok=True)
width, height = pil_images[0].size
if save_fmt == ".mp4":
codec = "libx264"
container = av.open(path, "w")
stream = container.add_stream(codec, rate=fps)
stream.width = width
stream.height = height
for pil_image in pil_images:
# pil_image = Image.fromarray(image_arr).convert("RGB")
av_frame = av.VideoFrame.from_image(pil_image)
container.mux(stream.encode(av_frame))
container.mux(stream.encode())
container.close()
elif save_fmt == ".gif":
pil_images[0].save(
fp=path,
format="GIF",
append_images=pil_images[1:],
save_all=True,
duration=(1 / fps * 1000),
loop=0,
)
else:
raise ValueError("Unsupported file type. Use .mp4 or .gif.")
def align_face(face_1:np.ndarray,face_2:np.ndarray):
"""
Align face_1 to face_2
Input:
face: np.ndarray [68,2], -1 is non visible
Return:
face_1 after align, [68,2]
"""
face1_non_vis = face_1 == -1
face2_vis = face_2 > 0
face_vis = (face_1 > 0) * face2_vis
face_vis = face_vis[:,0] * face_vis[:,1]
face_1_vis = face_1[face_vis]
face_2_vis = face_2[face_vis]
print(face_1_vis.shape)
x_1, y_1 = face_1_vis.copy(), face_1_vis[:,1:]
x_1[:,1] = 1
y_1 = np.concatenate([y_1,np.ones(shape=(68,1))], axis=1)
x_2, y_2 = face_2_vis[:,0], face_2_vis[:,1]
s_x,t_x = np.linalg.inv(x_1.T @ x_1) @ x_1.T @ x_2
s_y,t_y = np.linalg.inv(y_1.T @ y_1) @ y_1.T @ y_2
trans = np.array([
s_x,0,t_x,
0,s_y,t_y,
0,0,1
]).reshape(3,3)
face_ret = np.concatenate([face_1, np.ones(shape=(68,1))],axis=1)
face_ret = (face_ret @ trans)[:,:2]
face_ret[face1_non_vis] = -1
return face_ret
def draw_facepose(canvas, lmks):
eps = 0.01
H, W = canvas.shape[:2]
for lmk in lmks:
x, y = lmk
x = int(x * W)
y = int(y * H)
if x > eps and y > eps:
cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
return canvas
def align_body(body_1,body_2):
cdd_1,cdd_2 = body_1["candidate"].copy(), body_2["candidate"].copy()
sub_1,sub_2 = body_1["subset"][0].copy(), body_2["subset"][0].copy()
sub1_vis = sub_1 >= 0
sub2_vis = sub_2 >= 0
vis = sub1_vis * sub2_vis
num_points = np.sum(vis)
if num_points < 3:
return body_2
cdd1_vis, cdd2_vis = cdd_1[vis], cdd_2[vis]
x_1, y_1 = cdd1_vis.copy(), cdd1_vis[:,1:]
x_1[:,1] = 1
y_1 = np.concatenate([y_1, np.ones(shape=(num_points,1))], axis=1)
x_2, y_2 = cdd2_vis[:,0], cdd2_vis[:,1]
s_x, t_x = np.linalg.inv(x_1.T @ x_1) @ x_1.T @ x_2
s_y, t_y = np.linalg.inv(y_1.T @ y_1) @ y_1.T @ y_2
cdd_1[:,0] = cdd_1[:,0] * s_x + t_x
cdd_1[:,1] = cdd_1[:,1] * s_y + t_y
subset_vis = sub_1 * (sub1_vis * 2 - 1) * (vis * 2 - 1)
return dict(candidate=cdd_1,subset=subset_vis[np.newaxis,:])
def draw_bodypose(canvas, candidate, subset):
H, W, C = canvas.shape
candidate = np.array(candidate)
subset = np.array(subset)
stickwidth = 4
limbSeq = [
[2, 3],
[2, 6],
[3, 4],
[4, 5],
[6, 7],
[7, 8],
[2, 9],
[9, 10],
[10, 11],
[2, 12],
[12, 13],
[13, 14],
[2, 1],
[1, 15],
[15, 17],
[1, 16],
[16, 18],
[3, 17],
[6, 18],
]
colors = [
[255, 0, 0],
[255, 85, 0],
[255, 170, 0],
[255, 255, 0],
[170, 255, 0],
[85, 255, 0],
[0, 255, 0],
[0, 255, 85],
[0, 255, 170],
[0, 255, 255],
[0, 170, 255],
[0, 85, 255],
[0, 0, 255],
[85, 0, 255],
[170, 0, 255],
[255, 0, 255],
[255, 0, 170],
[255, 0, 85],
]
for i in range(17):
for n in range(len(subset)):
index = subset[n][np.array(limbSeq[i]) - 1]
if -1 in index:
continue
Y = candidate[index.astype(int), 0] * float(W)
X = candidate[index.astype(int), 1] * float(H)
mX = np.mean(X)
mY = np.mean(Y)
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly(
(int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1
)
cv2.fillConvexPoly(canvas, polygon, colors[i])
canvas = (canvas * 0.6).astype(np.uint8)
for i in range(18):
for n in range(len(subset)):
index = int(subset[n][i])
if index == -1:
continue
x, y = candidate[index][0:2]
x = int(x * W)
y = int(y * H)
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
return canvas
def draw_facebody(canvas,face,body):
canvas = draw_bodypose(canvas,body["candidate"], body["subset"])
canvas = draw_facepose(canvas,face)
return canvas
def align_pose(source_pose_dict, target_pose_dict, height:int = 256, width:int = 256):
face_1 = target_pose_dict["faces"][0]
face_2 = source_pose_dict["faces"][0]
body_1 = target_pose_dict["bodies"]
body_2 = source_pose_dict["bodies"]
cdd_1,cdd_2 = body_1["candidate"].copy(), body_2["candidate"].copy()
sub_1,sub_2 = body_1["subset"][0].copy(), body_2["subset"][0].copy()
sub1_vis = sub_1 >= 0
sub2_vis = sub_2 >= 0
vis = sub1_vis * sub2_vis
num_points = np.sum(vis)
cdd1_vis, cdd2_vis = cdd_1[vis], cdd_2[vis]
body_x_1, body_y_1 = cdd1_vis.copy(), cdd1_vis[:,1:]
body_x_1[:,1] = 1
body_y_1 = np.concatenate([body_y_1, np.ones(shape=(num_points,1))], axis=1)
body_x_2, body_y_2 = cdd2_vis[:,0], cdd2_vis[:,1]
subset_vis = sub_1 * (sub1_vis * 2 - 1) * (vis * 2 - 1)
face1_non_vis = face_1 == -1
face2_vis = face_2 > 0
face_vis = (face_1 > 0) * face2_vis
face_vis = face_vis[:,0] * face_vis[:,1]
face_1_vis = face_1[face_vis]
face_2_vis = face_2[face_vis]
face_x_1, face_y_1 = face_1_vis.copy(), face_1_vis[:,1:]
face_x_1[:,1] = 1
face_y_1 = np.concatenate([face_y_1,np.ones(shape=(68,1))], axis=1)
face_x_2, face_y_2 = face_2_vis[:,0], face_2_vis[:,1]
x_1 = np.concatenate([body_x_1, face_x_1], axis=0)
y_1 = np.concatenate([body_y_1, face_y_1], axis=0)
x_2 = np.concatenate([body_x_2, face_x_2], axis=0)
y_2 = np.concatenate([body_y_2, face_y_2], axis=0)
s_x,t_x = np.linalg.inv(x_1.T @ x_1) @ x_1.T @ x_2
s_y,t_y = np.linalg.inv(y_1.T @ y_1) @ y_1.T @ y_2
cdd_1[:,0] = cdd_1[:,0] * s_x + t_x
cdd_1[:,1] = cdd_1[:,1] * s_y + t_y
face_ret = np.copy(face_1)
face_ret[:,0] = face_ret[:,0] * s_x + t_x
face_ret[:,1] = face_ret[:,1] * s_y + t_y
face_ret[face1_non_vis] = -1
body_ret = dict(candidate=cdd_1,subset=subset_vis[np.newaxis,:])
pil = draw_facebody(
np.zeros(shape=(height, width, 3), dtype=np.uint8),
face_ret, body_ret
)
pil = Image.fromarray(pil)
return pil
class first_frame_extractor:
def __init__(self, video_dir:str, output_dir:str):
self.video_dir = video_dir
self.output_dir = output_dir
self.video_paths = lsdir(video_dir)
def extract(self):
for vp in tqdm(self.video_paths):
out = os.path.join(self.output_dir, os.path.basename(vp).split('.')[0] + ".jpg")
command = [
'ffmpeg',
'-hide_banner',
'-y',
'-i', vp,
'-vf', 'scale=256:256',
'-vframes', '1',
'-ss', '00:00:00',
out
]
subprocess.run(command)
if __name__ == "__main__":
from torch.utils.data import DataLoader
evalset = RecEvalDataset(
video_dir= "/mnt/pfs-gv8sxa/tts/dhg/zqy/data/FaceVid_240h/videos",
num_frames=96,
)
evalloader = DataLoader(
evalset, 4, shuffle=False,drop_last=True,collate_fn=evalset.collate,num_workers=0
)
for data in evalloader:
img, video, filename, fps = data["ref_img"], data["video"], data["filename"], data["fps"]
print(img.shape)
print(video.shape)
print(filename)
print(fps)
break
# from torch.utils.data import DataLoader
# evalset = P2MEvalDataset(
# ref_img_dir= "/mnt/pfs-gv8sxa/tts/dhg/zqy/data/FaceVid_240h/firstframes/fromvideo",
# dwpose_dict_dir = "/mnt/pfs-gv8sxa/tts/dhg/zqy/data/FaceVid_240h/dwpose_dict",
# num_frames=96,
# random_dwpose=True
# )
# evalloader = DataLoader(
# evalset, 4, shuffle=False,drop_last=True,collate_fn=evalset.collate,num_workers=16
# )
# for data in evalloader:
# img, source_pose, driven_poses = data["ref_img"], data["source_pose"], data["driven_poses"]
# driven_pose_paths = data["driven_pose_paths"]
# print(img.shape)
# print(source_pose.shape)
# print(driven_poses.shape)
# print(driven_pose_paths)
# break
# evalset = EvalDataset(
# audio_emb_dir = "/mnt/pfs-gv8sxa/tts/dhg/zqy/data/FaceVid_240h/whisper_embs",
# dwpose_dir = "/mnt/pfs-gv8sxa/tts/dhg/zqy/data/FaceVid_240h/firstframes/fromdwpose",
# ref_img_dir = "/mnt/pfs-gv8sxa/tts/dhg/zqy/data/FaceVid_240h/firstframes/fromvideo",
# num_frames = 96,
# random_audio=True,
# random_dwpose=False,
# audio_dir = "/mnt/pfs-gv8sxa/tts/dhg/zqy/data/FaceVid_240h/audios",
# num_evals=4,
# audio_suffix="wav"
# )
# evalloader = DataLoader(
# evalset, 4, shuffle=False,drop_last=True,collate_fn=evalset.collate,num_workers=16
# )
# for data in evalloader:
# audio_emb, lmk, img = data["audio_emb"], data["dwpose"], data["ref_img"]
# audio_path = data["audio_path"]
# print(audio_emb.shape)
# print(lmk.shape)
# print(img.shape)
# print(audio_path)
# break
# video_dir = "/mnt/pfs-gv8sxa/tts/dhg/zqy/data/FaceVid_240h/videos"
# dwpose_video_dir = "/mnt/pfs-gv8sxa/tts/dhg/zqy/data/FaceVid_240h/videos_dwpose"
# video_output_dir = "/mnt/pfs-gv8sxa/tts/dhg/zqy/data/FaceVid_240h/firstframes/fromvideo"
# dwpose_video_output_dir = "/mnt/pfs-gv8sxa/tts/dhg/zqy/data/FaceVid_240h/firstframes/fromdwpose"
# dw_ex = first_frame_extractor(
# dwpose_video_dir,
# dwpose_video_output_dir
# )
# dw_ex.extract()
# vex = first_frame_extractor(
# video_dir,
# video_output_dir
# )
# vex.extract()