|
|
import sys, os, subprocess, tempfile, uuid, torch, gradio as gr |
|
|
from pathlib import Path |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
sys.path.insert(0, '/workspace/SadTalker') |
|
|
sys.path.insert(0, '/workspace/SadTalker/src') |
|
|
|
|
|
|
|
|
from src.utils.preprocess import CropAndExtract |
|
|
from src.test_audio2coeff import Audio2Coeff |
|
|
from src.facerender.animate import AnimateFromCoeff |
|
|
from src.generate_batch import get_data |
|
|
from src.generate_facerender_batch import get_facerender_data |
|
|
from src.utils.init_path import init_path |
|
|
|
|
|
from TTS.api import TTS |
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
tts = TTS(model_name="tts_models/en/xtts_v2", progress_bar=False).to(DEVICE) |
|
|
|
|
|
class SadTalkerInference: |
|
|
def __init__(self, checkpoint_dir='./checkpoints', config_dir='./src/config', device='cuda'): |
|
|
self.device = device |
|
|
self.checkpoint_dir = checkpoint_dir |
|
|
self.config_dir = config_dir |
|
|
|
|
|
|
|
|
self.preprocess_model = CropAndExtract(self.checkpoint_dir, self.device) |
|
|
self.audio2coeff = Audio2Coeff(self.checkpoint_dir, self.device) |
|
|
self.animate_from_coeff = AnimateFromCoeff(self.checkpoint_dir, self.device) |
|
|
|
|
|
def generate(self, source_image, driven_audio, preprocess='crop', |
|
|
still_mode=False, use_enhancer=None, batch_size=1, |
|
|
size=512, pose_style=0, exp_scale=1.0, |
|
|
use_ref_video=False, ref_video=None, ref_info=None, |
|
|
use_idle_mode=False, length_of_audio=0, use_blink=True, |
|
|
result_dir='./results/'): |
|
|
|
|
|
|
|
|
save_dir = os.path.join(result_dir, 'temp_' + str(uuid.uuid4())[:8]) |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
input_dir = os.path.join(save_dir, 'input') |
|
|
os.makedirs(input_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
pic_path = os.path.join(input_dir, 'pic.jpg') |
|
|
if isinstance(source_image, str): |
|
|
source_image = source_image |
|
|
else: |
|
|
source_image.save(pic_path) |
|
|
source_image = pic_path |
|
|
|
|
|
|
|
|
first_frame_dir = os.path.join(save_dir, 'first_frame_dir') |
|
|
os.makedirs(first_frame_dir, exist_ok=True) |
|
|
|
|
|
first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate( |
|
|
source_image, first_frame_dir, preprocess, True, size) |
|
|
|
|
|
if first_coeff_path is None: |
|
|
raise ValueError("Can't get the coeffs of the input") |
|
|
|
|
|
|
|
|
batch = get_data(first_coeff_path, driven_audio, self.device, ref_eyeblink=None, still=still_mode, |
|
|
idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink) |
|
|
coeff_path = self.audio2coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path=None, audio_path=driven_audio, crop_info=crop_info) |
|
|
|
|
|
|
|
|
data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, driven_audio, |
|
|
batch_size, still_mode, None, None, expression_scale=exp_scale, |
|
|
input_yaw_list=None, input_pitch_list=None, input_roll_list=None) |
|
|
|
|
|
result = self.animate_from_coeff.generate(data, save_dir, pic_path, crop_info, enhancer=use_enhancer, |
|
|
background_enhancer=None, preprocess=preprocess, img_size=size) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
sadtalker = SadTalkerInference(device=DEVICE) |
|
|
|
|
|
def generate(image: Image, audio=None, text: str = ""): |
|
|
tmp = Path(tempfile.mkdtemp()) |
|
|
img_path = tmp / "input.png" |
|
|
image.convert("RGB").save(img_path) |
|
|
|
|
|
|
|
|
if audio is None and text.strip(): |
|
|
wav_path = tmp / "speech.wav" |
|
|
tts.tts_to_file(text=text, file_path=str(wav_path)) |
|
|
elif audio is not None: |
|
|
wav_path = Path(audio) |
|
|
else: |
|
|
raise gr.Error("Provide either audio or text!") |
|
|
|
|
|
|
|
|
try: |
|
|
result_path = sadtalker.generate( |
|
|
source_image=str(img_path), |
|
|
driven_audio=str(wav_path), |
|
|
preprocess='crop', |
|
|
still_mode=False, |
|
|
use_enhancer=None, |
|
|
size=512, |
|
|
result_dir=str(tmp) |
|
|
) |
|
|
vid_path = result_path |
|
|
except Exception as e: |
|
|
raise gr.Error(f"SadTalker generation failed: {str(e)}") |
|
|
|
|
|
|
|
|
upscaled_path = tmp / "up.mp4" |
|
|
try: |
|
|
subprocess.run([ |
|
|
"python", "/workspace/Real-ESRGAN/inference_realesrgan_video.py", |
|
|
"-n", "RealESRGAN_x4plus", |
|
|
"-i", str(vid_path), |
|
|
"-o", str(upscaled_path), |
|
|
"--fp32" |
|
|
], check=True) |
|
|
except Exception: |
|
|
upscaled_path = vid_path |
|
|
|
|
|
|
|
|
out_vid = tmp / "vertical.mp4" |
|
|
subprocess.run([ |
|
|
"ffmpeg", "-y", "-i", str(upscaled_path), |
|
|
"-vf", "scale=576:-1,pad=576:1024:0:(1024-ih)/2:black,fps=30", |
|
|
"-c:v", "libx264", "-crf", "18", "-pix_fmt", "yuv420p", str(out_vid) |
|
|
], check=True) |
|
|
|
|
|
return str(out_vid) |
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=generate, |
|
|
inputs=[ |
|
|
gr.Image(type="pil", label="Portrait 512×512"), |
|
|
gr.Audio(type="filepath", optional=True, label="Voice (wav/mp3)"), |
|
|
gr.Textbox(lines=2, placeholder="…or paste text", label="Text") |
|
|
], |
|
|
outputs=gr.Video(label="576×1024 MP4"), |
|
|
title="ZeroGPU SadTalker 9:16" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue(concurrency_count=2, max_size=8).launch() |