sadtalker-api / app.py
Madras1's picture
Upload 3 files
fe0a2e1 verified
import gradio as gr
import spaces
import subprocess
import tempfile
import base64
import os
import shutil
import sys
# SadTalker path
SADTALKER_DIR = "/home/user/SadTalker"
def setup_sadtalker():
"""Clone and setup SadTalker if not already done"""
if not os.path.exists(SADTALKER_DIR):
print("Cloning SadTalker...")
subprocess.run([
"git", "clone", "--depth", "1",
"https://github.com/OpenTalker/SadTalker.git",
SADTALKER_DIR
], check=True)
# Install SadTalker requirements
print("Installing SadTalker requirements...")
subprocess.run([
sys.executable, "-m", "pip", "install", "-q", "-r",
f"{SADTALKER_DIR}/requirements.txt"
])
# Download checkpoints from HuggingFace
print("Downloading checkpoints...")
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="vinthony/SadTalker",
local_dir=f"{SADTALKER_DIR}/checkpoints",
local_dir_use_symlinks=False
)
return True
@spaces.GPU(duration=120)
def generate_video_gpu(image_path: str, audio_path: str, output_dir: str) -> str:
"""GPU-accelerated video generation"""
setup_sadtalker()
# Add SadTalker to path
if SADTALKER_DIR not in sys.path:
sys.path.insert(0, SADTALKER_DIR)
# Run SadTalker inference
cmd = [
sys.executable, f"{SADTALKER_DIR}/inference.py",
"--driven_audio", audio_path,
"--source_image", image_path,
"--result_dir", output_dir,
"--still",
"--preprocess", "crop",
]
print(f"Running: {' '.join(cmd)}")
result = subprocess.run(
cmd,
capture_output=True,
text=True,
cwd=SADTALKER_DIR
)
print(f"STDOUT: {result.stdout}")
if result.stderr:
print(f"STDERR: {result.stderr}")
if result.returncode != 0:
raise Exception(f"SadTalker failed: {result.stderr}")
# Find generated video
for root, dirs, files in os.walk(output_dir):
for f in files:
if f.endswith(".mp4"):
return os.path.join(root, f)
raise Exception("No video generated")
def gradio_generate(image, audio):
"""Gradio interface wrapper"""
# Debug logging
print(f"=== RECEIVED ===")
print(f"Image type: {type(image)}, value: {image}")
print(f"Audio type: {type(audio)}, value: {audio}")
# Better validation
if not image:
raise gr.Error("Envie uma imagem primeiro!")
if not audio:
raise gr.Error("Envie um áudio primeiro!")
with tempfile.TemporaryDirectory() as tmpdir:
image_path = os.path.join(tmpdir, "input.png")
audio_path = os.path.join(tmpdir, "input.wav")
output_dir = os.path.join(tmpdir, "output")
os.makedirs(output_dir, exist_ok=True)
# Handle image
try:
if isinstance(image, str) and os.path.exists(image):
shutil.copy(image, image_path)
elif hasattr(image, 'name'): # File object
shutil.copy(image.name, image_path)
else:
from PIL import Image as PILImage
PILImage.fromarray(image).save(image_path)
except Exception as e:
raise gr.Error(f"Erro ao processar imagem: {e}")
# Handle audio
try:
if isinstance(audio, str) and os.path.exists(audio):
shutil.copy(audio, audio_path)
elif hasattr(audio, 'name'): # File object
shutil.copy(audio.name, audio_path)
elif isinstance(audio, tuple):
import scipy.io.wavfile as wav
sr, data = audio
wav.write(audio_path, sr, data)
else:
raise gr.Error(f"Formato de áudio não reconhecido: {type(audio)}")
except Exception as e:
raise gr.Error(f"Erro ao processar áudio: {e}")
print(f"Image saved: {image_path}, exists: {os.path.exists(image_path)}")
print(f"Audio saved: {audio_path}, exists: {os.path.exists(audio_path)}")
# Generate video
try:
video_path = generate_video_gpu(image_path, audio_path, output_dir)
# Copy to persistent location
final_path = "/tmp/sadtalker_output.mp4"
shutil.copy(video_path, final_path)
return final_path
except Exception as e:
raise gr.Error(f"Erro na geração: {e}")
# Create Gradio app
with gr.Blocks(title="SadTalker API") as demo:
gr.Markdown("# 🎭 SadTalker API")
gr.Markdown("Generate talking head videos from image + audio (ZeroGPU)")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Face Image", type="filepath")
audio_input = gr.Audio(label="Audio", type="filepath")
generate_btn = gr.Button("🎬 Generate Video", variant="primary", size="lg")
with gr.Column():
video_output = gr.Video(label="Generated Video")
gr.Markdown("⏱️ Takes ~30-60 seconds with GPU")
generate_btn.click(
fn=gradio_generate,
inputs=[image_input, audio_input],
outputs=video_output
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)