Spaces:
Runtime error
Runtime error
| import argparse | |
| import platform | |
| import subprocess | |
| import time | |
| from pathlib import Path | |
| from typing import Dict, Iterator, List, Literal, Optional, Union | |
| import cv2 | |
| import numpy as np | |
| from config import hparams as hp | |
| from nota_wav2lip.inference import Wav2LipInferenceImpl | |
| from nota_wav2lip.util import FFMPEG_LOGGING_MODE | |
| from nota_wav2lip.video import AudioSlicer, VideoSlicer | |
| class Wav2LipModelComparisonDemo: | |
| def __init__(self, device='cpu', result_dir='./temp', model_list: Optional[Union[str, List[str]]]=None): | |
| if model_list is None: | |
| model_list: List[str] = ['wav2lip', 'nota_wav2lip'] | |
| if isinstance(model_list, str) and len(model_list) != 0: | |
| model_list: List[str] = [model_list] | |
| super().__init__() | |
| self.video_dict: Dict[str, VideoSlicer] = {} | |
| self.audio_dict: Dict[str, AudioSlicer] = {} | |
| self.model_zoo: Dict[str, Wav2LipInferenceImpl] = {} | |
| for model_name in model_list: | |
| assert model_name in hp.inference.model, f"{model_name} not in hp.inference_model: {hp.inference.model}" | |
| self.model_zoo[model_name] = Wav2LipInferenceImpl( | |
| model_name, hp_inference_model=hp.inference.model[model_name], device=device | |
| ) | |
| self._params_zoo: Dict[str, str] = { | |
| model_name: self.model_zoo[model_name].params for model_name in self.model_zoo | |
| } | |
| self.result_dir: Path = Path(result_dir) | |
| self.result_dir.mkdir(exist_ok=True) | |
| def params(self): | |
| return self._params_zoo | |
| def _infer( | |
| self, | |
| audio_name: str, | |
| video_name: str, | |
| model_type: Literal['wav2lip', 'nota_wav2lip'] | |
| ) -> Iterator[np.ndarray]: | |
| audio_iterable: AudioSlicer = self.audio_dict[audio_name] | |
| video_iterable: VideoSlicer = self.video_dict[video_name] | |
| target_model = self.model_zoo[model_type] | |
| return target_model.inference_with_iterator(audio_iterable, video_iterable) | |
| def update_audio(self, audio_path, name=None): | |
| _name = name if name is not None else Path(audio_path).stem | |
| self.audio_dict.update( | |
| {_name: AudioSlicer(audio_path)} | |
| ) | |
| def update_video(self, frame_dir_path, bbox_path, name=None): | |
| _name = name if name is not None else Path(frame_dir_path).stem | |
| self.video_dict.update( | |
| {_name: VideoSlicer(frame_dir_path, bbox_path)} | |
| ) | |
| def save_as_video(self, audio_name, video_name, model_type): | |
| output_video_path = self.result_dir / 'generated_with_audio.mp4' | |
| frame_only_video_path = self.result_dir / 'generated.mp4' | |
| audio_path = self.audio_dict[audio_name].audio_path | |
| out = cv2.VideoWriter(str(frame_only_video_path), | |
| cv2.VideoWriter_fourcc(*'mp4v'), | |
| hp.face.video_fps, | |
| (hp.inference.frame.w, hp.inference.frame.h)) | |
| start = time.time() | |
| for frame in self._infer(audio_name=audio_name, video_name=video_name, model_type=model_type): | |
| out.write(frame) | |
| inference_time = time.time() - start | |
| out.release() | |
| command = f"ffmpeg {FFMPEG_LOGGING_MODE['ERROR']} -y -i {audio_path} -i {frame_only_video_path} -strict -2 -q:v 1 {output_video_path}" | |
| subprocess.call(command, shell=platform.system() != 'Windows') | |
| # The number of frames of generated video | |
| video_frames_num = len(self.audio_dict[audio_name]) | |
| inference_fps = video_frames_num / inference_time | |
| return output_video_path, inference_time, inference_fps | |