MimicTalk: Mimicking a personalized and expressive 3D talking face in minutes (NIPS 2024)
\
+
Homepage ")
+
+ sources = None
+ with gr.Row():
+ with gr.Column(variant='panel'):
+ with gr.Tabs(elem_id="source_image"):
+ with gr.TabItem('Upload Training Video'):
+ with gr.Row():
+ src_video_name = gr.Video(label="Source video (required for training)", sources=sources, value="data/raw/videos/German_20s.mp4")
+ # src_video_name = gr.Image(label="Source video (required for training)", sources=sources, type="filepath", value="data/raw/videos/German_20s.mp4")
+ with gr.Tabs(elem_id="driven_audio"):
+ with gr.TabItem('Upload Driving Audio'):
+ with gr.Column(variant='panel'):
+ drv_audio_name = gr.Audio(label="Input audio (required for inference)", sources=sources, type="filepath", value="data/raw/examples/80_vs_60_10s.wav")
+ with gr.Tabs(elem_id="driven_style"):
+ with gr.TabItem('Upload Style Prompt'):
+ with gr.Column(variant='panel'):
+ drv_style_name = gr.Video(label="Driven Style (optional for inference)", sources=sources, value="data/raw/videos/German_20s.mp4")
+ with gr.Tabs(elem_id="driven_pose"):
+ with gr.TabItem('Upload Driving Pose'):
+ with gr.Column(variant='panel'):
+ drv_pose_name = gr.Video(label="Driven Pose (optional for inference)", sources=sources, value="data/raw/videos/German_20s.mp4")
+ with gr.Tabs(elem_id="bg_image"):
+ with gr.TabItem('Upload Background Image'):
+ with gr.Row():
+ bg_image_name = gr.Image(label="Background image (optional for inference)", sources=sources, type="filepath", value=None)
+
+
+ with gr.Column(variant='panel'):
+ with gr.Tabs(elem_id="checkbox"):
+ with gr.TabItem('General Settings'):
+ with gr.Column(variant='panel'):
+
+ blink_mode = gr.Radio(['none', 'period'], value='period', label='blink mode', info="whether to blink periodly") #
+ min_face_area_percent = gr.Slider(minimum=0.15, maximum=0.5, step=0.01, label="min_face_area_percent", value=0.2, info='The minimum face area percent in the output frame, to prevent bad cases caused by a too small face.',)
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.025, label="temperature", value=0.2, info='audio to secc temperature',)
+ cfg_scale = gr.Slider(minimum=1.0, maximum=3.0, step=0.025, label="talking style cfg_scale", value=1.5, info='higher -> encourage the generated motion more coherent to talking style',)
+ out_mode = gr.Radio(['final', 'concat_debug'], value='concat_debug', label='output layout', info="final: only final output ; concat_debug: final output concated with internel features")
+ low_memory_usage = gr.Checkbox(label="Low Memory Usage Mode: save memory at the expense of lower inference speed. Useful when running a low audio (minutes-long).", value=False)
+ map_to_init_pose = gr.Checkbox(label="Whether to map pose of first frame to initial pose", value=True)
+ hold_eye_opened = gr.Checkbox(label="Whether to maintain eyes always open")
+
+ train_submit = gr.Button('Train', elem_id="train", variant='primary')
+ infer_submit = gr.Button('Generate', elem_id="generate", variant='primary')
+
+ with gr.Tabs(elem_id="genearted_video"):
+ info_box = gr.Textbox(label="Error", interactive=False, visible=False)
+ gen_video = gr.Video(label="Generated video", format="mp4", visible=True)
+ with gr.Column(variant='panel'):
+ with gr.Tabs(elem_id="checkbox"):
+ with gr.TabItem('Checkpoints'):
+ with gr.Column(variant='panel'):
+ ckpt_info_box = gr.Textbox(value="Please select \"ckpt\" under the checkpoint folder ", interactive=False, visible=True, show_label=False)
+ audio2secc_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=audio2secc_dir, file_count='single', label='audio2secc model ckpt path or directory')
+ # head_model_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=head_model_dir, file_count='single', label='head model ckpt path or directory (will be ignored if torso model is set)')
+ torso_model_dir = gr.FileExplorer(glob="checkpoints_mimictalk/**/*.ckpt", value=torso_model_dir, file_count='single', label='mimictalk model ckpt path or directory')
+ # audio2secc_dir = gr.Textbox(audio2secc_dir, max_lines=1, label='audio2secc model ckpt path or directory (will be ignored if torso model is set)')
+ # head_model_dir = gr.Textbox(head_model_dir, max_lines=1, label='head model ckpt path or directory (will be ignored if torso model is set)')
+ # torso_model_dir = gr.Textbox(torso_model_dir, max_lines=1, label='torso model ckpt path or directory')
+
+
+ fn = infer_obj.infer_once_args
+ if warpfn:
+ fn = warpfn(fn)
+ infer_submit.click(
+ fn=fn,
+ inputs=[
+ drv_audio_name,
+ drv_pose_name,
+ drv_style_name,
+ bg_image_name,
+ blink_mode,
+ temperature,
+ cfg_scale,
+ out_mode,
+ map_to_init_pose,
+ low_memory_usage,
+ hold_eye_opened,
+ audio2secc_dir,
+ # head_model_dir,
+ torso_model_dir,
+ min_face_area_percent,
+ ],
+ outputs=[
+ gen_video,
+ info_box,
+ ],
+ )
+
+ fn_train = train_obj.train_once_args
+
+ train_submit.click(
+ fn=fn_train,
+ inputs=[
+ src_video_name,
+
+ ],
+ outputs=[
+ # gen_video,
+ info_box,
+ torso_model_dir,
+ ],
+ )
+
+ print(sep_line)
+ print("Gradio page is constructed.")
+ print(sep_line)
+
+ return real3dportrait_interface
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--a2m_ckpt", default='checkpoints/240112_icl_audio2secc_vox2_cmlr') # checkpoints/0727_audio2secc/audio2secc_withlm2d100_randomframe
+ parser.add_argument("--head_ckpt", default='') # checkpoints/0729_th1kh/secc_img2plane checkpoints/0720_img2planes/secc_img2plane_two_stage
+ parser.add_argument("--torso_ckpt", default='checkpoints_mimictalk/German_20s/model_ckpt_steps_10000.ckpt')
+ parser.add_argument("--port", type=int, default=None)
+ parser.add_argument("--server", type=str, default='127.0.0.1')
+ parser.add_argument("--share", action='store_true', dest='share', help='share srever to Internet')
+
+ args = parser.parse_args()
+ demo = mimictalk_demo(
+ audio2secc_dir=args.a2m_ckpt,
+ head_model_dir=args.head_ckpt,
+ torso_model_dir=args.torso_ckpt,
+ device='cuda:0',
+ warpfn=None,
+ )
+ demo.queue()
+ demo.launch(share=args.share, server_name=args.server, server_port=args.port)
diff --git a/inference/app_real3dportrait.py b/inference/app_real3dportrait.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d231a538dd1117f82ecc49bc5df92b7d6f2db81
--- /dev/null
+++ b/inference/app_real3dportrait.py
@@ -0,0 +1,247 @@
+import os, sys
+sys.path.append('./')
+import argparse
+import gradio as gr
+from inference.real3d_infer import GeneFace2Infer
+from utils.commons.hparams import hparams
+
+class Inferer(GeneFace2Infer):
+ def infer_once_args(self, *args, **kargs):
+ assert len(kargs) == 0
+ keys = [
+ 'src_image_name',
+ 'drv_audio_name',
+ 'drv_pose_name',
+ 'bg_image_name',
+ 'blink_mode',
+ 'temperature',
+ 'mouth_amp',
+ 'out_mode',
+ 'map_to_init_pose',
+ 'low_memory_usage',
+ 'hold_eye_opened',
+ 'a2m_ckpt',
+ 'head_ckpt',
+ 'torso_ckpt',
+ 'min_face_area_percent',
+ ]
+ inp = {}
+ out_name = None
+ info = ""
+
+ try: # try to catch errors and jump to return
+ for key_index in range(len(keys)):
+ key = keys[key_index]
+ inp[key] = args[key_index]
+ if '_name' in key:
+ inp[key] = inp[key] if inp[key] is not None else ''
+
+ if inp['src_image_name'] == '':
+ info = "Input Error: Source image is REQUIRED!"
+ raise ValueError
+ if inp['drv_audio_name'] == '' and inp['drv_pose_name'] == '':
+ info = "Input Error: At least one of driving audio or video is REQUIRED!"
+ raise ValueError
+
+
+ if inp['drv_audio_name'] == '' and inp['drv_pose_name'] != '':
+ inp['drv_audio_name'] = inp['drv_pose_name']
+ print("No audio input, we use driving pose video for video driving")
+
+ if inp['drv_pose_name'] == '':
+ inp['drv_pose_name'] = 'static'
+
+ reload_flag = False
+ if inp['a2m_ckpt'] != self.audio2secc_dir:
+ print("Changes of a2m_ckpt detected, reloading model")
+ reload_flag = True
+ if inp['head_ckpt'] != self.head_model_dir:
+ print("Changes of head_ckpt detected, reloading model")
+ reload_flag = True
+ if inp['torso_ckpt'] != self.torso_model_dir:
+ print("Changes of torso_ckpt detected, reloading model")
+ reload_flag = True
+
+ inp['out_name'] = ''
+ inp['seed'] = 42
+
+ print(f"infer inputs : {inp}")
+
+ try:
+ if reload_flag:
+ self.__init__(inp['a2m_ckpt'], inp['head_ckpt'], inp['torso_ckpt'], inp=inp, device=self.device)
+ except Exception as e:
+ content = f"{e}"
+ info = f"Reload ERROR: {content}"
+ raise ValueError
+ try:
+ out_name = self.infer_once(inp)
+ except Exception as e:
+ content = f"{e}"
+ info = f"Inference ERROR: {content}"
+ raise ValueError
+ except Exception as e:
+ if info == "": # unexpected errors
+ content = f"{e}"
+ info = f"WebUI ERROR: {content}"
+
+ # output part
+ if len(info) > 0 : # there is errors
+ print(info)
+ info_gr = gr.update(visible=True, value=info)
+ else: # no errors
+ info_gr = gr.update(visible=False, value=info)
+ if out_name is not None and len(out_name) > 0 and os.path.exists(out_name): # good output
+ print(f"Succefully generated in {out_name}")
+ video_gr = gr.update(visible=True, value=out_name)
+ else:
+ print(f"Failed to generate")
+ video_gr = gr.update(visible=True, value=out_name)
+
+ return video_gr, info_gr
+
+def toggle_audio_file(choice):
+ if choice == False:
+ return gr.update(visible=True), gr.update(visible=False)
+ else:
+ return gr.update(visible=False), gr.update(visible=True)
+
+def ref_video_fn(path_of_ref_video):
+ if path_of_ref_video is not None:
+ return gr.update(value=True)
+ else:
+ return gr.update(value=False)
+
+def real3dportrait_demo(
+ audio2secc_dir,
+ head_model_dir,
+ torso_model_dir,
+ device = 'cuda',
+ warpfn = None,
+ ):
+
+ sep_line = "-" * 40
+
+ infer_obj = Inferer(
+ audio2secc_dir=audio2secc_dir,
+ head_model_dir=head_model_dir,
+ torso_model_dir=torso_model_dir,
+ device=device,
+ )
+
+ print(sep_line)
+ print("Model loading is finished.")
+ print(sep_line)
+ with gr.Blocks(analytics_enabled=False) as real3dportrait_interface:
+ gr.Markdown("\
+
Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis (ICLR 2024 Spotlight)
\
+
Arxiv \
+
Homepage \
+
Github ")
+
+ sources = None
+ with gr.Row():
+ with gr.Column(variant='panel'):
+ with gr.Tabs(elem_id="source_image"):
+ with gr.TabItem('Upload image'):
+ with gr.Row():
+ src_image_name = gr.Image(label="Source image (required)", sources=sources, type="filepath", value="data/raw/examples/Macron.png")
+ with gr.Tabs(elem_id="driven_audio"):
+ with gr.TabItem('Upload audio'):
+ with gr.Column(variant='panel'):
+ drv_audio_name = gr.Audio(label="Input audio (required for audio-driven)", sources=sources, type="filepath", value="data/raw/examples/Obama_5s.wav")
+ with gr.Tabs(elem_id="driven_pose"):
+ with gr.TabItem('Upload video'):
+ with gr.Column(variant='panel'):
+ drv_pose_name = gr.Video(label="Driven Pose (required for video-driven, optional for audio-driven)", sources=sources, value="data/raw/examples/May_5s.mp4")
+ with gr.Tabs(elem_id="bg_image"):
+ with gr.TabItem('Upload image'):
+ with gr.Row():
+ bg_image_name = gr.Image(label="Background image (optional)", sources=sources, type="filepath", value="data/raw/examples/bg.png")
+
+
+ with gr.Column(variant='panel'):
+ with gr.Tabs(elem_id="checkbox"):
+ with gr.TabItem('General Settings'):
+ with gr.Column(variant='panel'):
+
+ blink_mode = gr.Radio(['none', 'period'], value='period', label='blink mode', info="whether to blink periodly") #
+ min_face_area_percent = gr.Slider(minimum=0.15, maximum=0.5, step=0.01, label="min_face_area_percent", value=0.2, info='The minimum face area percent in the output frame, to prevent bad cases caused by a too small face.',)
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.025, label="temperature", value=0.2, info='audio to secc temperature',)
+ mouth_amp = gr.Slider(minimum=0.0, maximum=1.0, step=0.025, label="mouth amplitude", value=0.45, info='higher -> mouth will open wider, default to be 0.4',)
+ out_mode = gr.Radio(['final', 'concat_debug'], value='concat_debug', label='output layout', info="final: only final output ; concat_debug: final output concated with internel features")
+ low_memory_usage = gr.Checkbox(label="Low Memory Usage Mode: save memory at the expense of lower inference speed. Useful when running a low audio (minutes-long).", value=False)
+ map_to_init_pose = gr.Checkbox(label="Whether to map pose of first frame to initial pose", value=True)
+ hold_eye_opened = gr.Checkbox(label="Whether to maintain eyes always open")
+
+ submit = gr.Button('Generate', elem_id="generate", variant='primary')
+
+ with gr.Tabs(elem_id="genearted_video"):
+ info_box = gr.Textbox(label="Error", interactive=False, visible=False)
+ gen_video = gr.Video(label="Generated video", format="mp4", visible=True)
+ with gr.Column(variant='panel'):
+ with gr.Tabs(elem_id="checkbox"):
+ with gr.TabItem('Checkpoints'):
+ with gr.Column(variant='panel'):
+ ckpt_info_box = gr.Textbox(value="Please select \"ckpt\" under the checkpoint folder ", interactive=False, visible=True, show_label=False)
+ audio2secc_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=audio2secc_dir, file_count='single', label='audio2secc model ckpt path or directory')
+ head_model_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=head_model_dir, file_count='single', label='head model ckpt path or directory (will be ignored if torso model is set)')
+ torso_model_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=torso_model_dir, file_count='single', label='torso model ckpt path or directory')
+ # audio2secc_dir = gr.Textbox(audio2secc_dir, max_lines=1, label='audio2secc model ckpt path or directory (will be ignored if torso model is set)')
+ # head_model_dir = gr.Textbox(head_model_dir, max_lines=1, label='head model ckpt path or directory (will be ignored if torso model is set)')
+ # torso_model_dir = gr.Textbox(torso_model_dir, max_lines=1, label='torso model ckpt path or directory')
+
+
+ fn = infer_obj.infer_once_args
+ if warpfn:
+ fn = warpfn(fn)
+ submit.click(
+ fn=fn,
+ inputs=[
+ src_image_name,
+ drv_audio_name,
+ drv_pose_name,
+ bg_image_name,
+ blink_mode,
+ temperature,
+ mouth_amp,
+ out_mode,
+ map_to_init_pose,
+ low_memory_usage,
+ hold_eye_opened,
+ audio2secc_dir,
+ head_model_dir,
+ torso_model_dir,
+ min_face_area_percent,
+ ],
+ outputs=[
+ gen_video,
+ info_box,
+ ],
+ )
+
+ print(sep_line)
+ print("Gradio page is constructed.")
+ print(sep_line)
+
+ return real3dportrait_interface
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--a2m_ckpt", type=str, default='checkpoints/240210_real3dportrait_orig/audio2secc_vae/model_ckpt_steps_400000.ckpt')
+ parser.add_argument("--head_ckpt", type=str, default='')
+ parser.add_argument("--torso_ckpt", type=str, default='checkpoints/240210_real3dportrait_orig/secc2plane_torso_orig/model_ckpt_steps_100000.ckpt')
+ parser.add_argument("--port", type=int, default=None)
+ parser.add_argument("--server", type=str, default='127.0.0.1')
+ parser.add_argument("--share", action='store_true', dest='share', help='share srever to Internet')
+
+ args = parser.parse_args()
+ demo = real3dportrait_demo(
+ audio2secc_dir=args.a2m_ckpt,
+ head_model_dir=args.head_ckpt,
+ torso_model_dir=args.torso_ckpt,
+ device='cuda:0',
+ warpfn=None,
+ )
+ demo.queue()
+ demo.launch(share=args.share, server_name=args.server, server_port=args.port)
diff --git a/inference/edit_secc.py b/inference/edit_secc.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c1e602b389665c2710eb76e5ab1244030096db2
--- /dev/null
+++ b/inference/edit_secc.py
@@ -0,0 +1,147 @@
+import cv2
+import torch
+from utils.commons.image_utils import dilate, erode
+from sklearn.neighbors import NearestNeighbors
+import copy
+import numpy as np
+from utils.commons.meters import Timer
+
+def hold_eye_opened_for_secc(img):
+ img = img.permute(1,2,0).cpu().numpy()
+ img = ((img +1)/2*255).astype(np.uint)
+ face_mask = (img[...,0] != 0) & (img[...,1] != 0) & (img[...,2] != 0)
+ face_xys = np.stack(np.nonzero(face_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
+ h,w = face_mask.shape
+ # get face and eye mask
+ left_eye_prior_reigon = np.zeros([h,w], dtype=bool)
+ right_eye_prior_reigon = np.zeros([h,w], dtype=bool)
+ left_eye_prior_reigon[h//4:h//2, w//4:w//2] = True
+ right_eye_prior_reigon[h//4:h//2, w//2:w//4*3] = True
+ eye_prior_reigon = left_eye_prior_reigon | right_eye_prior_reigon
+ coarse_eye_mask = (~ face_mask) & eye_prior_reigon
+ coarse_eye_xys = np.stack(np.nonzero(coarse_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
+
+ opened_eye_mask = cv2.imread('inference/os_avatar/opened_eye_mask.png')
+ opened_eye_mask = torch.nn.functional.interpolate(torch.tensor(opened_eye_mask).permute(2,0,1).unsqueeze(0), size=(img.shape[0], img.shape[1]), mode='nearest')[0].permute(1,2,0).sum(-1).bool().cpu() # [512,512,3]
+ coarse_opened_eye_xys = np.stack(np.nonzero(opened_eye_mask)) # [N_nonbg,2] coordinate of non-face pixels
+
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(coarse_eye_xys)
+ dists, _ = nbrs.kneighbors(coarse_opened_eye_xys) # [512*512, 1] distance to nearest non-bg pixel
+ # print(dists.max())
+ non_opened_eye_pixs = dists > max(dists.max()*0.75, 4) # 大于这个距离的opened eye部分会被合上
+ non_opened_eye_pixs = non_opened_eye_pixs.reshape([-1])
+ opened_eye_xys_to_erode = coarse_opened_eye_xys[non_opened_eye_pixs]
+ opened_eye_mask[opened_eye_xys_to_erode[...,0], opened_eye_xys_to_erode[...,1]] = False # shrink 将mask在face-eye边界收缩3pixel,为了平滑
+
+ img[opened_eye_mask] = 0
+ return torch.tensor(img.astype(np.float32) / 127.5 - 1).permute(2,0,1)
+
+
+# def hold_eye_opened_for_secc(img):
+# img = copy.copy(img)
+# eye_mask = cv2.imread('inference/os_avatar/opened_eye_mask.png')
+# eye_mask = torch.nn.functional.interpolate(torch.tensor(eye_mask).permute(2,0,1).unsqueeze(0), size=(img.shape[-2], img.shape[-1]), mode='nearest')[0].bool().to(img.device) # [3,512,512]
+# img[eye_mask] = -1
+# return img
+
+def blink_eye_for_secc(img, close_eye_percent=0.5):
+ """
+ secc_img: [3,h,w], tensor, -1~1
+ """
+ img = img.permute(1,2,0).cpu().numpy()
+ img = ((img +1)/2*255).astype(np.uint)
+ assert close_eye_percent <= 1.0 and close_eye_percent >= 0.
+ if close_eye_percent == 0: return torch.tensor(img.astype(np.float32) / 127.5 - 1).permute(2,0,1)
+ img = copy.deepcopy(img)
+ face_mask = (img[...,0] != 0) & (img[...,1] != 0) & (img[...,2] != 0)
+ h,w = face_mask.shape
+
+ # get face and eye mask
+ left_eye_prior_reigon = np.zeros([h,w], dtype=bool)
+ right_eye_prior_reigon = np.zeros([h,w], dtype=bool)
+ left_eye_prior_reigon[h//4:h//2, w//4:w//2] = True
+ right_eye_prior_reigon[h//4:h//2, w//2:w//4*3] = True
+ eye_prior_reigon = left_eye_prior_reigon | right_eye_prior_reigon
+ coarse_eye_mask = (~ face_mask) & eye_prior_reigon
+ coarse_left_eye_mask = (~ face_mask) & left_eye_prior_reigon
+ coarse_right_eye_mask = (~ face_mask) & right_eye_prior_reigon
+ coarse_eye_xys = np.stack(np.nonzero(coarse_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
+ min_h = coarse_eye_xys[:, 0].min()
+ max_h = coarse_eye_xys[:, 0].max()
+ coarse_left_eye_xys = np.stack(np.nonzero(coarse_left_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
+ left_min_w = coarse_left_eye_xys[:, 1].min()
+ left_max_w = coarse_left_eye_xys[:, 1].max()
+ coarse_right_eye_xys = np.stack(np.nonzero(coarse_right_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
+ right_min_w = coarse_right_eye_xys[:, 1].min()
+ right_max_w = coarse_right_eye_xys[:, 1].max()
+
+ # 尽力较少需要考虑的face_xyz,以降低KNN的损耗
+ left_eye_prior_reigon = np.zeros([h,w], dtype=bool)
+ more_room = 4 # 过小会导致一些问题
+ left_eye_prior_reigon[min_h-more_room:max_h+more_room, left_min_w-more_room:left_max_w+more_room] = True
+ right_eye_prior_reigon = np.zeros([h,w], dtype=bool)
+ right_eye_prior_reigon[min_h-more_room:max_h+more_room, right_min_w-more_room:right_max_w+more_room] = True
+ eye_prior_reigon = left_eye_prior_reigon | right_eye_prior_reigon
+
+ around_eye_face_mask = face_mask & eye_prior_reigon
+ face_mask = around_eye_face_mask
+ face_xys = np.stack(np.nonzero(around_eye_face_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
+
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(coarse_eye_xys)
+ dists, _ = nbrs.kneighbors(face_xys) # [512*512, 1] distance to nearest non-bg pixel
+ face_pixs = dists > 5 # 只有距离最近的eye pixel大于5的才被认为是face,过小会导致一些问题
+ face_pixs = face_pixs.reshape([-1])
+ face_xys_to_erode = face_xys[~face_pixs]
+ face_mask[face_xys_to_erode[...,0], face_xys_to_erode[...,1]] = False # shrink 将mask在face-eye边界收缩3pixel,为了平滑
+ eye_mask = (~ face_mask) & eye_prior_reigon
+
+ h_grid = np.mgrid[0:h, 0:w][0]
+ eye_num_pixel_along_w_axis = eye_mask.sum(axis=0)
+ eye_mask_along_w_axis = eye_num_pixel_along_w_axis != 0
+
+ tmp_h_grid = h_grid.copy()
+ tmp_h_grid[~eye_mask] = 0
+ eye_mean_h_coord_along_w_axis = tmp_h_grid.sum(axis=0) / np.clip(eye_num_pixel_along_w_axis, a_min=1, a_max=h)
+ tmp_h_grid = h_grid.copy()
+ tmp_h_grid[~eye_mask] = 99999
+ eye_min_h_coord_along_w_axis = tmp_h_grid.min(axis=0)
+ tmp_h_grid = h_grid.copy()
+ tmp_h_grid[~eye_mask] = -99999
+ eye_max_h_coord_along_w_axis = tmp_h_grid.max(axis=0)
+
+ eye_low_h_coord_along_w_axis = close_eye_percent * eye_mean_h_coord_along_w_axis + (1-close_eye_percent) * eye_min_h_coord_along_w_axis # upper eye
+ eye_high_h_coord_along_w_axis = close_eye_percent * eye_mean_h_coord_along_w_axis + (1-close_eye_percent) * eye_max_h_coord_along_w_axis # lower eye
+
+ tmp_h_grid = h_grid.copy()
+ tmp_h_grid[~eye_mask] = 99999
+ upper_eye_blink_mask = tmp_h_grid <= eye_low_h_coord_along_w_axis
+ tmp_h_grid = h_grid.copy()
+ tmp_h_grid[~eye_mask] = -99999
+ lower_eye_blink_mask = tmp_h_grid >= eye_high_h_coord_along_w_axis
+ eye_blink_mask = upper_eye_blink_mask | lower_eye_blink_mask
+
+ face_xys = np.stack(np.nonzero(around_eye_face_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
+ eye_blink_xys = np.stack(np.nonzero(eye_blink_mask)).transpose(1, 0) # [N_nonbg,hw] coordinate of non-face pixels
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(face_xys)
+ distances, indices = nbrs.kneighbors(eye_blink_xys)
+ bg_fg_xys = face_xys[indices[:, 0]]
+ img[eye_blink_xys[:, 0], eye_blink_xys[:, 1], :] = img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
+ return torch.tensor(img.astype(np.float32) / 127.5 - 1).permute(2,0,1)
+
+
+if __name__ == '__main__':
+ import imageio
+ import tqdm
+ img = cv2.imread("assets/cano_secc.png")
+ img = img / 127.5 - 1
+ img = torch.FloatTensor(img).permute(2, 0, 1)
+ fps = 25
+ writer = imageio.get_writer('demo_blink.mp4', fps=fps)
+
+ for i in tqdm.trange(33):
+ blink_percent = 0.03 * i
+ with Timer("Blink", True):
+ out_img = blink_eye_for_secc(img, blink_percent)
+ out_img = ((out_img.permute(1,2,0)+1)*127.5).int().numpy()
+ writer.append_data(out_img)
+ writer.close()
\ No newline at end of file
diff --git a/inference/infer_utils.py b/inference/infer_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d06c08d280dea47b3b622309f624fbde5224867
--- /dev/null
+++ b/inference/infer_utils.py
@@ -0,0 +1,154 @@
+import os
+import torch
+import torch.nn.functional as F
+import librosa
+import numpy as np
+import importlib
+import tqdm
+import copy
+import cv2
+from scipy.spatial.transform import Rotation
+
+
+def load_img_to_512_hwc_array(img_name):
+ img = cv2.imread(img_name)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = cv2.resize(img, (512, 512))
+ return img
+
+def load_img_to_normalized_512_bchw_tensor(img_name):
+ img = load_img_to_512_hwc_array(img_name)
+ img = ((torch.tensor(img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2) # [b,c,h,w]
+ return img
+
+def mirror_index(index, len_seq):
+ """
+ get mirror index when indexing a sequence and the index is larger than len_pose
+ args:
+ index: int
+ len_pose: int
+ return:
+ mirror_index: int
+ """
+ turn = index // len_seq
+ res = index % len_seq
+ if turn % 2 == 0:
+ return res # forward indexing
+ else:
+ return len_seq - res - 1 # reverse indexing
+
+def smooth_camera_sequence(camera, kernel_size=7):
+ """
+ smooth the camera trajectory (i.e., rotation & translation)...
+ args:
+ camera: [N, 25] or [N, 16]. np.ndarray
+ kernel_size: int
+ return:
+ smoothed_camera: [N, 25] or [N, 16]. np.ndarray
+ """
+ # poses: [N, 25], numpy array
+ N = camera.shape[0]
+ K = kernel_size // 2
+ poses = camera[:, :16].reshape([-1, 4, 4]).copy()
+ trans = poses[:, :3, 3].copy() # [N, 3]
+ rots = poses[:, :3, :3].copy() # [N, 3, 3]
+
+ for i in range(N):
+ start = max(0, i - K)
+ end = min(N, i + K + 1)
+ poses[i, :3, 3] = trans[start:end].mean(0)
+ try:
+ poses[i, :3, :3] = Rotation.from_matrix(rots[start:end]).mean().as_matrix()
+ except:
+ if i == 0:
+ poses[i, :3, :3] = rots[i]
+ else:
+ poses[i, :3, :3] = poses[i-1, :3, :3]
+ poses = poses.reshape([-1, 16])
+ camera[:, :16] = poses
+ return camera
+
+def smooth_features_xd(in_tensor, kernel_size=7):
+ """
+ smooth the feature maps
+ args:
+ in_tensor: [T, c,h,w] or [T, c1,c2,h,w]
+ kernel_size: int
+ return:
+ out_tensor: [T, c,h,w] or [T, c1,c2,h,w]
+ """
+ t = in_tensor.shape[0]
+ ndim = in_tensor.ndim
+ pad = (kernel_size- 1)//2
+ in_tensor = torch.cat([torch.flip(in_tensor[0:pad], dims=[0]), in_tensor, torch.flip(in_tensor[t-pad:t], dims=[0])], dim=0)
+ if ndim == 2: # tc
+ _,c = in_tensor.shape
+ in_tensor = in_tensor.permute(1,0).reshape([-1,1,t+2*pad]) # [c, 1, t]
+ elif ndim == 4: # tchw
+ _,c,h,w = in_tensor.shape
+ in_tensor = in_tensor.permute(1,2,3,0).reshape([-1,1,t+2*pad]) # [c, 1, t]
+ elif ndim == 5: # tcchw, like deformation
+ _,c1,c2, h,w = in_tensor.shape
+ in_tensor = in_tensor.permute(1,2,3,4,0).reshape([-1,1,t+2*pad]) # [c, 1, t]
+ else: raise NotImplementedError()
+ avg_kernel = 1 / kernel_size * torch.Tensor([1.]*kernel_size).reshape([1,1,kernel_size]).float().to(in_tensor.device) # [1, 1, kw]
+ out_tensor = F.conv1d(in_tensor, avg_kernel)
+ if ndim == 2: # tc
+ return out_tensor.reshape([c,t]).permute(1,0)
+ elif ndim == 4: # tchw
+ return out_tensor.reshape([c,h,w,t]).permute(3,0,1,2)
+ elif ndim == 5: # tcchw, like deformation
+ return out_tensor.reshape([c1,c2,h,w,t]).permute(4,0,1,2,3)
+
+
+def extract_audio_motion_from_ref_video(video_name):
+ def save_wav16k(audio_name):
+ supported_types = ('.wav', '.mp3', '.mp4', '.avi')
+ assert audio_name.endswith(supported_types), f"Now we only support {','.join(supported_types)} as audio source!"
+ wav16k_name = audio_name[:-4] + '_16k.wav'
+ extract_wav_cmd = f"ffmpeg -i {audio_name} -f wav -ar 16000 -v quiet -y {wav16k_name} -y"
+ os.system(extract_wav_cmd)
+ print(f"Extracted wav file (16khz) from {audio_name} to {wav16k_name}.")
+ return wav16k_name
+
+ def get_f0( wav16k_name):
+ from data_gen.utils.process_audio.extract_mel_f0 import extract_mel_from_fname, extract_f0_from_wav_and_mel
+ wav, mel = extract_mel_from_fname(wav16k_name)
+ f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
+ f0 = f0.reshape([-1,1])
+ f0 = torch.tensor(f0)
+ return f0
+
+ def get_hubert(wav16k_name):
+ from data_gen.utils.process_audio.extract_hubert import get_hubert_from_16k_wav
+ hubert = get_hubert_from_16k_wav(wav16k_name).detach().numpy()
+ len_mel = hubert.shape[0]
+ x_multiply = 8
+ if len_mel % x_multiply == 0:
+ num_to_pad = 0
+ else:
+ num_to_pad = x_multiply - len_mel % x_multiply
+ hubert = np.pad(hubert, pad_width=((0,num_to_pad), (0,0)))
+ hubert = torch.tensor(hubert)
+ return hubert
+
+ def get_exp(video_name):
+ from data_gen.utils.process_video.fit_3dmm_landmark import fit_3dmm_for_a_video
+ drv_motion_coeff_dict = fit_3dmm_for_a_video(video_name, save=False)
+ exp = torch.tensor(drv_motion_coeff_dict['exp'])
+ return exp
+
+ wav16k_name = save_wav16k(video_name)
+ f0 = get_f0(wav16k_name)
+ hubert = get_hubert(wav16k_name)
+ os.system(f"rm {wav16k_name}")
+ exp = get_exp(video_name)
+ target_length = min(len(exp), len(hubert)//2, len(f0)//2)
+ exp = exp[:target_length]
+ f0 = f0[:target_length*2]
+ hubert = hubert[:target_length*2]
+ return exp.unsqueeze(0), hubert.unsqueeze(0), f0.unsqueeze(0)
+
+
+if __name__ == '__main__':
+ extract_audio_motion_from_ref_video('data/raw/videos/crop_0213.mp4')
\ No newline at end of file
diff --git a/inference/mimictalk_infer.py b/inference/mimictalk_infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5203b5952b23d5d6b1d166a12a1bf8b68e69f234
--- /dev/null
+++ b/inference/mimictalk_infer.py
@@ -0,0 +1,357 @@
+"""
+用于推理 inference/train_mimictalk_on_a_video.py 得到的person-specific模型
+"""
+import os
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+# import librosa
+import random
+import time
+import numpy as np
+import importlib
+import tqdm
+import copy
+import cv2
+
+# common utils
+from utils.commons.hparams import hparams, set_hparams
+from utils.commons.tensor_utils import move_to_cuda, convert_to_tensor
+from utils.commons.ckpt_utils import load_ckpt, get_last_checkpoint
+# 3DMM-related utils
+from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
+from data_util.face3d_helper import Face3DHelper
+from data_gen.utils.process_image.fit_3dmm_landmark import fit_3dmm_for_a_image
+from data_gen.utils.process_video.fit_3dmm_landmark import fit_3dmm_for_a_video
+from deep_3drecon.secc_renderer import SECC_Renderer
+from data_gen.eg3d.convert_to_eg3d_convention import get_eg3d_convention_camera_pose_intrinsic
+# Face Parsing
+from data_gen.utils.mp_feature_extractors.mp_segmenter import MediapipeSegmenter
+from data_gen.utils.process_video.extract_segment_imgs import inpaint_torso_job, extract_background
+# other inference utils
+from inference.infer_utils import mirror_index, load_img_to_512_hwc_array, load_img_to_normalized_512_bchw_tensor
+from inference.infer_utils import smooth_camera_sequence, smooth_features_xd
+from inference.edit_secc import blink_eye_for_secc, hold_eye_opened_for_secc
+from inference.real3d_infer import GeneFace2Infer
+
+
+class AdaptGeneFace2Infer(GeneFace2Infer):
+ def __init__(self, audio2secc_dir, head_model_dir, torso_model_dir, device=None, **kwargs):
+ if device is None:
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ self.device = device
+ self.audio2secc_dir = audio2secc_dir
+ self.head_model_dir = head_model_dir
+ self.torso_model_dir = torso_model_dir
+ self.audio2secc_model = self.load_audio2secc(audio2secc_dir)
+ self.secc2video_model = self.load_secc2video(head_model_dir, torso_model_dir)
+ self.audio2secc_model.to(device).eval()
+ self.secc2video_model.to(device).eval()
+ self.seg_model = MediapipeSegmenter()
+ self.secc_renderer = SECC_Renderer(512)
+ self.face3d_helper = Face3DHelper(use_gpu=True, keypoint_mode='lm68')
+ self.mp_face3d_helper = Face3DHelper(use_gpu=True, keypoint_mode='mediapipe')
+ # self.camera_selector = KNearestCameraSelector()
+
+ def load_secc2video(self, head_model_dir, torso_model_dir):
+ if torso_model_dir != '':
+ config_dir = torso_model_dir if os.path.isdir(torso_model_dir) else os.path.dirname(torso_model_dir)
+ set_hparams(f"{config_dir}/config.yaml", print_hparams=False)
+ hparams['htbsr_head_threshold'] = 1.0
+ self.secc2video_hparams = copy.deepcopy(hparams)
+ ckpt = get_last_checkpoint(torso_model_dir)[0]
+ lora_args = ckpt.get("lora_args", None)
+ from modules.real3d.secc_img2plane_torso import OSAvatarSECC_Img2plane_Torso
+ model = OSAvatarSECC_Img2plane_Torso(self.secc2video_hparams, lora_args=lora_args)
+ load_ckpt(model, f"{torso_model_dir}", model_name='model', strict=True)
+ self.learnable_triplane = nn.Parameter(torch.zeros([1, 3, model.triplane_hid_dim*model.triplane_depth, 256, 256]).float().cuda(), requires_grad=True)
+ load_ckpt(self.learnable_triplane, f"{torso_model_dir}", model_name='learnable_triplane', strict=True)
+ model._last_cano_planes = self.learnable_triplane
+ if head_model_dir != '':
+ print("| Warning: Assigned --torso_ckpt which also contains head, but --head_ckpt is also assigned, skipping the --head_ckpt.")
+ else:
+ from modules.real3d.secc_img2plane_torso import OSAvatarSECC_Img2plane
+ set_hparams(f"{head_model_dir}/config.yaml", print_hparams=False)
+ ckpt = get_last_checkpoint(head_model_dir)[0]
+ lora_args = ckpt.get("lora_args", None)
+ self.secc2video_hparams = copy.deepcopy(hparams)
+ model = OSAvatarSECC_Img2plane(self.secc2video_hparams, lora_args=lora_args)
+ load_ckpt(model, f"{head_model_dir}", model_name='model', strict=True)
+ self.learnable_triplane = nn.Parameter(torch.zeros([1, 3, model.triplane_hid_dim*model.triplane_depth, 256, 256]).float().cuda(), requires_grad=True)
+ model._last_cano_planes = self.learnable_triplane
+ load_ckpt(model._last_cano_planes, f"{head_model_dir}", model_name='learnable_triplane', strict=True)
+ self.person_ds = ckpt['person_ds']
+ return model
+
+ def prepare_batch_from_inp(self, inp):
+ """
+ :param inp: {'audio_source_name': (str)}
+ :return: a dict that contains the condition feature of NeRF
+ """
+ sample = {}
+ # Process Driving Motion
+ if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
+ self.save_wav16k(inp['drv_audio_name'])
+ if self.audio2secc_hparams['audio_type'] == 'hubert':
+ hubert = self.get_hubert(self.wav16k_name)
+ elif self.audio2secc_hparams['audio_type'] == 'mfcc':
+ hubert = self.get_mfcc(self.wav16k_name) / 100
+
+ f0 = self.get_f0(self.wav16k_name)
+ if f0.shape[0] > len(hubert):
+ f0 = f0[:len(hubert)]
+ else:
+ num_to_pad = len(hubert) - len(f0)
+ f0 = np.pad(f0, pad_width=((0,num_to_pad), (0,0)))
+ t_x = hubert.shape[0]
+ x_mask = torch.ones([1, t_x]).float() # mask for audio frames
+ y_mask = torch.ones([1, t_x//2]).float() # mask for motion/image frames
+ sample.update({
+ 'hubert': torch.from_numpy(hubert).float().unsqueeze(0).cuda(),
+ 'f0': torch.from_numpy(f0).float().reshape([1,-1]).cuda(),
+ 'x_mask': x_mask.cuda(),
+ 'y_mask': y_mask.cuda(),
+ })
+ sample['blink'] = torch.zeros([1, t_x, 1]).long().cuda()
+ sample['audio'] = sample['hubert']
+ sample['eye_amp'] = torch.ones([1, 1]).cuda() * 1.0
+ elif inp['drv_audio_name'][-4:] in ['.mp4']:
+ drv_motion_coeff_dict = fit_3dmm_for_a_video(inp['drv_audio_name'], save=False)
+ drv_motion_coeff_dict = convert_to_tensor(drv_motion_coeff_dict)
+ t_x = drv_motion_coeff_dict['exp'].shape[0] * 2
+ self.drv_motion_coeff_dict = drv_motion_coeff_dict
+ elif inp['drv_audio_name'][-4:] in ['.npy']:
+ drv_motion_coeff_dict = np.load(inp['drv_audio_name'], allow_pickle=True).tolist()
+ drv_motion_coeff_dict = convert_to_tensor(drv_motion_coeff_dict)
+ t_x = drv_motion_coeff_dict['exp'].shape[0] * 2
+ self.drv_motion_coeff_dict = drv_motion_coeff_dict
+
+ # Face Parsing
+ sample['ref_gt_img'] = self.person_ds['gt_img'].cuda()
+ img = self.person_ds['gt_img'].reshape([3, 512, 512]).permute(1, 2, 0)
+ img = (img + 1) * 127.5
+ img = np.ascontiguousarray(img.int().numpy()).astype(np.uint8)
+ segmap = self.seg_model._cal_seg_map(img)
+ sample['segmap'] = torch.tensor(segmap).float().unsqueeze(0).cuda()
+ head_img = self.seg_model._seg_out_img_with_segmap(img, segmap, mode='head')[0]
+ sample['ref_head_img'] = ((torch.tensor(head_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
+ inpaint_torso_img, _, _, _ = inpaint_torso_job(img, segmap)
+ sample['ref_torso_img'] = ((torch.tensor(inpaint_torso_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
+
+ if inp['bg_image_name'] == '':
+ bg_img = extract_background([img], [segmap], 'knn')
+ else:
+ bg_img = cv2.imread(inp['bg_image_name'])
+ bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB)
+ bg_img = cv2.resize(bg_img, (512,512))
+ sample['bg_img'] = ((torch.tensor(bg_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
+
+ # 3DMM, get identity code and camera pose
+ image_name = f"data/raw/val_imgs/{self.person_ds['video_id']}_img.png"
+ os.makedirs(os.path.dirname(image_name), exist_ok=True)
+ cv2.imwrite(image_name, img[:,:,::-1])
+ coeff_dict = fit_3dmm_for_a_image(image_name, save=False)
+ coeff_dict['id'] = self.person_ds['id'].reshape([1,80]).numpy()
+
+ assert coeff_dict is not None
+ src_id = torch.tensor(coeff_dict['id']).reshape([1,80]).cuda()
+ src_exp = torch.tensor(coeff_dict['exp']).reshape([1,64]).cuda()
+ src_euler = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda()
+ src_trans = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda()
+ sample['id'] = src_id.repeat([t_x//2,1])
+
+ # get the src_kp for torso model
+ sample['src_kp'] = self.person_ds['src_kp'].cuda().reshape([1, 68, 3]).repeat([t_x//2,1,1])[..., :2] # [B, 68, 2]
+
+ # get camera pose file
+ random.seed(time.time())
+ if inp['drv_pose_name'] in ['nearest', 'topk']:
+ camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler': torch.tensor(coeff_dict['euler']).reshape([1,3]), 'trans': torch.tensor(coeff_dict['trans']).reshape([1,3])})
+ c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics']
+ camera = np.concatenate([c2w.reshape([1,16]), intrinsics.reshape([1,9])], axis=-1)
+ coeff_names, distance_matrix = self.camera_selector.find_k_nearest(camera, k=100)
+ coeff_names = coeff_names[0] # squeeze
+ if inp['drv_pose_name'] == 'nearest':
+ inp['drv_pose_name'] = coeff_names[0]
+ else:
+ inp['drv_pose_name'] = random.choice(coeff_names)
+ # inp['drv_pose_name'] = coeff_names[0]
+ elif inp['drv_pose_name'] == 'random':
+ inp['drv_pose_name'] = self.camera_selector.random_select()
+ else:
+ inp['drv_pose_name'] = inp['drv_pose_name']
+
+ print(f"| To extract pose from {inp['drv_pose_name']}")
+
+ # extract camera pose
+ if inp['drv_pose_name'] == 'static':
+ sample['euler'] = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda().repeat([t_x//2,1]) # default static pose
+ sample['trans'] = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda().repeat([t_x//2,1])
+ else: # from file
+ if inp['drv_pose_name'].endswith('.mp4'):
+ # extract coeff from video
+ drv_pose_coeff_dict = fit_3dmm_for_a_video(inp['drv_pose_name'], save=False)
+ else:
+ # load from npy
+ drv_pose_coeff_dict = np.load(inp['drv_pose_name'], allow_pickle=True).tolist()
+ print(f"| Extracted pose from {inp['drv_pose_name']}")
+ eulers = convert_to_tensor(drv_pose_coeff_dict['euler']).reshape([-1,3]).cuda()
+ trans = convert_to_tensor(drv_pose_coeff_dict['trans']).reshape([-1,3]).cuda()
+ len_pose = len(eulers)
+ index_lst = [mirror_index(i, len_pose) for i in range(t_x//2)]
+ sample['euler'] = eulers[index_lst]
+ sample['trans'] = trans[index_lst]
+
+ # fix the z axis
+ sample['trans'][:, -1] = sample['trans'][0:1, -1].repeat([sample['trans'].shape[0]])
+
+ # mapping to the init pose
+ if inp.get("map_to_init_pose", 'False') == 'True':
+ diff_euler = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda() - sample['euler'][0:1]
+ sample['euler'] = sample['euler'] + diff_euler
+ diff_trans = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda() - sample['trans'][0:1]
+ sample['trans'] = sample['trans'] + diff_trans
+
+ # prepare camera
+ camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler':sample['euler'].cpu(), 'trans':sample['trans'].cpu()})
+ c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics']
+ # smooth camera
+ camera_smo_ksize = 7
+ camera = np.concatenate([c2w.reshape([-1,16]), intrinsics.reshape([-1,9])], axis=-1)
+ camera = smooth_camera_sequence(camera, kernel_size=camera_smo_ksize) # [T, 25]
+ camera = torch.tensor(camera).cuda().float()
+ sample['camera'] = camera
+
+ return sample
+
+ @torch.no_grad()
+ def forward_secc2video(self, batch, inp=None):
+ num_frames = len(batch['drv_secc'])
+ camera = batch['camera']
+ src_kps = batch['src_kp']
+ drv_kps = batch['drv_kp']
+ cano_secc_color = batch['cano_secc']
+ src_secc_color = batch['src_secc']
+ drv_secc_colors = batch['drv_secc']
+ ref_img_gt = batch['ref_gt_img']
+ ref_img_head = batch['ref_head_img']
+ ref_torso_img = batch['ref_torso_img']
+ bg_img = batch['bg_img']
+ segmap = batch['segmap']
+
+ # smooth torso drv_kp
+ torso_smo_ksize = 7
+ drv_kps = smooth_features_xd(drv_kps.reshape([-1, 68*2]), kernel_size=torso_smo_ksize).reshape([-1, 68, 2])
+
+ # forward renderer
+ img_raw_lst = []
+ img_lst = []
+ depth_img_lst = []
+ with torch.no_grad():
+ for i in tqdm.trange(num_frames, desc="MimicTalk is rendering frames"):
+ kp_src = torch.cat([src_kps[i:i+1].reshape([1, 68, 2]), torch.zeros([1, 68,1]).to(src_kps.device)],dim=-1)
+ kp_drv = torch.cat([drv_kps[i:i+1].reshape([1, 68, 2]), torch.zeros([1, 68,1]).to(drv_kps.device)],dim=-1)
+ cond={'cond_cano': cano_secc_color,'cond_src': src_secc_color, 'cond_tgt': drv_secc_colors[i:i+1].cuda(),
+ 'ref_torso_img': ref_torso_img, 'bg_img': bg_img, 'segmap': segmap,
+ 'kp_s': kp_src, 'kp_d': kp_drv}
+
+ ########################################################################################################
+ ### 相比real3d_infer只修改了这行👇,即cano_triplane来自cache里的learnable_triplane,而不是img预测的plane ####
+ ########################################################################################################
+ gen_output = self.secc2video_model.forward(img=None, camera=camera[i:i+1], cond=cond, ret={}, cache_backbone=False, use_cached_backbone=True)
+
+ img_lst.append(gen_output['image'])
+ img_raw_lst.append(gen_output['image_raw'])
+ depth_img_lst.append(gen_output['image_depth'])
+
+ # save demo video
+ depth_imgs = torch.cat(depth_img_lst)
+ imgs = torch.cat(img_lst)
+ imgs_raw = torch.cat(img_raw_lst)
+ secc_img = torch.cat([torch.nn.functional.interpolate(drv_secc_colors[i:i+1], (512,512)) for i in range(num_frames)])
+
+ if inp['out_mode'] == 'concat_debug':
+ secc_img = secc_img.cpu()
+ secc_img = ((secc_img + 1) * 127.5).permute(0, 2, 3, 1).int().numpy()
+
+ depth_img = F.interpolate(depth_imgs, (512,512)).cpu()
+ depth_img = depth_img.repeat([1,3,1,1])
+ depth_img = (depth_img - depth_img.min()) / (depth_img.max() - depth_img.min())
+ depth_img = depth_img * 2 - 1
+ depth_img = depth_img.clamp(-1,1)
+
+ secc_img = secc_img / 127.5 - 1
+ secc_img = torch.from_numpy(secc_img).permute(0, 3, 1, 2)
+ imgs = torch.cat([ref_img_gt.repeat([imgs.shape[0],1,1,1]).cpu(), secc_img, F.interpolate(imgs_raw, (512,512)).cpu(), depth_img, imgs.cpu()], dim=-1)
+ elif inp['out_mode'] == 'final':
+ imgs = imgs.cpu()
+ elif inp['out_mode'] == 'debug':
+ raise NotImplementedError("to do: save separate videos")
+ imgs = imgs.clamp(-1,1)
+
+ import imageio
+ import uuid
+ debug_name = f'{uuid.uuid1()}.mp4'
+ out_imgs = ((imgs.permute(0, 2, 3, 1) + 1)/2 * 255).int().cpu().numpy().astype(np.uint8)
+ writer = imageio.get_writer(debug_name, fps=25, format='FFMPEG', codec='h264')
+ for i in tqdm.trange(len(out_imgs), desc="Imageio is saving video"):
+ writer.append_data(out_imgs[i])
+ writer.close()
+
+ out_fname = 'infer_out/tmp/' + os.path.basename(inp['drv_pose_name'])[:-4] + '.mp4' if inp['out_name'] == '' else inp['out_name']
+ try:
+ os.makedirs(os.path.dirname(out_fname), exist_ok=True)
+ except: pass
+ if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
+ # os.system(f"ffmpeg -i {debug_name} -i {inp['drv_audio_name']} -y -v quiet -shortest {out_fname}")
+ cmd = f"/usr/bin/ffmpeg -i {debug_name} -i {self.wav16k_name} -y -r 25 -ar 16000 -c:v copy -c:a libmp3lame -pix_fmt yuv420p -b:v 2000k -strict experimental -shortest {out_fname}"
+ os.system(cmd)
+ os.system(f"rm {debug_name}")
+ else:
+ ret = os.system(f"ffmpeg -i {debug_name} -i {inp['drv_audio_name']} -map 0:v -map 1:a -y -v quiet -shortest {out_fname}")
+ if ret != 0: # 没有成功从drv_audio_name里面提取到音频, 则直接输出无音频轨道的纯视频
+ os.system(f"mv {debug_name} {out_fname}")
+ print(f"Saved at {out_fname}")
+ return out_fname
+
+if __name__ == '__main__':
+ import argparse, glob, tqdm
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--a2m_ckpt", default='checkpoints/240112_icl_audio2secc_vox2_cmlr') # checkpoints/0727_audio2secc/audio2secc_withlm2d100_randomframe
+ parser.add_argument("--head_ckpt", default='') # checkpoints/0729_th1kh/secc_img2plane checkpoints/0720_img2planes/secc_img2plane_two_stage
+ parser.add_argument("--torso_ckpt", default='checkpoints_mimictalk/German_20s')
+ parser.add_argument("--bg_img", default='') # data/raw/val_imgs/bg3.png
+ parser.add_argument("--drv_aud", default='data/raw/examples/80_vs_60_10s.wav')
+ parser.add_argument("--drv_pose", default='data/raw/examples/German_20s.mp4') # nearest | topk | random | static | vid_name
+ parser.add_argument("--drv_style", default='data/raw/examples/angry.mp4') # nearest | topk | random | static | vid_name
+ parser.add_argument("--blink_mode", default='period') # none | period
+ parser.add_argument("--temperature", default=0.3, type=float) # nearest | random
+ parser.add_argument("--denoising_steps", default=20, type=int) # nearest | random
+ parser.add_argument("--cfg_scale", default=1.5, type=float) # nearest | random
+ parser.add_argument("--out_name", default='') # nearest | random
+ parser.add_argument("--out_mode", default='concat_debug') # concat_debug | debug | final
+ parser.add_argument("--hold_eye_opened", default='False') # concat_debug | debug | final
+ parser.add_argument("--map_to_init_pose", default='True') # concat_debug | debug | final
+ parser.add_argument("--seed", default=None, type=int) # random seed, default None to use time.time()
+
+ args = parser.parse_args()
+
+ inp = {
+ 'a2m_ckpt': args.a2m_ckpt,
+ 'head_ckpt': args.head_ckpt,
+ 'torso_ckpt': args.torso_ckpt,
+ 'bg_image_name': args.bg_img,
+ 'drv_audio_name': args.drv_aud,
+ 'drv_pose_name': args.drv_pose,
+ 'drv_talking_style_name': args.drv_style,
+ 'blink_mode': args.blink_mode,
+ 'temperature': args.temperature,
+ 'denoising_steps': args.denoising_steps,
+ 'cfg_scale': args.cfg_scale,
+ 'out_name': args.out_name,
+ 'out_mode': args.out_mode,
+ 'map_to_init_pose': args.map_to_init_pose,
+ 'hold_eye_opened': args.hold_eye_opened,
+ 'seed': args.seed,
+ }
+ AdaptGeneFace2Infer.example_run(inp)
\ No newline at end of file
diff --git a/inference/real3d_infer.py b/inference/real3d_infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..27f25c51616de30eda430973cabd131d974b7f75
--- /dev/null
+++ b/inference/real3d_infer.py
@@ -0,0 +1,667 @@
+import os
+import torch
+import torch.nn.functional as F
+import torchshow as ts
+import librosa
+import random
+import time
+import numpy as np
+import importlib
+import tqdm
+import copy
+import cv2
+import math
+
+# common utils
+from utils.commons.hparams import hparams, set_hparams
+from utils.commons.tensor_utils import move_to_cuda, convert_to_tensor
+from utils.commons.ckpt_utils import load_ckpt, get_last_checkpoint
+# 3DMM-related utils
+from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
+from data_util.face3d_helper import Face3DHelper
+from data_gen.utils.process_image.fit_3dmm_landmark import fit_3dmm_for_a_image
+from data_gen.utils.process_video.fit_3dmm_landmark import fit_3dmm_for_a_video
+from data_gen.utils.process_image.extract_lm2d import extract_lms_mediapipe_job
+from data_gen.utils.process_image.fit_3dmm_landmark import index_lm68_from_lm468
+from deep_3drecon.secc_renderer import SECC_Renderer
+from data_gen.eg3d.convert_to_eg3d_convention import get_eg3d_convention_camera_pose_intrinsic
+# Face Parsing
+from data_gen.utils.mp_feature_extractors.mp_segmenter import MediapipeSegmenter
+from data_gen.utils.process_video.extract_segment_imgs import inpaint_torso_job, extract_background
+# other inference utils
+from inference.infer_utils import mirror_index, load_img_to_512_hwc_array, load_img_to_normalized_512_bchw_tensor
+from inference.infer_utils import smooth_camera_sequence, smooth_features_xd
+from inference.edit_secc import blink_eye_for_secc, hold_eye_opened_for_secc
+
+
+def read_first_frame_from_a_video(vid_name):
+ frames = []
+ cap = cv2.VideoCapture(vid_name)
+ ret, frame_bgr = cap.read()
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
+ return frame_rgb
+
+def analyze_weights_img(gen_output):
+ img_raw = gen_output['image_raw']
+ mask_005_to_03 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.3).repeat([1,3,1,1])
+ mask_005_to_05 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.5).repeat([1,3,1,1])
+ mask_005_to_07 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.7).repeat([1,3,1,1])
+ mask_005_to_09 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.9).repeat([1,3,1,1])
+ mask_005_to_10 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<1.0).repeat([1,3,1,1])
+
+ img_raw_005_to_03 = img_raw.clone()
+ img_raw_005_to_03[~mask_005_to_03] = -1
+ img_raw_005_to_05 = img_raw.clone()
+ img_raw_005_to_05[~mask_005_to_05] = -1
+ img_raw_005_to_07 = img_raw.clone()
+ img_raw_005_to_07[~mask_005_to_07] = -1
+ img_raw_005_to_09 = img_raw.clone()
+ img_raw_005_to_09[~mask_005_to_09] = -1
+ img_raw_005_to_10 = img_raw.clone()
+ img_raw_005_to_10[~mask_005_to_10] = -1
+ ts.save([img_raw_005_to_03[0], img_raw_005_to_05[0], img_raw_005_to_07[0], img_raw_005_to_09[0], img_raw_005_to_10[0]])
+
+
+def cal_face_area_percent(img_name):
+ img = cv2.resize(cv2.imread(img_name)[:,:,::-1], (512,512))
+ lm478 = extract_lms_mediapipe_job(img) / 512
+ min_x = lm478[:,0].min()
+ max_x = lm478[:,0].max()
+ min_y = lm478[:,1].min()
+ max_y = lm478[:,1].max()
+ area = (max_x - min_x) * (max_y - min_y)
+ return area
+
+def crop_img_on_face_area_percent(img_name, out_name='temp/cropped_src_img.png', min_face_area_percent=0.2):
+ try:
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
+ except: pass
+ face_area_percent = cal_face_area_percent(img_name)
+ if face_area_percent >= min_face_area_percent:
+ print(f"face area percent {face_area_percent} larger than threshold {min_face_area_percent}, directly use the input image...")
+ cmd = f"cp {img_name} {out_name}"
+ os.system(cmd)
+ return out_name
+ else:
+ print(f"face area percent {face_area_percent} smaller than threshold {min_face_area_percent}, crop the input image...")
+ img = cv2.resize(cv2.imread(img_name)[:,:,::-1], (512,512))
+ lm478 = extract_lms_mediapipe_job(img).astype(int)
+ min_x = lm478[:,0].min()
+ max_x = lm478[:,0].max()
+ min_y = lm478[:,1].min()
+ max_y = lm478[:,1].max()
+ face_area = (max_x - min_x) * (max_y - min_y)
+ target_total_area = face_area / min_face_area_percent
+ target_hw = int(target_total_area**0.5)
+ center_x, center_y = (min_x+max_x)/2, (min_y+max_y)/2
+ shrink_pixels = 2 * max(-(center_x - target_hw/2), center_x + target_hw/2 - 512, -(center_y - target_hw/2), center_y + target_hw/2-512)
+ shrink_pixels = max(0, shrink_pixels)
+ hw = math.floor(target_hw - shrink_pixels)
+ new_min_x = int(center_x - hw/2)
+ new_max_x = int(center_x + hw/2)
+ new_min_y = int(center_y - hw/2)
+ new_max_y = int(center_y + hw/2)
+
+ img = img[new_min_y:new_max_y, new_min_x:new_max_x]
+ img = cv2.resize(img, (512, 512))
+ cv2.imwrite(out_name, img[:,:,::-1])
+ return out_name
+
+
+class GeneFace2Infer:
+ def __init__(self, audio2secc_dir, head_model_dir, torso_model_dir, device=None, inp=None):
+ if device is None:
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ self.audio2secc_model = self.load_audio2secc(audio2secc_dir)
+ self.secc2video_model = self.load_secc2video(head_model_dir, torso_model_dir, inp)
+ self.audio2secc_model.to(device).eval()
+ self.secc2video_model.to(device).eval()
+ self.seg_model = MediapipeSegmenter()
+ self.secc_renderer = SECC_Renderer(512)
+ self.face3d_helper = Face3DHelper(use_gpu=True, keypoint_mode='lm68')
+ self.mp_face3d_helper = Face3DHelper(use_gpu=True, keypoint_mode='mediapipe')
+ self.camera_selector = KNearestCameraSelector()
+
+ def load_audio2secc(self, audio2secc_dir):
+ config_name = f"{audio2secc_dir}/config.yaml" if not audio2secc_dir.endswith(".ckpt") else f"{os.path.dirname(audio2secc_dir)}/config.yaml"
+ set_hparams(f"{config_name}", print_hparams=False)
+ self.audio2secc_dir = audio2secc_dir
+ self.audio2secc_hparams = copy.deepcopy(hparams)
+ from modules.audio2motion.vae import VAEModel, PitchContourVAEModel
+ from modules.audio2motion.cfm.icl_audio2motion_model import InContextAudio2MotionModel
+ if self.audio2secc_hparams['audio_type'] == 'hubert':
+ audio_in_dim = 1024
+ elif self.audio2secc_hparams['audio_type'] == 'mfcc':
+ audio_in_dim = 13
+
+ if 'icl' in hparams['task_cls']:
+ self.use_icl_audio2motion = True
+ model = InContextAudio2MotionModel(hparams['icl_model_type'], hparams=self.audio2secc_hparams)
+ else:
+ self.use_icl_audio2motion = False
+ if hparams.get("use_pitch", False) is True:
+ model = PitchContourVAEModel(hparams, in_out_dim=64, audio_in_dim=audio_in_dim)
+ else:
+ model = VAEModel(in_out_dim=64, audio_in_dim=audio_in_dim)
+ load_ckpt(model, f"{audio2secc_dir}", model_name='model', strict=True)
+ return model
+
+ def load_secc2video(self, head_model_dir, torso_model_dir, inp):
+ if inp is None:
+ inp = {}
+ self.head_model_dir = head_model_dir
+ self.torso_model_dir = torso_model_dir
+ if torso_model_dir != '':
+ if torso_model_dir.endswith(".ckpt"):
+ set_hparams(f"{os.path.dirname(torso_model_dir)}/config.yaml", print_hparams=False)
+ else:
+ set_hparams(f"{torso_model_dir}/config.yaml", print_hparams=False)
+ if inp.get('head_torso_threshold', None) is not None:
+ hparams['htbsr_head_threshold'] = inp['head_torso_threshold']
+ self.secc2video_hparams = copy.deepcopy(hparams)
+ from modules.real3d.secc_img2plane_torso import OSAvatarSECC_Img2plane_Torso
+ model = OSAvatarSECC_Img2plane_Torso()
+ load_ckpt(model, f"{torso_model_dir}", model_name='model', strict=False)
+ if head_model_dir != '':
+ print("| Warning: Assigned --torso_ckpt which also contains head, but --head_ckpt is also assigned, skipping the --head_ckpt.")
+ else:
+ from modules.real3d.secc_img2plane_torso import OSAvatarSECC_Img2plane
+ if head_model_dir.endswith(".ckpt"):
+ set_hparams(f"{os.path.dirname(head_model_dir)}/config.yaml", print_hparams=False)
+ else:
+ set_hparams(f"{head_model_dir}/config.yaml", print_hparams=False)
+ if inp.get('head_torso_threshold', None) is not None:
+ hparams['htbsr_head_threshold'] = inp['head_torso_threshold']
+ self.secc2video_hparams = copy.deepcopy(hparams)
+ model = OSAvatarSECC_Img2plane()
+ load_ckpt(model, f"{head_model_dir}", model_name='model', strict=False)
+ return model
+
+ def infer_once(self, inp):
+ self.inp = inp
+ samples = self.prepare_batch_from_inp(inp)
+ seed = inp['seed'] if inp['seed'] is not None else int(time.time())
+ random.seed(seed)
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ out_name = self.forward_system(samples, inp)
+ return out_name
+
+ def prepare_batch_from_inp(self, inp):
+ """
+ :param inp: {'audio_source_name': (str)}
+ :return: a dict that contains the condition feature of NeRF
+ """
+ cropped_name = 'temp/cropped_src_img_512.png'
+ crop_img_on_face_area_percent(inp['src_image_name'], cropped_name, min_face_area_percent=inp['min_face_area_percent'])
+ inp['src_image_name'] = cropped_name
+
+ sample = {}
+ # Process Driving Motion
+ if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
+ self.save_wav16k(inp['drv_audio_name'])
+ if self.audio2secc_hparams['audio_type'] == 'hubert':
+ hubert = self.get_hubert(self.wav16k_name)
+ elif self.audio2secc_hparams['audio_type'] == 'mfcc':
+ hubert = self.get_mfcc(self.wav16k_name) / 100
+
+ f0 = self.get_f0(self.wav16k_name)
+ if f0.shape[0] > len(hubert):
+ f0 = f0[:len(hubert)]
+ else:
+ num_to_pad = len(hubert) - len(f0)
+ f0 = np.pad(f0, pad_width=((0,num_to_pad), (0,0)))
+ t_x = hubert.shape[0]
+ x_mask = torch.ones([1, t_x]).float() # mask for audio frames
+ y_mask = torch.ones([1, t_x//2]).float() # mask for motion/image frames
+ sample.update({
+ 'hubert': torch.from_numpy(hubert).float().unsqueeze(0).cuda(),
+ 'f0': torch.from_numpy(f0).float().reshape([1,-1]).cuda(),
+ 'x_mask': x_mask.cuda(),
+ 'y_mask': y_mask.cuda(),
+ })
+ sample['blink'] = torch.zeros([1, t_x, 1]).long().cuda()
+ sample['audio'] = sample['hubert']
+ sample['eye_amp'] = torch.ones([1, 1]).cuda() * 1.0
+ sample['mouth_amp'] = torch.ones([1, 1]).cuda() * inp['mouth_amp']
+ elif inp['drv_audio_name'][-4:] in ['.mp4']:
+ drv_motion_coeff_dict = fit_3dmm_for_a_video(inp['drv_audio_name'], save=False)
+ drv_motion_coeff_dict = convert_to_tensor(drv_motion_coeff_dict)
+ t_x = drv_motion_coeff_dict['exp'].shape[0] * 2
+ self.drv_motion_coeff_dict = drv_motion_coeff_dict
+ elif inp['drv_audio_name'][-4:] in ['.npy']:
+ drv_motion_coeff_dict = np.load(inp['drv_audio_name'], allow_pickle=True).tolist()
+ drv_motion_coeff_dict = convert_to_tensor(drv_motion_coeff_dict)
+ t_x = drv_motion_coeff_dict['exp'].shape[0] * 2
+ self.drv_motion_coeff_dict = drv_motion_coeff_dict
+ else:
+ raise ValueError()
+
+ # Face Parsing
+ image_name = inp['src_image_name']
+ if image_name.endswith(".mp4"):
+ img = read_first_frame_from_a_video(image_name)
+ image_name = inp['src_image_name'] = image_name[:-4] + '.png'
+ cv2.imwrite(image_name, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
+ sample['ref_gt_img'] = load_img_to_normalized_512_bchw_tensor(image_name).cuda()
+ img = load_img_to_512_hwc_array(image_name)
+ segmap = self.seg_model._cal_seg_map(img)
+ sample['segmap'] = torch.tensor(segmap).float().unsqueeze(0).cuda()
+ head_img = self.seg_model._seg_out_img_with_segmap(img, segmap, mode='head')[0]
+ sample['ref_head_img'] = ((torch.tensor(head_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
+ inpaint_torso_img, _, _, _ = inpaint_torso_job(img, segmap)
+ sample['ref_torso_img'] = ((torch.tensor(inpaint_torso_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
+
+ if inp['bg_image_name'] == '':
+ bg_img = extract_background([img], [segmap], 'lama')
+ else:
+ bg_img = cv2.imread(inp['bg_image_name'])
+ bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB)
+ bg_img = cv2.resize(bg_img, (512,512))
+ sample['bg_img'] = ((torch.tensor(bg_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
+
+ # 3DMM, get identity code and camera pose
+ coeff_dict = fit_3dmm_for_a_image(image_name, save=False)
+ assert coeff_dict is not None
+ src_id = torch.tensor(coeff_dict['id']).reshape([1,80]).cuda()
+ src_exp = torch.tensor(coeff_dict['exp']).reshape([1,64]).cuda()
+ src_euler = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda()
+ src_trans = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda()
+ sample['id'] = src_id.repeat([t_x//2,1])
+
+ # get the src_kp for torso model
+ src_kp = self.face3d_helper.reconstruct_lm2d(src_id, src_exp, src_euler, src_trans) # [1, 68, 2]
+ src_kp = (src_kp-0.5) / 0.5 # rescale to -1~1
+ sample['src_kp'] = torch.clamp(src_kp, -1, 1).repeat([t_x//2,1,1])
+
+ # get camera pose file
+ # random.seed(time.time())
+ if inp['drv_pose_name'] in ['nearest', 'topk']:
+ camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler': torch.tensor(coeff_dict['euler']).reshape([1,3]), 'trans': torch.tensor(coeff_dict['trans']).reshape([1,3])})
+ c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics']
+ camera = np.concatenate([c2w.reshape([1,16]), intrinsics.reshape([1,9])], axis=-1)
+ coeff_names, distance_matrix = self.camera_selector.find_k_nearest(camera, k=100)
+ coeff_names = coeff_names[0] # squeeze
+ if inp['drv_pose_name'] == 'nearest':
+ inp['drv_pose_name'] = coeff_names[0]
+ else:
+ inp['drv_pose_name'] = random.choice(coeff_names)
+ # inp['drv_pose_name'] = coeff_names[0]
+ elif inp['drv_pose_name'] == 'random':
+ inp['drv_pose_name'] = self.camera_selector.random_select()
+ else:
+ inp['drv_pose_name'] = inp['drv_pose_name']
+
+ print(f"| To extract pose from {inp['drv_pose_name']}")
+
+ # extract camera pose
+ if inp['drv_pose_name'] == 'static':
+ sample['euler'] = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda().repeat([t_x//2,1]) # default static pose
+ sample['trans'] = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda().repeat([t_x//2,1])
+ else: # from file
+ if inp['drv_pose_name'].endswith('.mp4'):
+ # extract coeff from video
+ drv_pose_coeff_dict = fit_3dmm_for_a_video(inp['drv_pose_name'], save=False)
+ else:
+ # load from npy
+ drv_pose_coeff_dict = np.load(inp['drv_pose_name'], allow_pickle=True).tolist()
+ print(f"| Extracted pose from {inp['drv_pose_name']}")
+ eulers = convert_to_tensor(drv_pose_coeff_dict['euler']).reshape([-1,3]).cuda()
+ trans = convert_to_tensor(drv_pose_coeff_dict['trans']).reshape([-1,3]).cuda()
+ len_pose = len(eulers)
+ index_lst = [mirror_index(i, len_pose) for i in range(t_x//2)]
+ sample['euler'] = eulers[index_lst]
+ sample['trans'] = trans[index_lst]
+
+ # fix the z axis
+ sample['trans'][:, -1] = sample['trans'][0:1, -1].repeat([sample['trans'].shape[0]])
+
+ # mapping to the init pose
+ if inp.get("map_to_init_pose", 'False') == 'True':
+ diff_euler = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda() - sample['euler'][0:1]
+ sample['euler'] = sample['euler'] + diff_euler
+ diff_trans = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda() - sample['trans'][0:1]
+ sample['trans'] = sample['trans'] + diff_trans
+
+ # prepare camera
+ camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler':sample['euler'].cpu(), 'trans':sample['trans'].cpu()})
+ c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics']
+ # smooth camera
+ camera_smo_ksize = 7
+ camera = np.concatenate([c2w.reshape([-1,16]), intrinsics.reshape([-1,9])], axis=-1)
+ camera = smooth_camera_sequence(camera, kernel_size=camera_smo_ksize) # [T, 25]
+ camera = torch.tensor(camera).cuda().float()
+ sample['camera'] = camera
+
+ return sample
+
+ @torch.no_grad()
+ def get_hubert(self, wav16k_name):
+ from data_gen.utils.process_audio.extract_hubert import get_hubert_from_16k_wav
+ hubert = get_hubert_from_16k_wav(wav16k_name).detach().numpy()
+ len_mel = hubert.shape[0]
+ x_multiply = 8
+ if len_mel % x_multiply == 0:
+ num_to_pad = 0
+ else:
+ num_to_pad = x_multiply - len_mel % x_multiply
+ hubert = np.pad(hubert, pad_width=((0,num_to_pad), (0,0)))
+ return hubert
+
+ def get_mfcc(self, wav16k_name):
+ from utils.audio import librosa_wav2mfcc
+ hparams['fft_size'] = 1200
+ hparams['win_size'] = 1200
+ hparams['hop_size'] = 480
+ hparams['audio_num_mel_bins'] = 80
+ hparams['fmin'] = 80
+ hparams['fmax'] = 12000
+ hparams['audio_sample_rate'] = 24000
+ mfcc = librosa_wav2mfcc(wav16k_name,
+ fft_size=hparams['fft_size'],
+ hop_size=hparams['hop_size'],
+ win_length=hparams['win_size'],
+ num_mels=hparams['audio_num_mel_bins'],
+ fmin=hparams['fmin'],
+ fmax=hparams['fmax'],
+ sample_rate=hparams['audio_sample_rate'],
+ center=True)
+ mfcc = np.array(mfcc).reshape([-1, 13])
+ len_mel = mfcc.shape[0]
+ x_multiply = 8
+ if len_mel % x_multiply == 0:
+ num_to_pad = 0
+ else:
+ num_to_pad = x_multiply - len_mel % x_multiply
+ mfcc = np.pad(mfcc, pad_width=((0,num_to_pad), (0,0)))
+ return mfcc
+
+ @torch.no_grad()
+ def forward_audio2secc(self, batch, inp=None):
+ if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
+ from inference.infer_utils import extract_audio_motion_from_ref_video
+ if self.use_icl_audio2motion:
+ self.audio2secc_model.empty_context() # make this function reloadable
+ if self.use_icl_audio2motion and inp['drv_talking_style_name'].endswith(".mp4"):
+ ref_exp, ref_hubert, ref_f0 = extract_audio_motion_from_ref_video(inp['drv_talking_style_name'])
+ self.audio2secc_model.add_sample_to_context(ref_exp, ref_hubert, ref_f0)
+ elif self.use_icl_audio2motion and inp['drv_talking_style_name'].endswith((".png",'.jpg')):
+ style_coeff_dict = fit_3dmm_for_a_image(inp['drv_talking_style_name'])
+ ref_exp = torch.tensor(style_coeff_dict['exp']).reshape([1,1,64]).cuda()
+ self.audio2secc_model.add_sample_to_context(ref_exp.repeat([1, 100, 1]), hubert=None, f0=None)
+ else:
+ print("| WARNING: Not assigned reference talking style, passing...")
+ # audio-to-exp
+ ret = {}
+ # pred = self.audio2secc_model.forward(batch, ret=ret,train=False, ,)
+ pred = self.audio2secc_model.forward(batch, ret=ret,train=False, temperature=inp['temperature'], denoising_steps=inp['denoising_steps'], cond_scale=inp['cfg_scale'])
+
+ print("| audio-to-motion finished")
+ if pred.shape[-1] == 144:
+ id = ret['pred'][0][:,:80]
+ exp = ret['pred'][0][:,80:]
+ else:
+ id = batch['id']
+ exp = ret['pred'][0]
+ if len(id) < len(exp): # happens when use ICL
+ id = torch.cat([id, id[0].unsqueeze(0).repeat([len(exp)-len(id),1])])
+ batch['id'] = id
+ batch['exp'] = exp
+ else:
+ drv_motion_coeff_dict = self.drv_motion_coeff_dict
+ batch['exp'] = torch.FloatTensor(drv_motion_coeff_dict['exp']).cuda()
+
+ batch['id'] = batch['id'][:-4]
+ batch['exp'] = batch['exp'][:-4]
+ batch['euler'] = batch['euler'][:-4]
+ batch['trans'] = batch['trans'][:-4]
+ batch = self.get_driving_motion(batch['id'], batch['exp'], batch['euler'], batch['trans'], batch, inp)
+ if self.use_icl_audio2motion:
+ self.audio2secc_model.empty_context()
+ return batch
+
+ @torch.no_grad()
+ def get_driving_motion(self, id, exp, euler, trans, batch, inp):
+ zero_eulers = torch.zeros([id.shape[0], 3]).to(id.device)
+ zero_trans = torch.zeros([id.shape[0], 3]).to(exp.device)
+ # render the secc given the id,exp
+ with torch.no_grad():
+ chunk_size = 50
+ drv_secc_color_lst = []
+ num_iters = len(id)//chunk_size if len(id)%chunk_size == 0 else len(id)//chunk_size+1
+ for i in tqdm.trange(num_iters, desc="rendering drv secc"):
+ torch.cuda.empty_cache()
+ face_mask, drv_secc_color = self.secc_renderer(id[i*chunk_size:(i+1)*chunk_size], exp[i*chunk_size:(i+1)*chunk_size], zero_eulers[i*chunk_size:(i+1)*chunk_size], zero_trans[i*chunk_size:(i+1)*chunk_size])
+ drv_secc_color_lst.append(drv_secc_color.cpu())
+ drv_secc_colors = torch.cat(drv_secc_color_lst, dim=0)
+ _, src_secc_color = self.secc_renderer(id[0:1], exp[0:1], zero_eulers[0:1], zero_trans[0:1])
+ _, cano_secc_color = self.secc_renderer(id[0:1], exp[0:1]*0, zero_eulers[0:1], zero_trans[0:1])
+ batch['drv_secc'] = drv_secc_colors.cuda()
+ batch['src_secc'] = src_secc_color.cuda()
+ batch['cano_secc'] = cano_secc_color.cuda()
+
+ # blinking secc
+ if inp['blink_mode'] == 'period':
+ period = 5 # second
+
+ if inp['hold_eye_opened'] == 'True':
+ for i in tqdm.trange(len(drv_secc_colors),desc="opening eye for secc"):
+ batch['drv_secc'][i] = hold_eye_opened_for_secc(batch['drv_secc'][i])
+
+ for i in tqdm.trange(len(drv_secc_colors),desc="blinking secc"):
+ if i % (25*period) == 0:
+ blink_dur_frames = random.randint(8, 12)
+ for offset in range(blink_dur_frames):
+ j = offset + i
+ if j >= len(drv_secc_colors)-1: break
+ def blink_percent_fn(t, T):
+ return -4/T**2 * t**2 + 4/T * t
+ blink_percent = blink_percent_fn(offset, blink_dur_frames)
+ secc = batch['drv_secc'][j]
+ out_secc = blink_eye_for_secc(secc, blink_percent)
+ out_secc = out_secc.cuda()
+ batch['drv_secc'][j] = out_secc
+
+ # get the drv_kp for torso model, using the transformed trajectory
+ drv_kp = self.face3d_helper.reconstruct_lm2d(id, exp, euler, trans) # [T, 68, 2]
+
+ drv_kp = (drv_kp-0.5) / 0.5 # rescale to -1~1
+ batch['drv_kp'] = torch.clamp(drv_kp, -1, 1)
+ return batch
+
+ @torch.no_grad()
+ def forward_secc2video(self, batch, inp=None):
+ num_frames = len(batch['drv_secc'])
+ camera = batch['camera']
+ src_kps = batch['src_kp']
+ drv_kps = batch['drv_kp']
+ cano_secc_color = batch['cano_secc']
+ src_secc_color = batch['src_secc']
+ drv_secc_colors = batch['drv_secc']
+ ref_img_gt = batch['ref_gt_img']
+ ref_img_head = batch['ref_head_img']
+ ref_torso_img = batch['ref_torso_img']
+ bg_img = batch['bg_img']
+ segmap = batch['segmap']
+
+ # smooth torso drv_kp
+ torso_smo_ksize = 7
+ drv_kps = smooth_features_xd(drv_kps.reshape([-1, 68*2]), kernel_size=torso_smo_ksize).reshape([-1, 68, 2])
+
+ # forward renderer
+ img_raw_lst = []
+ img_lst = []
+ depth_img_lst = []
+ with torch.no_grad():
+ with torch.cuda.amp.autocast(inp['fp16']):
+ for i in tqdm.trange(num_frames, desc="Real3D-Portrait is rendering frames"):
+ kp_src = torch.cat([src_kps[i:i+1].reshape([1, 68, 2]), torch.zeros([1, 68,1]).to(src_kps.device)],dim=-1)
+ kp_drv = torch.cat([drv_kps[i:i+1].reshape([1, 68, 2]), torch.zeros([1, 68,1]).to(drv_kps.device)],dim=-1)
+ cond={'cond_cano': cano_secc_color,'cond_src': src_secc_color, 'cond_tgt': drv_secc_colors[i:i+1].cuda(),
+ 'ref_torso_img': ref_torso_img, 'bg_img': bg_img, 'segmap': segmap,
+ 'kp_s': kp_src, 'kp_d': kp_drv,
+ 'ref_cameras': camera[i:i+1],
+ }
+ if i == 0:
+ gen_output = self.secc2video_model.forward(img=ref_img_head, camera=camera[i:i+1], cond=cond, ret={}, cache_backbone=True, use_cached_backbone=False)
+ else:
+ gen_output = self.secc2video_model.forward(img=ref_img_head, camera=camera[i:i+1], cond=cond, ret={}, cache_backbone=False, use_cached_backbone=True)
+ img_lst.append(gen_output['image'])
+ img_raw_lst.append(gen_output['image_raw'])
+ depth_img_lst.append(gen_output['image_depth'])
+
+ # save demo video
+ depth_imgs = torch.cat(depth_img_lst)
+ imgs = torch.cat(img_lst)
+ imgs_raw = torch.cat(img_raw_lst)
+ secc_img = torch.cat([torch.nn.functional.interpolate(drv_secc_colors[i:i+1], (512,512)) for i in range(num_frames)])
+
+ if inp['out_mode'] == 'concat_debug':
+ secc_img = secc_img.cpu()
+ secc_img = ((secc_img + 1) * 127.5).permute(0, 2, 3, 1).int().numpy()
+
+ depth_img = F.interpolate(depth_imgs, (512,512)).cpu()
+ depth_img = depth_img.repeat([1,3,1,1])
+ depth_img = (depth_img - depth_img.min()) / (depth_img.max() - depth_img.min())
+ depth_img = depth_img * 2 - 1
+ depth_img = depth_img.clamp(-1,1)
+
+ secc_img = secc_img / 127.5 - 1
+ secc_img = torch.from_numpy(secc_img).permute(0, 3, 1, 2)
+ imgs = torch.cat([ref_img_gt.repeat([imgs.shape[0],1,1,1]).cpu(), secc_img, F.interpolate(imgs_raw, (512,512)).cpu(), depth_img, imgs.cpu()], dim=-1)
+ elif inp['out_mode'] == 'final':
+ imgs = imgs.cpu()
+ elif inp['out_mode'] == 'debug':
+ raise NotImplementedError("to do: save separate videos")
+ imgs = imgs.clamp(-1,1)
+
+ import imageio
+ debug_name = 'demo.mp4'
+ out_imgs = ((imgs.permute(0, 2, 3, 1) + 1)/2 * 255).int().cpu().numpy().astype(np.uint8)
+ writer = imageio.get_writer(debug_name, fps=25, format='FFMPEG', codec='h264')
+
+ for i in tqdm.trange(len(out_imgs), desc="Imageio is saving video"):
+ writer.append_data(out_imgs[i])
+ writer.close()
+
+ out_fname = 'infer_out/tmp/' + os.path.basename(inp['src_image_name'])[:-4] + '_' + os.path.basename(inp['drv_pose_name'])[:-4] + '.mp4' if inp['out_name'] == '' else inp['out_name']
+ try:
+ os.makedirs(os.path.dirname(out_fname), exist_ok=True)
+ except: pass
+ if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
+ # cmd = f"ffmpeg -i {debug_name} -i {self.wav16k_name} -y -shortest {out_fname}"
+ cmd = f"ffmpeg -i {debug_name} -i {self.wav16k_name} -y -v quiet -shortest {out_fname}"
+ print(cmd)
+ os.system(cmd)
+ os.system(f"rm {debug_name}")
+ os.system(f"rm {self.wav16k_name}")
+ else:
+ ret = os.system(f"ffmpeg -i {debug_name} -i {inp['drv_audio_name']} -map 0:v -map 1:a -y -v quiet -shortest {out_fname}")
+ if ret != 0: # 没有成功从drv_audio_name里面提取到音频, 则直接输出无音频轨道的纯视频
+ os.system(f"mv {debug_name} {out_fname}")
+ print(f"Saved at {out_fname}")
+ return out_fname
+
+ @torch.no_grad()
+ def forward_system(self, batch, inp):
+ self.forward_audio2secc(batch, inp)
+ out_fname = self.forward_secc2video(batch, inp)
+ return out_fname
+
+ @classmethod
+ def example_run(cls, inp=None):
+ inp_tmp = {
+ 'drv_audio_name': 'data/raw/val_wavs/zozo.wav',
+ 'src_image_name': 'data/raw/val_imgs/Macron.png'
+ }
+ if inp is not None:
+ inp_tmp.update(inp)
+ inp = inp_tmp
+
+ infer_instance = cls(inp['a2m_ckpt'], inp['head_ckpt'], inp['torso_ckpt'], inp=inp)
+ infer_instance.infer_once(inp)
+
+ ##############
+ # IO-related
+ ##############
+ def save_wav16k(self, audio_name):
+ supported_types = ('.wav', '.mp3', '.mp4', '.avi')
+ assert audio_name.endswith(supported_types), f"Now we only support {','.join(supported_types)} as audio source!"
+ import uuid
+ wav16k_name = audio_name[:-4] + f'{uuid.uuid1()}_16k.wav'
+ self.wav16k_name = wav16k_name
+ extract_wav_cmd = f"ffmpeg -i {audio_name} -f wav -ar 16000 -v quiet -y {wav16k_name} -y"
+ # extract_wav_cmd = f"ffmpeg -i {audio_name} -f wav -ar 16000 -y {wav16k_name} -y"
+ print(extract_wav_cmd)
+ os.system(extract_wav_cmd)
+ print(f"Extracted wav file (16khz) from {audio_name} to {wav16k_name}.")
+
+ def get_f0(self, wav16k_name):
+ from data_gen.utils.process_audio.extract_mel_f0 import extract_mel_from_fname, extract_f0_from_wav_and_mel
+ wav, mel = extract_mel_from_fname(self.wav16k_name)
+ f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
+ f0 = f0.reshape([-1,1])
+ return f0
+
+if __name__ == '__main__':
+ import argparse, glob, tqdm
+ parser = argparse.ArgumentParser()
+ # parser.add_argument("--a2m_ckpt", default='checkpoints/240112_audio2secc/icl_audio2secc_vox2_cmlr') # checkpoints/0727_audio2secc/audio2secc_withlm2d100_randomframe
+ parser.add_argument("--a2m_ckpt", default='checkpoints/240126_real3dportrait_orig/audio2secc_vae') # checkpoints/0727_audio2secc/audio2secc_withlm2d100_randomframe
+ parser.add_argument("--head_ckpt", default='') # checkpoints/0729_th1kh/secc_img2plane checkpoints/0720_img2planes/secc_img2plane_two_stage
+ # parser.add_argument("--head_ckpt", default='checkpoints/240210_os_secc2plane/secc2plane_trigridv2_blink0.3_pertubeNone') # checkpoints/0729_th1kh/secc_img2plane checkpoints/0720_img2planes/secc_img2plane_two_stage
+ # parser.add_argument("--torso_ckpt", default='')
+ # parser.add_argument("--torso_ckpt", default='checkpoints/240209_robust_secc2plane_torso/secc2plane_torso_orig_fuseV1_MulMaskFalse')
+ parser.add_argument("--torso_ckpt", default='checkpoints/240211_robust_secc2plane_torso/secc2plane_torso_orig_fuseV3_MulMaskTrue')
+ # parser.add_argument("--torso_ckpt", default='checkpoints/240209_robust_secc2plane_torso/secc2plane_torso_orig_fuseV2_MulMaskTrue')
+ # parser.add_argument("--src_img", default='data/raw/val_imgs/Macron_img.png')
+ # parser.add_argument("--src_img", default='gf2_iclr_test_data/cross_imgs/Trump.png')
+ parser.add_argument("--src_img", default='data/raw/val_imgs/mercy.png')
+ parser.add_argument("--bg_img", default='') # data/raw/val_imgs/bg3.png
+ parser.add_argument("--drv_aud", default='data/raw/val_wavs/yiwise.wav')
+ parser.add_argument("--drv_pose", default='infer_out/240319_tta/trump.mp4') # nearest | topk | random | static | vid_name
+ # parser.add_argument("--drv_pose", default='nearest') # nearest | topk | random | static | vid_name
+ parser.add_argument("--drv_style", default='') # nearest | topk | random | static | vid_name
+ # parser.add_argument("--drv_style", default='infer_out/240319_tta/trump.mp4') # nearest | topk | random | static | vid_name
+ parser.add_argument("--blink_mode", default='period') # none | period
+ parser.add_argument("--temperature", default=0.2, type=float) # nearest | random
+ parser.add_argument("--denoising_steps", default=20, type=int) # nearest | random
+ parser.add_argument("--cfg_scale", default=2.5, type=float) # nearest | random
+ parser.add_argument("--mouth_amp", default=0.4, type=float) # scale of predicted mouth, enabled in audio-driven
+ parser.add_argument("--min_face_area_percent", default=0.2, type=float) # scale of predicted mouth, enabled in audio-driven
+ parser.add_argument("--head_torso_threshold", default=0.5, type=float, help="0.1~1.0, 如果发现头发有半透明的现象,调小该值,以将小weights的头发直接clamp到weights=1.0; 如果发现头外部有荧光色的虚影,调小这个值. 对不同超参的Nerf也是case-to-case")
+ # parser.add_argument("--head_torso_threshold", default=None, type=float, help="0.1~1.0, 如果发现头发有半透明的现象,调小该值,以将小weights的头发直接clamp到weights=1.0; 如果发现头外部有荧光色的虚影,调小这个值. 对不同超参的Nerf也是case-to-case")
+ parser.add_argument("--out_name", default='') # nearest | random
+ parser.add_argument("--out_mode", default='concat_debug') # concat_debug | debug | final
+ parser.add_argument("--hold_eye_opened", default='False') # concat_debug | debug | final
+ parser.add_argument("--map_to_init_pose", default='True') # concat_debug | debug | final
+ parser.add_argument("--seed", default=None, type=int) # random seed, default None to use time.time()
+ parser.add_argument("--fp16", action='store_true')
+
+ args = parser.parse_args()
+
+ inp = {
+ 'a2m_ckpt': args.a2m_ckpt,
+ 'head_ckpt': args.head_ckpt,
+ 'torso_ckpt': args.torso_ckpt,
+ 'src_image_name': args.src_img,
+ 'bg_image_name': args.bg_img,
+ 'drv_audio_name': args.drv_aud,
+ 'drv_pose_name': args.drv_pose,
+ 'drv_talking_style_name': args.drv_style,
+ 'blink_mode': args.blink_mode,
+ 'temperature': args.temperature,
+ 'mouth_amp': args.mouth_amp,
+ 'out_name': args.out_name,
+ 'out_mode': args.out_mode,
+ 'map_to_init_pose': args.map_to_init_pose,
+ 'hold_eye_opened': args.hold_eye_opened,
+ 'head_torso_threshold': args.head_torso_threshold,
+ 'min_face_area_percent': args.min_face_area_percent,
+ 'denoising_steps': args.denoising_steps,
+ 'cfg_scale': args.cfg_scale,
+ 'seed': args.seed,
+ 'fp16': args.fp16, # 目前的ckpt使用fp16会导致nan,发现是因为i2p模型的layernorm产生了单个nan导致的,在训练阶段也采用fp16可能可以解决这个问题
+ }
+
+ GeneFace2Infer.example_run(inp)
\ No newline at end of file
diff --git a/inference/real3dportrait_demo.ipynb b/inference/real3dportrait_demo.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..03a4b31c363b71c90efc16b1abde1cc13d59263e
--- /dev/null
+++ b/inference/real3dportrait_demo.ipynb
@@ -0,0 +1,287 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "view-in-github"
+ },
+ "source": [
+ "

"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "QS04K9oO21AW"
+ },
+ "source": [
+ "Check GPU"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "1ESQRDb-yVUG"
+ },
+ "outputs": [],
+ "source": [
+ "!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "y-ctmIvu3Ei8"
+ },
+ "source": [
+ "Installation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "gXu76wdDgaxo"
+ },
+ "outputs": [],
+ "source": [
+ "# install pytorch3d, about 15s\n",
+ "import os\n",
+ "import sys\n",
+ "import torch\n",
+ "need_pytorch3d=False\n",
+ "try:\n",
+ " import pytorch3d\n",
+ "except ModuleNotFoundError:\n",
+ " need_pytorch3d=True\n",
+ "if need_pytorch3d:\n",
+ " if torch.__version__.startswith(\"2.1.\") and sys.platform.startswith(\"linux\"):\n",
+ " # We try to install PyTorch3D via a released wheel.\n",
+ " pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
+ " version_str=\"\".join([\n",
+ " f\"py3{sys.version_info.minor}_cu\",\n",
+ " torch.version.cuda.replace(\".\",\"\"),\n",
+ " f\"_pyt{pyt_version_str}\"\n",
+ " ])\n",
+ " !pip install fvcore iopath\n",
+ " !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
+ " else:\n",
+ " # We try to install PyTorch3D from source.\n",
+ " !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "DuUynxmotG_-"
+ },
+ "outputs": [],
+ "source": [
+ "# install dependencies, about 5~10 min\n",
+ "!pip install tensorboard==2.13.0 tensorboardX==2.6.1\n",
+ "!pip install pyspy==0.1.1\n",
+ "!pip install protobuf==3.20.3\n",
+ "!pip install scipy==1.9.1\n",
+ "!pip install kornia==0.5.0\n",
+ "!pip install trimesh==3.22.0\n",
+ "!pip install einops==0.6.1 torchshow==0.5.1\n",
+ "!pip install imageio==2.31.1 imageio-ffmpeg==0.4.8\n",
+ "!pip install scikit-learn==1.3.0 scikit-image==0.21.0\n",
+ "!pip install av==10.0.0 lpips==0.1.4\n",
+ "!pip install timm==0.9.2 librosa==0.9.2\n",
+ "!pip install openmim==0.3.9\n",
+ "!mim install mmcv==2.1.0 # use mim to speed up installation for mmcv\n",
+ "!pip install transformers==4.33.2\n",
+ "!pip install pretrainedmodels==0.7.4\n",
+ "!pip install ninja==1.11.1\n",
+ "!pip install faiss-cpu==1.7.4\n",
+ "!pip install praat-parselmouth==0.4.3 moviepy==1.0.3\n",
+ "!pip install mediapipe==0.10.7\n",
+ "!pip install --upgrade attr\n",
+ "!pip install beartype==0.16.4 gateloop_transformer==0.4.0\n",
+ "!pip install torchode==0.2.0 torchdiffeq==0.2.3\n",
+ "!pip install hydra-core==1.3.2 pandas==2.1.3\n",
+ "!pip install pytorch_lightning==2.1.2\n",
+ "!pip install httpx==0.23.3\n",
+ "!pip install gradio==4.16.0\n",
+ "!pip install gdown\n",
+ "!pip install pyloudnorm webrtcvad pyworld==0.2.1rc0 pypinyin==0.42.0"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "0GLEV0HVu8rj"
+ },
+ "outputs": [],
+ "source": [
+ "# RESTART kernel to make sure runtime is correct if you meet runtime errors\n",
+ "# import os\n",
+ "# os.kill(os.getpid(), 9)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "5UfKHKrH6kcq"
+ },
+ "source": [
+ "Clone code and download checkpoints"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "-gfRsd9DwIgl"
+ },
+ "outputs": [],
+ "source": [
+ "# clone Real3DPortrait repo from github\n",
+ "!git clone https://github.com/yerfor/Real3DPortrait\n",
+ "%cd Real3DPortrait"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Yju8dQY7x5OS"
+ },
+ "outputs": [],
+ "source": [
+ "# download pretrained ckpts & third-party ckpts from google drive, about 1 min\n",
+ "!pip install --upgrade --no-cache-dir gdown\n",
+ "%cd deep_3drecon/BFM\n",
+ "!gdown https://drive.google.com/uc?id=1SPM3IHsyNAaVMwqZZGV6QVaV7I2Hly0v\n",
+ "!gdown https://drive.google.com/uc?id=1MSldX9UChKEb3AXLVTPzZQcsbGD4VmGF\n",
+ "!gdown https://drive.google.com/uc?id=180ciTvm16peWrcpl4DOekT9eUQ-lJfMU\n",
+ "!gdown https://drive.google.com/uc?id=1KX9MyGueFB3M-X0Ss152x_johyTXHTfU\n",
+ "!gdown https://drive.google.com/uc?id=19-NyZn_I0_mkF-F5GPyFMwQJ_-WecZIL\n",
+ "!gdown https://drive.google.com/uc?id=11ouQ7Wr2I-JKStp2Fd1afedmWeuifhof\n",
+ "!gdown https://drive.google.com/uc?id=18ICIvQoKX-7feYWP61RbpppzDuYTptCq\n",
+ "!gdown https://drive.google.com/uc?id=1VktuY46m0v_n_d4nvOupauJkK4LF6mHE\n",
+ "%cd ../..\n",
+ "\n",
+ "%cd checkpoints\n",
+ "!gdown https://drive.google.com/uc?id=1gz8A6xestHp__GbZT5qozb43YaybRJhZ\n",
+ "!gdown https://drive.google.com/uc?id=1gSUIw2AkkKnlLJnNfS2FCqtaVw9tw3QF\n",
+ "!unzip 240210_real3dportrait_orig.zip\n",
+ "!unzip pretrained_ckpts.zip\n",
+ "!ls\n",
+ "%cd ..\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "LHzLro206pnA"
+ },
+ "source": [
+ "Inference sample"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "!python inference/real3d_infer.py -h"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "2aCDwxNivQoS"
+ },
+ "outputs": [],
+ "source": [
+ "# sample inference, about 3 min\n",
+ "!python inference/real3d_infer.py \\\n",
+ "--src_img data/raw/examples/Macron.png \\\n",
+ "--drv_aud data/raw/examples/Obama_5s.wav \\\n",
+ "--drv_pose data/raw/examples/May_5s.mp4 \\\n",
+ "--bg_img data/raw/examples/bg.png \\\n",
+ "--out_name output.mp4 \\\n",
+ "--out_mode concat_debug \\\n",
+ "--low_memory_usage"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "XL0c54l19mBG"
+ },
+ "source": [
+ "Display output video"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "6olmWwZP9Icj"
+ },
+ "outputs": [],
+ "source": [
+ "# borrow code from makeittalk\n",
+ "from IPython.display import HTML\n",
+ "from base64 import b64encode\n",
+ "import os, sys\n",
+ "import glob\n",
+ "\n",
+ "mp4_name = './output.mp4'\n",
+ "\n",
+ "mp4 = open('{}'.format(mp4_name),'rb').read()\n",
+ "data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
+ "\n",
+ "print('Display animation: {}'.format(mp4_name), file=sys.stderr)\n",
+ "display(HTML(\"\"\"\n",
+ "
\n",
+ " \"\"\" % data_url))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "WebUI"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "!python inference/app_real3dportrait.py --share"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "authorship_tag": "ABX9TyPu++zOlOS4yKF4xn4FHGtZ",
+ "gpuType": "T4",
+ "include_colab_link": true,
+ "private_outputs": true,
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/inference/train_mimictalk_on_a_video.py b/inference/train_mimictalk_on_a_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..98096580e968b61fa72c72df0919832d1b333fc4
--- /dev/null
+++ b/inference/train_mimictalk_on_a_video.py
@@ -0,0 +1,608 @@
+"""
+将One-shot的说话人大模型(os_secc2plane or os_secc2plane_torso)在单一说话人(一张照片或一段视频)上overfit, 实现和GeneFace++类似的效果
+"""
+import os
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import librosa
+import random
+import time
+import numpy as np
+import importlib
+import tqdm
+import copy
+import cv2
+import glob
+import imageio
+# common utils
+from utils.commons.hparams import hparams, set_hparams
+from utils.commons.tensor_utils import move_to_cuda, convert_to_tensor
+from utils.commons.ckpt_utils import load_ckpt, get_last_checkpoint
+# 3DMM-related utils
+from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
+from data_util.face3d_helper import Face3DHelper
+from data_gen.utils.process_image.fit_3dmm_landmark import fit_3dmm_for_a_image
+from data_gen.utils.process_video.fit_3dmm_landmark import fit_3dmm_for_a_video
+from data_gen.utils.process_video.extract_segment_imgs import decode_segmap_mask_from_image
+from deep_3drecon.secc_renderer import SECC_Renderer
+from data_gen.eg3d.convert_to_eg3d_convention import get_eg3d_convention_camera_pose_intrinsic
+from data_gen.runs.binarizer_nerf import get_lip_rect
+# Face Parsing
+from data_gen.utils.mp_feature_extractors.mp_segmenter import MediapipeSegmenter
+from data_gen.utils.process_video.extract_segment_imgs import inpaint_torso_job, extract_background
+# other inference utils
+from inference.infer_utils import mirror_index, load_img_to_512_hwc_array, load_img_to_normalized_512_bchw_tensor
+from inference.infer_utils import smooth_camera_sequence, smooth_features_xd
+from inference.edit_secc import blink_eye_for_secc, hold_eye_opened_for_secc
+from modules.commons.loralib.utils import mark_only_lora_as_trainable
+from utils.nn.model_utils import num_params
+import lpips
+from utils.commons.meters import AvgrageMeter
+meter = AvgrageMeter()
+from torch.utils.tensorboard import SummaryWriter
+class LoRATrainer(nn.Module):
+ def __init__(self, inp):
+ super().__init__()
+ self.inp = inp
+ self.lora_args = {'lora_mode': inp['lora_mode'], 'lora_r': inp['lora_r']}
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ head_model_dir = inp['head_ckpt']
+ torso_model_dir = inp['torso_ckpt']
+ model_dir = torso_model_dir if torso_model_dir != '' else head_model_dir
+ cmd = f"cp {os.path.join(model_dir, 'config.yaml')} {self.inp['work_dir']}"
+ print(cmd)
+ os.system(cmd)
+ with open(os.path.join(self.inp['work_dir'], 'config.yaml'), "a") as f:
+ f.write(f"\nlora_r: {inp['lora_r']}")
+ f.write(f"\nlora_mode: {inp['lora_mode']}")
+ f.write(f"\n")
+ self.secc2video_model = self.load_secc2video(model_dir)
+ self.secc2video_model.to(device).eval()
+ self.seg_model = MediapipeSegmenter()
+ self.secc_renderer = SECC_Renderer(512)
+ self.face3d_helper = Face3DHelper(use_gpu=True, keypoint_mode='lm68')
+ self.mp_face3d_helper = Face3DHelper(use_gpu=True, keypoint_mode='mediapipe')
+ # self.camera_selector = KNearestCameraSelector()
+ self.load_training_data(inp)
+ def load_secc2video(self, model_dir):
+ inp = self.inp
+ from modules.real3d.secc_img2plane_torso import OSAvatarSECC_Img2plane, OSAvatarSECC_Img2plane_Torso
+ hp = set_hparams(f"{model_dir}/config.yaml", print_hparams=False, global_hparams=True)
+ hp['htbsr_head_threshold'] = 1.0
+ self.neural_rendering_resolution = hp['neural_rendering_resolution']
+ if 'torso' in hp['task_cls'].lower():
+ self.torso_mode = True
+ model = OSAvatarSECC_Img2plane_Torso(hp=hp, lora_args=self.lora_args)
+ else:
+ self.torso_mode = False
+ model = OSAvatarSECC_Img2plane(hp=hp, lora_args=self.lora_args)
+ mark_only_lora_as_trainable(model, bias='none')
+ lora_ckpt_path = os.path.join(inp['work_dir'], 'checkpoint.ckpt')
+ if os.path.exists(lora_ckpt_path):
+ self.learnable_triplane = nn.Parameter(torch.zeros([1, 3, model.triplane_hid_dim*model.triplane_depth, 256, 256]).float().cuda(), requires_grad=True)
+ model._last_cano_planes = self.learnable_triplane
+ load_ckpt(model, lora_ckpt_path, model_name='model', strict=False)
+ else:
+ load_ckpt(model, f"{model_dir}", model_name='model', strict=False)
+
+ num_params(model)
+ self.model = model
+ return model
+ def load_training_data(self, inp):
+ video_id = inp['video_id']
+ if video_id.endswith((".mp4", ".png", ".jpg", ".jpeg")):
+ # If input video is not GeneFace training videos, convert it into GeneFace convention
+ video_id_ = video_id
+ video_id = os.path.basename(video_id)[:-4]
+ inp['video_id'] = video_id
+ target_video_path = f'data/raw/videos/{video_id}.mp4'
+ if not os.path.exists(target_video_path):
+ print(f"| Copying video to {target_video_path}")
+ os.makedirs(os.path.dirname(target_video_path), exist_ok=True)
+ cmd = f"ffmpeg -i {video_id_} -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -y {target_video_path}"
+ print(f"| {cmd}")
+ os.system(cmd)
+ target_video_path = f'data/raw/videos/{video_id}.mp4'
+ print(f"| Copy source video into work dir: {self.inp['work_dir']}")
+ os.system(f"cp {target_video_path} {self.inp['work_dir']}")
+ # check head_img path
+ head_img_pattern = f'data/processed/videos/{video_id}/head_imgs/*.png'
+ head_img_names = sorted(glob.glob(head_img_pattern))
+ if len(head_img_names) == 0:
+ # extract head_imgs
+ head_img_dir = os.path.dirname(head_img_pattern)
+ print(f"| Pre-extracted head_imgs not found, try to extract and save to {head_img_dir}, this may take a while...")
+ gt_img_dir = f"data/processed/videos/{video_id}/gt_imgs"
+ os.makedirs(gt_img_dir, exist_ok=True)
+ target_video_path = f'data/raw/videos/{video_id}.mp4'
+ cmd = f"ffmpeg -i {target_video_path} -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 -y {gt_img_dir}/%08d.jpg"
+ print(f"| {cmd}")
+ os.system(cmd)
+ # extract image, segmap, and background
+ cmd = f"python data_gen/utils/process_video/extract_segment_imgs.py --ds_name=nerf --vid_dir={target_video_path}"
+ print(f"| {cmd}")
+ os.system(cmd)
+ print("| Head images Extracted!")
+ num_samples = len(head_img_names)
+ npy_name = f"data/processed/videos/{video_id}/coeff_fit_mp_for_lora.npy"
+ if os.path.exists(npy_name):
+ coeff_dict = np.load(npy_name, allow_pickle=True).tolist()
+ else:
+ print(f"| Pre-extracted 3DMM coefficient not found, try to extract and save to {npy_name}, this may take a while...")
+ coeff_dict = fit_3dmm_for_a_video(f'data/raw/videos/{video_id}.mp4', save=False)
+ os.makedirs(os.path.dirname(npy_name), exist_ok=True)
+ np.save(npy_name, coeff_dict)
+ ids = convert_to_tensor(coeff_dict['id']).reshape([-1,80]).cuda()
+ exps = convert_to_tensor(coeff_dict['exp']).reshape([-1,64]).cuda()
+ eulers = convert_to_tensor(coeff_dict['euler']).reshape([-1,3]).cuda()
+ trans = convert_to_tensor(coeff_dict['trans']).reshape([-1,3]).cuda()
+ WH = 512 # now we only support 512x512
+ lm2ds = WH * self.face3d_helper.reconstruct_lm2d(ids, exps, eulers, trans).cpu().numpy()
+ lip_rects = [get_lip_rect(lm2ds[i], WH, WH) for i in range(len(lm2ds))]
+ kps = self.face3d_helper.reconstruct_lm2d(ids, exps, eulers, trans).cuda()
+ kps = (kps-0.5) / 0.5 # rescale to -1~1
+ kps = torch.cat([kps, torch.zeros([*kps.shape[:-1], 1]).cuda()], dim=-1)
+ camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler': torch.tensor(coeff_dict['euler']).reshape([-1,3]), 'trans': torch.tensor(coeff_dict['trans']).reshape([-1,3])})
+ c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics']
+ cameras = torch.tensor(np.concatenate([c2w.reshape([-1,16]), intrinsics.reshape([-1,9])], axis=-1)).cuda()
+ camera_smo_ksize = 7
+ cameras = smooth_camera_sequence(cameras.cpu().numpy(), kernel_size=camera_smo_ksize) # [T, 25]
+ cameras = torch.tensor(cameras).cuda()
+ zero_eulers = eulers * 0
+ zero_trans = trans * 0
+ _, cano_secc_color = self.secc_renderer(ids[0:1], exps[0:1]*0, zero_eulers[0:1], zero_trans[0:1])
+ src_idx = 0
+ _, src_secc_color = self.secc_renderer(ids[0:1], exps[src_idx:src_idx+1], zero_eulers[0:1], zero_trans[0:1])
+ drv_secc_colors = [None for _ in range(len(exps))]
+ drv_head_imgs = [None for _ in range(len(exps))]
+ drv_torso_imgs = [None for _ in range(len(exps))]
+ drv_com_imgs = [None for _ in range(len(exps))]
+ segmaps = [None for _ in range(len(exps))]
+ img_name = f'data/processed/videos/{video_id}/bg.jpg'
+ bg_img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
+ ds = {
+ 'id': ids.cuda().float(),
+ 'exps': exps.cuda().float(),
+ 'eulers': eulers.cuda().float(),
+ 'trans': trans.cuda().float(),
+ 'cano_secc_color': cano_secc_color.cuda().float(),
+ 'src_secc_color': src_secc_color.cuda().float(),
+ 'cameras': cameras.float(),
+ 'video_id': video_id,
+ 'lip_rects': lip_rects,
+ 'head_imgs': drv_head_imgs,
+ 'torso_imgs': drv_torso_imgs,
+ 'com_imgs': drv_com_imgs,
+ 'bg_img': bg_img,
+ 'segmaps': segmaps,
+ 'kps': kps,
+ }
+ self.ds = ds
+ return ds
+
+ def training_loop(self, inp):
+ trainer = self
+ video_id = self.ds['video_id']
+ lora_params = [p for k, p in self.secc2video_model.named_parameters() if 'lora_' in k]
+ self.criterion_lpips = lpips.LPIPS(net='alex',lpips=True).cuda()
+ self.logger = SummaryWriter(log_dir=inp['work_dir'])
+ if not hasattr(self, 'learnable_triplane'):
+ src_idx = 0 # init triplane from the first frame's prediction
+ self.learnable_triplane = nn.Parameter(torch.zeros([1, 3, self.secc2video_model.triplane_hid_dim*self.secc2video_model.triplane_depth, 256, 256]).float().cuda(), requires_grad=True)
+ img_name = f'data/processed/videos/{video_id}/head_imgs/{format(src_idx, "08d")}.png'
+ img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float().cuda().float() # [3, H, W]
+ cano_plane = self.secc2video_model.cal_cano_plane(img.unsqueeze(0)) # [1, 3, CD, h, w]
+ self.learnable_triplane.data = cano_plane.data
+ self.secc2video_model._last_cano_planes = self.learnable_triplane
+ if len(lora_params) == 0:
+ self.optimizer = torch.optim.AdamW([self.learnable_triplane], lr=inp['lr_triplane'], weight_decay=0.01, betas=(0.9,0.98))
+ else:
+ self.optimizer = torch.optim.Adam(lora_params, lr=inp['lr'], betas=(0.9,0.98))
+ self.optimizer.add_param_group({
+ 'params': [self.learnable_triplane],
+ 'lr': inp['lr_triplane'],
+ 'betas': (0.9, 0.98)
+ })
+
+ ids = self.ds['id']
+ exps = self.ds['exps']
+ zero_eulers = self.ds['eulers']*0
+ zero_trans = self.ds['trans']*0
+ num_updates = inp['max_updates']
+ batch_size = inp['batch_size'] # 1 for lower gpu mem usage
+ num_samples = len(self.ds['cameras'])
+ init_plane = self.learnable_triplane.detach().clone()
+ if num_samples <= 5:
+ lambda_reg_triplane = 1.0
+ elif num_samples <= 250:
+ lambda_reg_triplane = 0.1
+ else:
+ lambda_reg_triplane = 0.
+ for i_step in tqdm.trange(num_updates+1,desc="training lora..."):
+ milestone_steps = []
+ # milestone_steps = [100, 200, 500]
+ if i_step % 2000 == 0 or i_step in milestone_steps:
+ trainer.test_loop(inp, step=i_step)
+ if i_step != 0:
+ filepath = os.path.join(inp['work_dir'], f"model_ckpt_steps_{i_step}.ckpt")
+ checkpoint = self.dump_checkpoint(inp)
+ tmp_path = str(filepath) + ".part"
+ torch.save(checkpoint, tmp_path, _use_new_zipfile_serialization=False)
+ os.replace(tmp_path, filepath)
+
+ drv_idx = [random.randint(0, num_samples-1) for _ in range(batch_size)]
+ drv_secc_colors = []
+ gt_imgs = []
+ head_imgs = []
+ segmaps_0 = []
+ segmaps = []
+ torso_imgs = []
+ drv_lip_rects = []
+ kp_src = []
+ kp_drv = []
+ for di in drv_idx:
+ # 读取target image
+ if self.torso_mode:
+ if self.ds['com_imgs'][di] is None:
+ # img_name = f'data/processed/videos/{video_id}/gt_imgs/{format(di, "08d")}.jpg'
+ img_name = f'data/processed/videos/{video_id}/com_imgs/{format(di, "08d")}.jpg'
+ img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
+ self.ds['com_imgs'][di] = img
+ gt_imgs.append(self.ds['com_imgs'][di])
+ else:
+ if self.ds['head_imgs'][di] is None:
+ img_name = f'data/processed/videos/{video_id}/head_imgs/{format(di, "08d")}.png'
+ img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
+ self.ds['head_imgs'][di] = img
+ gt_imgs.append(self.ds['head_imgs'][di])
+ if self.ds['head_imgs'][di] is None:
+ img_name = f'data/processed/videos/{video_id}/head_imgs/{format(di, "08d")}.png'
+ img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
+ self.ds['head_imgs'][di] = img
+ head_imgs.append(self.ds['head_imgs'][di])
+ # 使用第一帧的torso作为face v2v的输入
+ if self.ds['torso_imgs'][0] is None:
+ img_name = f'data/processed/videos/{video_id}/inpaint_torso_imgs/{format(0, "08d")}.png'
+ img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
+ self.ds['torso_imgs'][0] = img
+ torso_imgs.append(self.ds['torso_imgs'][0])
+ # 所以segmap也用第一帧的了
+ if self.ds['segmaps'][0] is None:
+ img_name = f'data/processed/videos/{video_id}/segmaps/{format(0, "08d")}.png'
+ seg_img = cv2.imread(img_name)[:,:, ::-1]
+ segmap = torch.from_numpy(decode_segmap_mask_from_image(seg_img)) # [6, H, W]
+ self.ds['segmaps'][0] = segmap
+ segmaps_0.append(self.ds['segmaps'][0])
+ if self.ds['segmaps'][di] is None:
+ img_name = f'data/processed/videos/{video_id}/segmaps/{format(di, "08d")}.png'
+ seg_img = cv2.imread(img_name)[:,:, ::-1]
+ segmap = torch.from_numpy(decode_segmap_mask_from_image(seg_img)) # [6, H, W]
+ self.ds['segmaps'][di] = segmap
+ segmaps.append(self.ds['segmaps'][di])
+ _, secc_color = self.secc_renderer(ids[0:1], exps[di:di+1], zero_eulers[0:1], zero_trans[0:1])
+ drv_secc_colors.append(secc_color)
+ drv_lip_rects.append(self.ds['lip_rects'][di])
+ kp_src.append(self.ds['kps'][0])
+ kp_drv.append(self.ds['kps'][di])
+ bg_img = self.ds['bg_img'].unsqueeze(0).repeat([batch_size, 1, 1, 1]).cuda()
+ ref_torso_imgs = torch.stack(torso_imgs).float().cuda()
+ kp_src = torch.stack(kp_src).float().cuda()
+ kp_drv = torch.stack(kp_drv).float().cuda()
+ segmaps = torch.stack(segmaps).float().cuda()
+ segmaps_0 = torch.stack(segmaps_0).float().cuda()
+ tgt_imgs = torch.stack(gt_imgs).float().cuda()
+ head_imgs = torch.stack(head_imgs).float().cuda()
+ drv_secc_color = torch.cat(drv_secc_colors)
+ cano_secc_color = self.ds['cano_secc_color'].repeat([batch_size, 1, 1, 1])
+ src_secc_color = self.ds['src_secc_color'].repeat([batch_size, 1, 1, 1])
+ cond = {'cond_cano': cano_secc_color,'cond_src': src_secc_color, 'cond_tgt': drv_secc_color,
+ 'ref_torso_img': ref_torso_imgs, 'bg_img': bg_img,
+ 'segmap': segmaps_0, # v2v使用第一帧的torso作为source image来warp
+ 'kp_s': kp_src, 'kp_d': kp_drv}
+ camera = self.ds['cameras'][drv_idx]
+ gen_output = self.secc2video_model.forward(img=None, camera=camera, cond=cond, ret={}, cache_backbone=False, use_cached_backbone=True)
+ pred_imgs = gen_output['image']
+ pred_imgs_raw = gen_output['image_raw']
+
+ losses = {}
+ loss_weights = {
+ 'v2v_occlusion_reg_l1_loss': 0.001, # loss for face_vid2vid-based torso
+ 'v2v_occlusion_2_reg_l1_loss': 0.001, # loss for face_vid2vid-based torso
+ 'v2v_occlusion_2_weights_entropy_loss': hparams['lam_occlusion_weights_entropy'], # loss for face_vid2vid-based torso
+ 'density_weight_l2_loss': 0.01, # supervised density
+ 'density_weight_entropy_loss': 0.001, # keep the density change sharp
+ 'mse_loss': 1.,
+ 'head_mse_loss': 0.2, # loss on neural rendering low-reso pred_img
+ 'lip_mse_loss': 1.0,
+ 'lpips_loss': 0.5,
+ 'head_lpips_loss': 0.1,
+ 'lip_lpips_loss': 1.0, # make the teeth more clear
+ 'blink_reg_loss': 0.003, # increase it when you find head shake while blinking; decrease it when you find the eye cannot closed.
+ 'triplane_reg_loss': lambda_reg_triplane,
+ 'secc_reg_loss': 0.01, # used to reduce flicking
+ }
+
+ occlusion_reg_l1 = gen_output.get("losses", {}).get('facev2v/occlusion_reg_l1', 0.)
+ occlusion_2_reg_l1 = gen_output.get("losses", {}).get('facev2v/occlusion_2_reg_l1', 0.)
+ occlusion_2_weights_entropy = gen_output.get("losses", {}).get('facev2v/occlusion_2_weights_entropy', 0.)
+ losses['v2v_occlusion_reg_l1_loss'] = occlusion_reg_l1
+ losses['v2v_occlusion_2_reg_l1_loss'] = occlusion_2_reg_l1
+ losses['v2v_occlusion_2_weights_entropy_loss'] = occlusion_2_weights_entropy
+
+ # Weights Reg loss in torso
+ neural_rendering_reso = self.neural_rendering_resolution
+ alphas = gen_output['weights_img'].clamp(1e-5, 1 - 1e-5)
+ loss_weights_entropy = torch.mean(- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas))
+ mv_head_masks = segmaps[:, [1,3,5]].sum(dim=1)
+ mv_head_masks_raw = F.interpolate(mv_head_masks.unsqueeze(1), size=(neural_rendering_reso,neural_rendering_reso)).squeeze(1)
+ face_mask = mv_head_masks_raw.bool().unsqueeze(1)
+ nonface_mask = ~ face_mask
+ loss_weights_l2_loss = (alphas[nonface_mask]-0).pow(2).mean() + (alphas[face_mask]-1).pow(2).mean()
+ losses['density_weight_l2_loss'] = loss_weights_l2_loss
+ losses['density_weight_entropy_loss'] = loss_weights_entropy
+
+ mse_loss = (pred_imgs - tgt_imgs).abs().mean()
+ head_mse_loss = (pred_imgs_raw - F.interpolate(head_imgs, size=(neural_rendering_reso,neural_rendering_reso), mode='bilinear', antialias=True)).abs().mean()
+ lpips_loss = self.criterion_lpips(pred_imgs, tgt_imgs).mean()
+ head_lpips_loss = self.criterion_lpips(pred_imgs_raw, F.interpolate(head_imgs, size=(neural_rendering_reso,neural_rendering_reso), mode='bilinear', antialias=True)).mean()
+ lip_mse_loss = 0
+ lip_lpips_loss = 0
+ for i in range(len(drv_idx)):
+ xmin, xmax, ymin, ymax = drv_lip_rects[i]
+ lip_tgt_imgs = tgt_imgs[i:i+1,:, ymin:ymax,xmin:xmax].contiguous()
+ lip_pred_imgs = pred_imgs[i:i+1,:, ymin:ymax,xmin:xmax].contiguous()
+ try:
+ lip_mse_loss = lip_mse_loss + (lip_pred_imgs - lip_tgt_imgs).abs().mean()
+ lip_lpips_loss = lip_lpips_loss + self.criterion_lpips(lip_pred_imgs, lip_tgt_imgs).mean()
+ except: pass
+ losses['mse_loss'] = mse_loss
+ losses['head_mse_loss'] = head_mse_loss
+ losses['lpips_loss'] = lpips_loss
+ losses['head_lpips_loss'] = head_lpips_loss
+ losses['lip_mse_loss'] = lip_mse_loss
+ losses['lip_lpips_loss'] = lip_lpips_loss
+
+ # eye blink reg loss
+ if i_step % 4 == 0:
+ blink_secc_lst1 = []
+ blink_secc_lst2 = []
+ blink_secc_lst3 = []
+ for i in range(len(drv_secc_color)):
+ secc = drv_secc_color[i]
+ blink_percent1 = random.random() * 0.5 # 0~0.5
+ blink_percent3 = 0.5 + random.random() * 0.5 # 0.5~1.0
+ blink_percent2 = (blink_percent1 + blink_percent3)/2
+ try:
+ out_secc1 = blink_eye_for_secc(secc, blink_percent1).to(secc.device)
+ out_secc2 = blink_eye_for_secc(secc, blink_percent2).to(secc.device)
+ out_secc3 = blink_eye_for_secc(secc, blink_percent3).to(secc.device)
+ except:
+ print("blink eye for secc failed, use original secc")
+ out_secc1 = copy.deepcopy(secc)
+ out_secc2 = copy.deepcopy(secc)
+ out_secc3 = copy.deepcopy(secc)
+ blink_secc_lst1.append(out_secc1)
+ blink_secc_lst2.append(out_secc2)
+ blink_secc_lst3.append(out_secc3)
+ src_secc_color1 = torch.stack(blink_secc_lst1)
+ src_secc_color2 = torch.stack(blink_secc_lst2)
+ src_secc_color3 = torch.stack(blink_secc_lst3)
+ blink_cond1 = {'cond_cano': cano_secc_color, 'cond_src': src_secc_color, 'cond_tgt': src_secc_color1}
+ blink_cond2 = {'cond_cano': cano_secc_color, 'cond_src': src_secc_color, 'cond_tgt': src_secc_color2}
+ blink_cond3 = {'cond_cano': cano_secc_color, 'cond_src': src_secc_color, 'cond_tgt': src_secc_color3}
+ blink_secc_plane1 = self.model.cal_secc_plane(blink_cond1)
+ blink_secc_plane2 = self.model.cal_secc_plane(blink_cond2)
+ blink_secc_plane3 = self.model.cal_secc_plane(blink_cond3)
+ interpolate_blink_secc_plane = (blink_secc_plane1 + blink_secc_plane3) / 2
+ blink_reg_loss = torch.nn.functional.l1_loss(blink_secc_plane2, interpolate_blink_secc_plane)
+ losses['blink_reg_loss'] = blink_reg_loss
+
+ # Triplane Reg loss
+ triplane_reg_loss = (self.learnable_triplane - init_plane).abs().mean()
+ losses['triplane_reg_loss'] = triplane_reg_loss
+
+
+ ref_id = self.ds['id'][0:1]
+ secc_pertube_randn_scale = hparams['secc_pertube_randn_scale']
+ perturbed_id = ref_id + torch.randn_like(ref_id) * secc_pertube_randn_scale
+ drv_exp = self.ds['exps'][drv_idx]
+ perturbed_exp = drv_exp + torch.randn_like(drv_exp) * secc_pertube_randn_scale
+ zero_euler = torch.zeros([len(drv_idx), 3], device=ref_id.device, dtype=ref_id.dtype)
+ zero_trans = torch.zeros([len(drv_idx), 3], device=ref_id.device, dtype=ref_id.dtype)
+ perturbed_secc = self.secc_renderer(perturbed_id, perturbed_exp, zero_euler, zero_trans)[1]
+ secc_reg_loss = torch.nn.functional.l1_loss(drv_secc_color, perturbed_secc)
+ losses['secc_reg_loss'] = secc_reg_loss
+
+
+ total_loss = sum([loss_weights[k] * v for k, v in losses.items() if isinstance(v, torch.Tensor) and v.requires_grad])
+ # Update weights
+ self.optimizer.zero_grad()
+ total_loss.backward()
+ self.learnable_triplane.grad.data = self.learnable_triplane.grad.data * self.learnable_triplane.numel()
+ self.optimizer.step()
+ meter.update(total_loss.item())
+ if i_step % 10 == 0:
+ log_line = f"Iter {i_step+1}: total_loss={meter.avg} "
+ for k, v in losses.items():
+ log_line = log_line + f" {k}={v.item()}, "
+ self.logger.add_scalar(f"train/{k}", v.item(), i_step)
+ print(log_line)
+ meter.reset()
+ @torch.no_grad()
+ def test_loop(self, inp, step=''):
+ self.model.eval()
+ # coeff_dict = np.load('data/processed/videos/Lieu/coeff_fit_mp_for_lora.npy', allow_pickle=True).tolist()
+ # drv_exps = torch.tensor(coeff_dict['exp']).cuda().float()
+ drv_exps = self.ds['exps']
+ zero_eulers = self.ds['eulers']*0
+ zero_trans = self.ds['trans']*0
+ batch_size = 1
+ num_samples = len(self.ds['cameras'])
+ video_writer = imageio.get_writer(os.path.join(inp['work_dir'], f'val_step{step}.mp4'), fps=25)
+ total_iters = min(num_samples, 250)
+ video_id = inp['video_id']
+ for i in tqdm.trange(total_iters,desc="testing lora..."):
+ drv_idx = [i]
+ drv_secc_colors = []
+ gt_imgs = []
+ segmaps = []
+ torso_imgs = []
+ drv_lip_rects = []
+ kp_src = []
+ kp_drv = []
+ for di in drv_idx:
+ # 读取target image
+ if self.torso_mode:
+ if self.ds['com_imgs'][di] is None:
+ img_name = f'data/processed/videos/{video_id}/com_imgs/{format(di, "08d")}.jpg'
+ img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
+ self.ds['com_imgs'][di] = img
+ gt_imgs.append(self.ds['com_imgs'][di])
+ else:
+ if self.ds['head_imgs'][di] is None:
+ img_name = f'data/processed/videos/{video_id}/head_imgs/{format(di, "08d")}.png'
+ img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
+ self.ds['head_imgs'][di] = img
+ gt_imgs.append(self.ds['head_imgs'][di])
+ # 使用第一帧的torso作为face v2v的输入
+ if self.ds['torso_imgs'][0] is None:
+ img_name = f'data/processed/videos/{video_id}/inpaint_torso_imgs/{format(0, "08d")}.png'
+ img = torch.tensor(cv2.imread(img_name)[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
+ self.ds['torso_imgs'][0] = img
+ torso_imgs.append(self.ds['torso_imgs'][0])
+ # 所以segmap也用第一帧的了
+ if self.ds['segmaps'][0] is None:
+ img_name = f'data/processed/videos/{video_id}/segmaps/{format(0, "08d")}.png'
+ seg_img = cv2.imread(img_name)[:,:, ::-1]
+ segmap = torch.from_numpy(decode_segmap_mask_from_image(seg_img)) # [6, H, W]
+ self.ds['segmaps'][0] = segmap
+ segmaps.append(self.ds['segmaps'][0])
+ drv_lip_rects.append(self.ds['lip_rects'][di])
+ kp_src.append(self.ds['kps'][0])
+ kp_drv.append(self.ds['kps'][di])
+ bg_img = self.ds['bg_img'].unsqueeze(0).repeat([batch_size, 1, 1, 1]).cuda()
+ ref_torso_imgs = torch.stack(torso_imgs).float().cuda()
+ kp_src = torch.stack(kp_src).float().cuda()
+ kp_drv = torch.stack(kp_drv).float().cuda()
+ segmaps = torch.stack(segmaps).float().cuda()
+ tgt_imgs = torch.stack(gt_imgs).float().cuda()
+ for di in drv_idx:
+ _, secc_color = self.secc_renderer(self.ds['id'][0:1], drv_exps[di:di+1], zero_eulers[0:1], zero_trans[0:1])
+ drv_secc_colors.append(secc_color)
+ drv_secc_color = torch.cat(drv_secc_colors)
+ cano_secc_color = self.ds['cano_secc_color'].repeat([batch_size, 1, 1, 1])
+ src_secc_color = self.ds['src_secc_color'].repeat([batch_size, 1, 1, 1])
+ cond = {'cond_cano': cano_secc_color,'cond_src': src_secc_color, 'cond_tgt': drv_secc_color,
+ 'ref_torso_img': ref_torso_imgs, 'bg_img': bg_img, 'segmap': segmaps,
+ 'kp_s': kp_src, 'kp_d': kp_drv}
+ camera = self.ds['cameras'][drv_idx]
+ gen_output = self.secc2video_model.forward(img=None, camera=camera, cond=cond, ret={}, cache_backbone=False, use_cached_backbone=True)
+ pred_img = gen_output['image']
+ pred_img = ((pred_img.permute(0, 2, 3, 1) + 1)/2 * 255).int().cpu().numpy().astype(np.uint8)
+ video_writer.append_data(pred_img[0])
+ video_writer.close()
+ self.model.train()
+
+ def masked_error_loss(self, img_pred, img_gt, mask, unmasked_weight=0.1, mode='l1'):
+ # 对raw图像,因为deform的原因背景没法全黑,导致这部分mse过高,我们将其mask掉,只计算人脸部分
+ masked_weight = 1.0
+ weight_mask = mask.float() * masked_weight + (~mask).float() * unmasked_weight
+ if mode == 'l1':
+ error = (img_pred - img_gt).abs().sum(dim=1) * weight_mask
+ else:
+ error = (img_pred - img_gt).pow(2).sum(dim=1) * weight_mask
+ error.clamp_(0, max(0.5, error.quantile(0.8).item())) # clamp掉较高loss的pixel,避免姿态没对齐的pixel导致的异常值占主导影响训练
+ loss = error.mean()
+ return loss
+
+ def dilate(self, bin_img, ksize=5, mode='max_pool'):
+ """
+ mode: max_pool or avg_pool
+ """
+ # bin_img, [1, h, w]
+ pad = (ksize-1)//2
+ bin_img = F.pad(bin_img, pad=[pad,pad,pad,pad], mode='reflect')
+ if mode == 'max_pool':
+ out = F.max_pool2d(bin_img, kernel_size=ksize, stride=1, padding=0)
+ else:
+ out = F.avg_pool2d(bin_img, kernel_size=ksize, stride=1, padding=0)
+ return out
+
+ def dilate_mask(self, mask, ksize=21):
+ mask = self.dilate(mask, ksize=ksize, mode='max_pool')
+ return mask
+
+ def set_unmasked_to_black(self, img, mask):
+ out_img = img * mask.float() - (~mask).float() # -1 denotes black
+ return out_img
+
+ def dump_checkpoint(self, inp):
+ checkpoint = {}
+ # save optimizers
+ optimizer_states = []
+ self.optimizers = [self.optimizer]
+ for i, optimizer in enumerate(self.optimizers):
+ if optimizer is not None:
+ state_dict = optimizer.state_dict()
+ state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
+ optimizer_states.append(state_dict)
+ checkpoint['optimizer_states'] = optimizer_states
+ state_dict = {
+ 'model': self.model.state_dict(),
+ 'learnable_triplane': self.model.state_dict()['_last_cano_planes'],
+ }
+ del state_dict['model']['_last_cano_planes']
+ checkpoint['state_dict'] = state_dict
+ checkpoint['lora_args'] = self.lora_args
+ person_ds = {}
+ video_id = inp['video_id']
+ img_name = f'data/processed/videos/{video_id}/gt_imgs/{format(0, "08d")}.jpg'
+ gt_img = torch.tensor(cv2.resize(cv2.imread(img_name), (512, 512))[..., ::-1] / 127.5 - 1).permute(2,0,1).float() # [3, H, W]
+ person_ds['gt_img'] = gt_img.reshape([1, 3, 512, 512])
+ person_ds['id'] = self.ds['id'].cpu().reshape([1, 80])
+ person_ds['src_kp'] = self.ds['kps'][0].cpu()
+ person_ds['video_id'] = inp['video_id']
+ checkpoint['person_ds'] = person_ds
+ return checkpoint
+if __name__ == '__main__':
+ import argparse, glob, tqdm
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--head_ckpt", default='') # checkpoints/0729_th1kh/secc_img2plane checkpoints/0720_img2planes/secc_img2plane_two_stage
+ # parser.add_argument("--torso_ckpt", default='checkpoints/240210_real3dportrait_orig/secc2plane_torso_orig') # checkpoints/0729_th1kh/secc_img2plane checkpoints/0720_img2planes/secc_img2plane_two_stage
+ parser.add_argument("--torso_ckpt", default='checkpoints/mimictalk_orig/os_secc2plane_torso') # checkpoints/0729_th1kh/secc_img2plane checkpoints/0720_img2planes/secc_img2plane_two_stage
+ parser.add_argument("--video_id", default='data/raw/examples/GER.mp4', help="identity source, we support (1) already processed
of GeneFace, (2) video path, (3) image path")
+ parser.add_argument("--work_dir", default=None)
+ parser.add_argument("--max_updates", default=10000, type=int, help="for video, 2000 is good; for an image, 3~10 is good")
+ parser.add_argument("--test", action='store_true')
+ parser.add_argument("--batch_size", default=1, type=int, help="batch size during training, 1 needs 8GB, 2 needs 15GB")
+ parser.add_argument("--lr", default=0.001)
+ parser.add_argument("--lr_triplane", default=0.005, help="for video, 0.1; for an image, 0.001; for ablation with_triplane, 0.")
+ parser.add_argument("--lora_r", default=2, type=int, help="width of lora unit")
+ parser.add_argument("--lora_mode", default='secc2plane_sr', help='for video, full; for an image, none')
+
+ args = parser.parse_args()
+ inp = {
+ 'head_ckpt': args.head_ckpt,
+ 'torso_ckpt': args.torso_ckpt,
+ 'video_id': args.video_id,
+ 'work_dir': args.work_dir,
+ 'max_updates': args.max_updates,
+ 'batch_size': args.batch_size,
+ 'test': args.test,
+ 'lr': float(args.lr),
+ 'lr_triplane': float(args.lr_triplane),
+ 'lora_mode': args.lora_mode,
+ 'lora_r': args.lora_r,
+ }
+ if inp['work_dir'] == None:
+ video_id = os.path.basename(inp['video_id'])[:-4] if inp['video_id'].endswith((".mp4", ".png", ".jpg", ".jpeg")) else inp['video_id']
+ inp['work_dir'] = f'checkpoints_mimictalk/{video_id}'
+ os.makedirs(inp['work_dir'], exist_ok=True)
+ trainer = LoRATrainer(inp)
+ if inp['test']:
+ trainer.test_loop(inp)
+ else:
+ trainer.training_loop(inp)
+ trainer.test_loop(inp)
+ print(" ")
\ No newline at end of file