| |
| import os |
| os.environ['HYDRA_FULL_ERROR']='1' |
| os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
|
|
| from huggingface_hub import snapshot_download |
|
|
| |
| snapshot_download( |
| repo_id = "ChaolongYang/KDTalker", |
| local_dir = "./" |
| ) |
|
|
| import argparse |
| import shutil |
| import uuid |
| import os |
| import numpy as np |
| from tqdm import tqdm |
| import cv2 |
| from rich.progress import track |
| import tyro |
|
|
| import gradio as gr |
| from PIL import Image |
| import time |
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
| import imageio |
| from pydub import AudioSegment |
| from pykalman import KalmanFilter |
|
|
|
|
| from src.config.argument_config import ArgumentConfig |
| from src.config.inference_config import InferenceConfig |
| from src.config.crop_config import CropConfig |
| from src.live_portrait_pipeline import LivePortraitPipeline |
| from src.utils.camera import get_rotation_matrix |
| from dataset_process import audio |
|
|
| from dataset_process.croper import Croper |
|
|
| import spaces |
|
|
| def parse_audio_length(audio_length, sr, fps): |
| bit_per_frames = sr / fps |
| num_frames = int(audio_length / bit_per_frames) |
| audio_length = int(num_frames * bit_per_frames) |
| return audio_length, num_frames |
|
|
| def crop_pad_audio(wav, audio_length): |
| if len(wav) > audio_length: |
| wav = wav[:audio_length] |
| elif len(wav) < audio_length: |
| wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0) |
| return wav |
|
|
| class Conv2d(nn.Module): |
| def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act=True, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.conv_block = nn.Sequential( |
| nn.Conv2d(cin, cout, kernel_size, stride, padding), |
| nn.BatchNorm2d(cout) |
| ) |
| self.act = nn.ReLU() |
| self.residual = residual |
| self.use_act = use_act |
|
|
| def forward(self, x): |
| out = self.conv_block(x) |
| if self.residual: |
| out += x |
|
|
| if self.use_act: |
| return self.act(out) |
| else: |
| return out |
|
|
| class AudioEncoder(nn.Module): |
| def __init__(self, wav2lip_checkpoint, device): |
| super(AudioEncoder, self).__init__() |
|
|
| self.audio_encoder = nn.Sequential( |
| Conv2d(1, 32, kernel_size=3, stride=1, padding=1), |
| Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), |
| Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), |
|
|
| Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), |
| Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), |
| Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), |
|
|
| Conv2d(64, 128, kernel_size=3, stride=3, padding=1), |
| Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), |
| Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), |
|
|
| Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), |
| Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), |
|
|
| Conv2d(256, 512, kernel_size=3, stride=1, padding=0), |
| Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) |
|
|
| |
| wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict'] |
| state_dict = self.audio_encoder.state_dict() |
|
|
| for k,v in wav2lip_state_dict.items(): |
| if 'audio_encoder' in k: |
| state_dict[k.replace('module.audio_encoder.', '')] = v |
| self.audio_encoder.load_state_dict(state_dict) |
|
|
| def forward(self, audio_sequences): |
| B = audio_sequences.size(0) |
|
|
| audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) |
|
|
| audio_embedding = self.audio_encoder(audio_sequences) |
| dim = audio_embedding.shape[1] |
| audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1)) |
|
|
| return audio_embedding.squeeze(-1).squeeze(-1) |
|
|
| def partial_fields(target_class, kwargs): |
| return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)}) |
|
|
| def dct2device(dct: dict, device): |
| for key in dct: |
| dct[key] = torch.tensor(dct[key]).to(device) |
| return dct |
|
|
| def save_video_with_watermark(video, audio, save_path): |
| temp_file = str(uuid.uuid4())+'.mp4' |
| cmd = r'ffmpeg -y -i "%s" -i "%s" -vcodec copy "%s"' % (video, audio, temp_file) |
| os.system(cmd) |
| shutil.move(temp_file, save_path) |
|
|
| class Inferencer(object): |
| def __init__(self): |
| st=time.time() |
| print('#'*25+'Start initialization'+'#'*25) |
| self.device = 'cuda' |
|
|
| from model import get_model |
| self.point_diffusion = get_model() |
| ckpt = torch.load('ckpts/KDTalker.pth') |
|
|
| self.point_diffusion.load_state_dict(ckpt['model']) |
| self.point_diffusion.eval() |
| self.point_diffusion.to(self.device) |
|
|
| lm_croper_checkpoint = 'ckpts/shape_predictor_68_face_landmarks.dat' |
| self.croper = Croper(lm_croper_checkpoint) |
|
|
| self.norm_info = dict(np.load('dataset_process/norm.npz')) |
|
|
| wav2lip_checkpoint = 'ckpts/wav2lip.pth' |
| self.wav2lip_model = AudioEncoder(wav2lip_checkpoint, 'cuda') |
| self.wav2lip_model.cuda() |
| self.wav2lip_model.eval() |
|
|
| |
| tyro.extras.set_accent_color("bright_cyan") |
| args = tyro.cli(ArgumentConfig) |
|
|
| |
| self.inf_cfg = partial_fields(InferenceConfig, args.__dict__) |
| self.crop_cfg = partial_fields(CropConfig, args.__dict__) |
|
|
| self.live_portrait_pipeline = LivePortraitPipeline(inference_cfg=self.inf_cfg, crop_cfg=self.crop_cfg) |
|
|
| def _norm(self, data_dict): |
| for k in data_dict.keys(): |
| if k in ['yaw', 'pitch', 'roll', 't', 'exp', 'scale', 'kp', ]: |
| v=data_dict[k] |
| data_dict[k] = (v - self.norm_info[k+'_mean'])/self.norm_info[k+'_std'] |
| return data_dict |
|
|
| def _denorm(self, data_dict): |
| for k in data_dict.keys(): |
| if k in ['yaw', 'pitch', 'roll', 't', 'exp', 'scale', 'kp']: |
| v=data_dict[k] |
| data_dict[k] = v * self.norm_info[k+'_std'] + self.norm_info[k+'_mean'] |
| return data_dict |
|
|
| def output_to_dict(self, data): |
| output = {} |
| output['scale'] = data[:, 0] |
| output['yaw'] = data[:, 1, None] |
| output['pitch'] = data[:, 2, None] |
| output['roll'] = data[:, 3, None] |
| output['t'] = data[:, 4:7] |
| output['exp'] = data[:, 7:] |
| return output |
|
|
| def extract_mel_from_audio(self, audio_file_path): |
| syncnet_mel_step_size = 16 |
| fps = 25 |
| wav = audio.load_wav(audio_file_path, 16000) |
| wav_length, num_frames = parse_audio_length(len(wav), 16000, 25) |
| wav = crop_pad_audio(wav, wav_length) |
| orig_mel = audio.melspectrogram(wav).T |
| spec = orig_mel.copy() |
| indiv_mels = [] |
|
|
| for i in tqdm(range(num_frames), 'mel:'): |
| start_frame_num = i - 2 |
| start_idx = int(80. * (start_frame_num / float(fps))) |
| end_idx = start_idx + syncnet_mel_step_size |
| seq = list(range(start_idx, end_idx)) |
| seq = [min(max(item, 0), orig_mel.shape[0] - 1) for item in seq] |
| m = spec[seq, :] |
| indiv_mels.append(m.T) |
| indiv_mels = np.asarray(indiv_mels) |
| return indiv_mels |
|
|
| def extract_wav2lip_from_audio(self, audio_file_path): |
| asd_mel = self.extract_mel_from_audio(audio_file_path) |
| asd_mel = torch.FloatTensor(asd_mel).cuda().unsqueeze(0).unsqueeze(2) |
| with torch.no_grad(): |
| hidden = self.wav2lip_model(asd_mel) |
| return hidden[0].cpu().detach().numpy() |
|
|
| def headpose_pred_to_degree(self, pred): |
| device = pred.device |
| idx_tensor = [idx for idx in range(66)] |
| idx_tensor = torch.FloatTensor(idx_tensor).to(device) |
| pred = F.softmax(pred) |
| degree = torch.sum(pred * idx_tensor, 1) * 3 - 99 |
| return degree |
|
|
| @torch.no_grad() |
| def generate_with_audio_img(self, image_path, audio_path, save_path): |
| image = np.array(Image.open(image_path).convert('RGB')) |
| cropped_image, crop, quad = self.croper.crop([image], still=False, xsize=512) |
| input_image = cv2.resize(cropped_image[0], (256, 256)) |
|
|
| I_s = torch.FloatTensor(input_image.transpose((2, 0, 1))).unsqueeze(0).cuda() / 255 |
|
|
| x_s_info = self.live_portrait_pipeline.live_portrait_wrapper.get_kp_info(I_s) |
| x_c_s = x_s_info['kp'].reshape(1, 21, -1) |
| R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) |
| f_s = self.live_portrait_pipeline.live_portrait_wrapper.extract_feature_3d(I_s) |
| x_s = self.live_portrait_pipeline.live_portrait_wrapper.transform_keypoint(x_s_info) |
|
|
| |
| kp_info = {} |
| for k in x_s_info.keys(): |
| kp_info[k] = x_s_info[k].cpu().numpy() |
|
|
| kp_info = self._norm(kp_info) |
|
|
| ori_kp = torch.cat([torch.zeros([1, 7]), torch.Tensor(kp_info['kp'])], -1).cuda() |
|
|
| input_x = np.concatenate([kp_info[k] for k in ['scale', 'yaw', 'pitch', 'roll', 't', 'exp']], 1) |
| input_x = np.expand_dims(input_x, -1) |
| input_x = np.expand_dims(input_x, 0) |
| input_x = np.concatenate([input_x, input_x, input_x], -1) |
|
|
| aud_feat = self.extract_wav2lip_from_audio(audio_path) |
|
|
| sample_frame = 64 |
| padding_size = (sample_frame - aud_feat.shape[0] % sample_frame) % sample_frame |
|
|
| if padding_size > 0: |
| aud_feat = np.concatenate((aud_feat, aud_feat[:padding_size, :]), axis=0) |
| else: |
| aud_feat = aud_feat |
|
|
| outputs = [input_x] |
|
|
| sample_frame = 64 |
| for i in range(0, aud_feat.shape[0] - 1, sample_frame): |
| input_mel = torch.Tensor(aud_feat[i: i + sample_frame]).unsqueeze(0).cuda() |
| kp0 = torch.Tensor(outputs[-1])[:, -1].cuda() |
| pred_kp = self.point_diffusion.forward_sample(70, ref_kps=kp0, ori_kps=ori_kp, aud_feat=input_mel, |
| scheduler='ddim', num_inference_steps=50) |
| outputs.append(pred_kp.cpu().numpy()) |
|
|
| outputs = np.mean(np.concatenate(outputs, 1)[0, 1:aud_feat.shape[0] - padding_size + 1], -1) |
| output_dict = self.output_to_dict(outputs) |
| output_dict = self._denorm(output_dict) |
|
|
| num_frame = output_dict['yaw'].shape[0] |
| x_d_info = {} |
| for key in output_dict: |
| x_d_info[key] = torch.tensor(output_dict[key]).cuda() |
|
|
| |
| def smooth(sequence, n_dim_state=1): |
| kf = KalmanFilter(initial_state_mean=sequence[0], |
| transition_covariance=0.05 * np.eye(n_dim_state), |
| observation_covariance=0.001 * np.eye(n_dim_state)) |
| state_means, _ = kf.smooth(sequence) |
| return state_means |
|
|
| yaw_data = x_d_info['yaw'].cpu().numpy() |
| pitch_data = x_d_info['pitch'].cpu().numpy() |
| roll_data = x_d_info['roll'].cpu().numpy() |
| t_data = x_d_info['t'].cpu().numpy() |
| exp_data = x_d_info['exp'].cpu().numpy() |
|
|
| smoothed_pitch = smooth(pitch_data, n_dim_state=1) |
| smoothed_yaw = smooth(yaw_data, n_dim_state=1) |
| smoothed_roll = smooth(roll_data, n_dim_state=1) |
| smoothed_t = smooth(t_data, n_dim_state=3) |
| smoothed_exp = smooth(exp_data, n_dim_state=63) |
|
|
| x_d_info['pitch'] = torch.Tensor(smoothed_pitch).cuda() |
| x_d_info['yaw'] = torch.Tensor(smoothed_yaw).cuda() |
| x_d_info['roll'] = torch.Tensor(smoothed_roll).cuda() |
| x_d_info['t'] = torch.Tensor(smoothed_t).cuda() |
| x_d_info['exp'] = torch.Tensor(smoothed_exp).cuda() |
|
|
| template_dct = {'motion': [], 'c_d_eyes_lst': [], 'c_d_lip_lst': []} |
| for i in track(range(num_frame), description='Making motion templates...', total=num_frame): |
| x_d_i_info = x_d_info |
| R_d_i = get_rotation_matrix(x_d_i_info['pitch'][i], x_d_i_info['yaw'][i], x_d_i_info['roll'][i]) |
|
|
| item_dct = { |
| 'scale': x_d_i_info['scale'][i].cpu().numpy().astype(np.float32), |
| 'R_d': R_d_i.cpu().numpy().astype(np.float32), |
| 'exp': x_d_i_info['exp'][i].reshape(1, 21, -1).cpu().numpy().astype(np.float32), |
| 't': x_d_i_info['t'][i].cpu().numpy().astype(np.float32), |
| } |
|
|
| template_dct['motion'].append(item_dct) |
|
|
| I_p_lst = [] |
| R_d_0, x_d_0_info = None, None |
|
|
| for i in track(range(num_frame), description='🚀Animating...', total=num_frame): |
| x_d_i_info = template_dct['motion'][i] |
| for key in x_d_i_info: |
| x_d_i_info[key] = torch.tensor(x_d_i_info[key]).cuda() |
| R_d_i = x_d_i_info['R_d'] |
|
|
| if i == 0: |
| R_d_0 = R_d_i |
| x_d_0_info = x_d_i_info |
|
|
| if self.inf_cfg.flag_relative_motion: |
| R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s |
| delta_new = x_s_info['exp'].reshape(1, 21, -1) + (x_d_i_info['exp'] - x_d_0_info['exp']) |
| scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale']) |
| t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t']) |
| else: |
| R_new = R_d_i |
| delta_new = x_d_i_info['exp'] |
| scale_new = x_s_info['scale'] |
| t_new = x_d_i_info['t'] |
|
|
| t_new[..., 2].fill_(0) |
| x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new |
|
|
| out = self.live_portrait_pipeline.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new) |
| I_p_i = self.live_portrait_pipeline.live_portrait_wrapper.parse_output(out['out'])[0] |
| I_p_lst.append(I_p_i) |
|
|
| video_name = save_path.split('/')[-1] |
| video_save_dir = os.path.dirname(save_path) |
| path = os.path.join(video_save_dir, 'temp_' + video_name) |
|
|
| imageio.mimsave(path, I_p_lst, fps=float(25)) |
|
|
| audio_name = audio_path.split('/')[-1] |
| new_audio_path = os.path.join(video_save_dir, audio_name) |
| start_time = 0 |
| sound = AudioSegment.from_file(audio_path) |
| end_time = start_time + num_frame * 1 / 25 * 1000 |
| word1 = sound.set_frame_rate(16000) |
| word = word1[start_time:end_time] |
| word.export(new_audio_path, format="wav") |
|
|
| save_video_with_watermark(path, new_audio_path, save_path) |
| print(f'The generated video is named {video_save_dir}/{video_name}') |
|
|
| os.remove(path) |
| os.remove(new_audio_path) |
|
|
| @spaces.GPU() |
| def gradio_infer(source_image, driven_audio): |
| """ |
| Generate a talking-head video from a static source image and an audio file. |
| |
| This function uses a pre-trained motion and lip-sync model to animate a face image so that it |
| appears to speak in sync with a given audio clip. The resulting video is saved |
| and returned. |
| |
| Args: |
| source_image: A path to an input image (or uploaded image) of a person's face |
| that will be animated. |
| driven_audio: A path to an audio file (or uploaded audio) that will drive the |
| lip-sync and head movement of the animation. |
| |
| Returns: |
| A file path to the generated .mp4 video, which shows the input face animated |
| to speak and move in sync with the audio. |
| |
| Workflow: |
| 1. Load and initialize the animation pipeline (Inferencer). |
| 2. Process the image and audio. |
| 3. Generate a talking-head animation using lip-sync and motion synthesis models. |
| 4. Combine generated video frames with the original audio. |
| 5. Return the video path to be displayed or downloaded. |
| """ |
| |
| import tempfile |
| temp_dir = tempfile.mkdtemp() |
| output_path = f"{temp_dir}/output.mp4" |
| |
| Infer = Inferencer() |
| Infer.generate_with_audio_img(source_image, driven_audio, output_path) |
|
|
| return output_path |
|
|
| with gr.Blocks() as demo: |
| with gr.Column(): |
| gr.Markdown("# KDTalker") |
| gr.Markdown("Unlock Pose Diversity: Accurate and Efficient Implicit Keypoint-based Spatiotemporal Diffusion for Audio-driven Talking Portrait") |
| gr.HTML(""" |
| <div style="display:flex;column-gap:4px;"> |
| <a href="https://github.com/chaolongy/KDTalker"> |
| <img src='https://img.shields.io/badge/GitHub-Repo-blue'> |
| </a> |
| <a href="https://arxiv.org/abs/2503.12963"> |
| <img src='https://img.shields.io/badge/ArXiv-Paper-red'> |
| </a> |
| <a href="https://huggingface.co/spaces/fffiloni/KDTalker?duplicate=true"> |
| <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space"> |
| </a> |
| </div> |
| """) |
|
|
| with gr.Row(): |
|
|
| with gr.Column(): |
| source_image = gr.Image(label="Source Image", type="filepath") |
| driven_audio = gr.Audio(label="Driven Audio", type="filepath") |
| submit_btn = gr.Button("Submit") |
|
|
| gr.Examples( |
| examples = [ |
| ["example/source_image/WDA_BenCardin1_000.png", "example/audio_driven/WDA_BenCardin1_000.wav"], |
| |
| ], |
| inputs = [source_image, driven_audio], |
| cache_examples = False |
| ) |
|
|
| with gr.Column(): |
| output_video = gr.Video(label="Output Video") |
|
|
| submit_btn.click( |
| fn = gradio_infer, |
| inputs = [source_image, driven_audio], |
| outputs = [output_video] |
| ) |
|
|
| demo.launch(mcp_server=True) |
| |