| """run bash scripts/download_models.sh first to prepare the weights file""" |
| import os |
| import shutil |
| from argparse import Namespace |
| from src.utils.preprocess import CropAndExtract |
| from src.test_audio2coeff import Audio2Coeff |
| from src.facerender.animate import AnimateFromCoeff |
| from src.generate_batch import get_data |
| from src.generate_facerender_batch import get_facerender_data |
| from src.utils.init_path import init_path |
| from cog import BasePredictor, Input, Path |
|
|
| checkpoints = "checkpoints" |
|
|
|
|
| class Predictor(BasePredictor): |
| def setup(self): |
| """Load the model into memory to make running multiple predictions efficient""" |
| device = "cuda" |
|
|
| |
| sadtalker_paths = init_path(checkpoints,os.path.join("src","config")) |
|
|
| |
| self.preprocess_model = CropAndExtract(sadtalker_paths, device |
| ) |
|
|
| self.audio_to_coeff = Audio2Coeff( |
| sadtalker_paths, |
| device, |
| ) |
|
|
| self.animate_from_coeff = { |
| "full": AnimateFromCoeff( |
| sadtalker_paths, |
| device, |
| ), |
| "others": AnimateFromCoeff( |
| sadtalker_paths, |
| device, |
| ), |
| } |
|
|
| def predict( |
| self, |
| source_image: Path = Input( |
| description="Upload the source image, it can be video.mp4 or picture.png", |
| ), |
| driven_audio: Path = Input( |
| description="Upload the driven audio, accepts .wav and .mp4 file", |
| ), |
| enhancer: str = Input( |
| description="Choose a face enhancer", |
| choices=["gfpgan", "RestoreFormer"], |
| default="gfpgan", |
| ), |
| preprocess: str = Input( |
| description="how to preprocess the images", |
| choices=["crop", "resize", "full"], |
| default="full", |
| ), |
| ref_eyeblink: Path = Input( |
| description="path to reference video providing eye blinking", |
| default=None, |
| ), |
| ref_pose: Path = Input( |
| description="path to reference video providing pose", |
| default=None, |
| ), |
| still: bool = Input( |
| description="can crop back to the original videos for the full body aniamtion when preprocess is full", |
| default=True, |
| ), |
| ) -> Path: |
| """Run a single prediction on the model""" |
|
|
| animate_from_coeff = ( |
| self.animate_from_coeff["full"] |
| if preprocess == "full" |
| else self.animate_from_coeff["others"] |
| ) |
|
|
| args = load_default() |
| args.pic_path = str(source_image) |
| args.audio_path = str(driven_audio) |
| device = "cuda" |
| args.still = still |
| args.ref_eyeblink = None if ref_eyeblink is None else str(ref_eyeblink) |
| args.ref_pose = None if ref_pose is None else str(ref_pose) |
|
|
| |
| results_dir = "results" |
| if os.path.exists(results_dir): |
| shutil.rmtree(results_dir) |
| os.makedirs(results_dir) |
| first_frame_dir = os.path.join(results_dir, "first_frame_dir") |
| os.makedirs(first_frame_dir) |
|
|
| print("3DMM Extraction for source image") |
| first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate( |
| args.pic_path, first_frame_dir, preprocess, source_image_flag=True |
| ) |
| if first_coeff_path is None: |
| print("Can't get the coeffs of the input") |
| return |
|
|
| if ref_eyeblink is not None: |
| ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[ |
| 0 |
| ] |
| ref_eyeblink_frame_dir = os.path.join(results_dir, ref_eyeblink_videoname) |
| os.makedirs(ref_eyeblink_frame_dir, exist_ok=True) |
| print("3DMM Extraction for the reference video providing eye blinking") |
| ref_eyeblink_coeff_path, _, _ = self.preprocess_model.generate( |
| ref_eyeblink, ref_eyeblink_frame_dir |
| ) |
| else: |
| ref_eyeblink_coeff_path = None |
|
|
| if ref_pose is not None: |
| if ref_pose == ref_eyeblink: |
| ref_pose_coeff_path = ref_eyeblink_coeff_path |
| else: |
| ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0] |
| ref_pose_frame_dir = os.path.join(results_dir, ref_pose_videoname) |
| os.makedirs(ref_pose_frame_dir, exist_ok=True) |
| print("3DMM Extraction for the reference video providing pose") |
| ref_pose_coeff_path, _, _ = self.preprocess_model.generate( |
| ref_pose, ref_pose_frame_dir |
| ) |
| else: |
| ref_pose_coeff_path = None |
|
|
| |
| batch = get_data( |
| first_coeff_path, |
| args.audio_path, |
| device, |
| ref_eyeblink_coeff_path, |
| still=still, |
| ) |
| coeff_path = self.audio_to_coeff.generate( |
| batch, results_dir, args.pose_style, ref_pose_coeff_path |
| ) |
| |
| print("coeff2video") |
| data = get_facerender_data( |
| coeff_path, |
| crop_pic_path, |
| first_coeff_path, |
| args.audio_path, |
| args.batch_size, |
| args.input_yaw, |
| args.input_pitch, |
| args.input_roll, |
| expression_scale=args.expression_scale, |
| still_mode=still, |
| preprocess=preprocess, |
| ) |
| animate_from_coeff.generate( |
| data, results_dir, args.pic_path, crop_info, |
| enhancer=enhancer, background_enhancer=args.background_enhancer, |
| preprocess=preprocess) |
|
|
| output = "/tmp/out.mp4" |
| mp4_path = os.path.join(results_dir, [f for f in os.listdir(results_dir) if "enhanced.mp4" in f][0]) |
| shutil.copy(mp4_path, output) |
|
|
| return Path(output) |
|
|
|
|
| def load_default(): |
| return Namespace( |
| pose_style=0, |
| batch_size=2, |
| expression_scale=1.0, |
| input_yaw=None, |
| input_pitch=None, |
| input_roll=None, |
| background_enhancer=None, |
| face3dvis=False, |
| net_recon="resnet50", |
| init_path=None, |
| use_last_fc=False, |
| bfm_folder="./src/config/", |
| bfm_model="BFM_model_front.mat", |
| focal=1015.0, |
| center=112.0, |
| camera_d=10.0, |
| z_near=5.0, |
| z_far=15.0, |
| ) |
|
|