IVTS / app.py
Kyo-Kai's picture
Update app.py
b852af3 verified
import sys, os, subprocess, tempfile, uuid, torch, gradio as gr
from pathlib import Path
from PIL import Image
import numpy as np
# Add SadTalker paths
sys.path.insert(0, '/workspace/SadTalker')
sys.path.insert(0, '/workspace/SadTalker/src')
# Import SadTalker modules directly
from src.utils.preprocess import CropAndExtract
from src.test_audio2coeff import Audio2Coeff
from src.facerender.animate import AnimateFromCoeff
from src.generate_batch import get_data
from src.generate_facerender_batch import get_facerender_data
from src.utils.init_path import init_path
from TTS.api import TTS
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
tts = TTS(model_name="tts_models/en/xtts_v2", progress_bar=False).to(DEVICE)
class SadTalkerInference:
def __init__(self, checkpoint_dir='./checkpoints', config_dir='./src/config', device='cuda'):
self.device = device
self.checkpoint_dir = checkpoint_dir
self.config_dir = config_dir
# Initialize the models
self.preprocess_model = CropAndExtract(self.checkpoint_dir, self.device)
self.audio2coeff = Audio2Coeff(self.checkpoint_dir, self.device)
self.animate_from_coeff = AnimateFromCoeff(self.checkpoint_dir, self.device)
def generate(self, source_image, driven_audio, preprocess='crop',
still_mode=False, use_enhancer=None, batch_size=1,
size=512, pose_style=0, exp_scale=1.0,
use_ref_video=False, ref_video=None, ref_info=None,
use_idle_mode=False, length_of_audio=0, use_blink=True,
result_dir='./results/'):
# Save paths
save_dir = os.path.join(result_dir, 'temp_' + str(uuid.uuid4())[:8])
os.makedirs(save_dir, exist_ok=True)
input_dir = os.path.join(save_dir, 'input')
os.makedirs(input_dir, exist_ok=True)
# Process image
pic_path = os.path.join(input_dir, 'pic.jpg')
if isinstance(source_image, str):
source_image = source_image
else:
source_image.save(pic_path)
source_image = pic_path
# First crop the image
first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
os.makedirs(first_frame_dir, exist_ok=True)
first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(
source_image, first_frame_dir, preprocess, True, size)
if first_coeff_path is None:
raise ValueError("Can't get the coeffs of the input")
# Audio to coeffs
batch = get_data(first_coeff_path, driven_audio, self.device, ref_eyeblink=None, still=still_mode,
idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink)
coeff_path = self.audio2coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path=None, audio_path=driven_audio, crop_info=crop_info)
# Coeffs to video
data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, driven_audio,
batch_size, still_mode, None, None, expression_scale=exp_scale,
input_yaw_list=None, input_pitch_list=None, input_roll_list=None)
result = self.animate_from_coeff.generate(data, save_dir, pic_path, crop_info, enhancer=use_enhancer,
background_enhancer=None, preprocess=preprocess, img_size=size)
return result
# Initialize SadTalker
sadtalker = SadTalkerInference(device=DEVICE)
def generate(image: Image, audio=None, text: str = ""):
tmp = Path(tempfile.mkdtemp())
img_path = tmp / "input.png"
image.convert("RGB").save(img_path)
# 1️⃣ handle audio or text -> wav
if audio is None and text.strip():
wav_path = tmp / "speech.wav"
tts.tts_to_file(text=text, file_path=str(wav_path))
elif audio is not None:
wav_path = Path(audio)
else:
raise gr.Error("Provide either audio or text!")
# 2️⃣ run SadTalker
try:
result_path = sadtalker.generate(
source_image=str(img_path),
driven_audio=str(wav_path),
preprocess='crop',
still_mode=False,
use_enhancer=None,
size=512,
result_dir=str(tmp)
)
vid_path = result_path
except Exception as e:
raise gr.Error(f"SadTalker generation failed: {str(e)}")
# 3️⃣ (optional) upscale to ~1024 height with Real-ESRGAN
upscaled_path = tmp / "up.mp4"
try:
subprocess.run([
"python", "/workspace/Real-ESRGAN/inference_realesrgan_video.py",
"-n", "RealESRGAN_x4plus",
"-i", str(vid_path),
"-o", str(upscaled_path),
"--fp32"
], check=True)
except Exception:
upscaled_path = vid_path # fallback
# 4️⃣ pad/crop to 576×1024 vertical
out_vid = tmp / "vertical.mp4"
subprocess.run([
"ffmpeg", "-y", "-i", str(upscaled_path),
"-vf", "scale=576:-1,pad=576:1024:0:(1024-ih)/2:black,fps=30",
"-c:v", "libx264", "-crf", "18", "-pix_fmt", "yuv420p", str(out_vid)
], check=True)
return str(out_vid)
demo = gr.Interface(
fn=generate,
inputs=[
gr.Image(type="pil", label="Portrait 512×512"),
gr.Audio(type="filepath", optional=True, label="Voice (wav/mp3)"),
gr.Textbox(lines=2, placeholder="…or paste text", label="Text")
],
outputs=gr.Video(label="576×1024 MP4"),
title="ZeroGPU SadTalker 9:16"
)
if __name__ == "__main__":
demo.queue(concurrency_count=2, max_size=8).launch()