sadtalker-api / app.py
aiwithshweta's picture
update
9fe1333 verified
import os
import subprocess
import gdown
import gradio as gr
import zipfile
# ------------------------------------------------------
# 1) DOWNLOAD SadTalker ZIP FROM GOOGLE DRIVE (ONLY ONCE)
# ------------------------------------------------------
SADTALKER_ZIP_ID = "1ERCCvqt2YTNiwfqY1N95aQIIBm3maApu"
if not os.path.exists("SadTalker"):
print("⬇️ Downloading SadTalker.zip...")
gdown.download(f"https://drive.google.com/uc?id={SADTALKER_ZIP_ID}", "sad.zip", quiet=False)
print("📦 Extracting ZIP...")
with zipfile.ZipFile("sad.zip", 'r') as zip_ref:
zip_ref.extractall("./")
# Rename folder if extracted as SadTalker-main
if os.path.exists("SadTalker-main"):
os.rename("SadTalker-main", "SadTalker")
print("✔ SadTalker ready!")
os.chdir("SadTalker")
# ------------------------------------------------------
# 2) DOWNLOAD CHECKPOINT MODELS (ONLY IF MISSING)
# ------------------------------------------------------
def gd(file_id, out):
url = f"https://drive.google.com/uc?id={file_id}"
if not os.path.exists(out):
print(f"⬇️ Downloading: {out}")
gdown.download(url, out, quiet=False)
else:
print(f"✔ Already exists: {out}")
os.makedirs("checkpoints", exist_ok=True)
gd("1g3VIpU3yhpITMZtrWU2mbmyzLoF9K3Gz", "checkpoints/audio2exp_00300-model.pth")
gd("1Jp4i_Qc-6qCms7v1kN61RE3qnT_9J8vg", "checkpoints/audio2pose_00140-model.pth")
gd("1pTbKmpOeWRSA1NYQ3DnWe3W-9Qeyy7PK", "checkpoints/mapping_00109-model.pth.tar")
gd("1QJins0e_hvvZM4d47gR9eoCxa0L4LpWj", "checkpoints/mapping_00229-model.pth.tar")
gd("1caz3ESiy-aEzI0ptff8qpjcxdTf-L26E", "checkpoints/SadTalker_V0.0.2_256.safetensors")
gd("12mj9EEs4KIS6-9FjF-PjEegF-TDKvvi6", "checkpoints/SadTalker_V0.0.2_512.safetensors")
gd("1cag0u7e5RgdKoxBa7exY599VNGtYgNQb", "checkpoints/shape_predictor_68_face_landmarks.dat")
print("✔ All checkpoints ready!")
# ------------------------------------------------------
# 3) GENERATE VIDEO (FAST + ALWAYS RETURNS OUTPUT)
# ------------------------------------------------------
def generate_video(image, audio):
image.save("input.png")
cmd = [
"python3", "inference.py",
"--driven_audio", audio,
"--source_image", "input.png",
"--result_dir", "results",
"--still",
"--size", "512",
"--enhancer", "None",
"--expression_scale", "1.8",
"--preprocess", "full"
]
print("▶ Running inference...")
subprocess.run(cmd)
# Find all mp4 in results
mp4s = []
for root, dirs, files in os.walk("results"):
for f in files:
if f.endswith(".mp4"):
mp4s.append(os.path.join(root, f))
if not mp4s:
return "❌ No video generated."
# Pick latest file
mp4s.sort(key=lambda x: os.path.getmtime(x), reverse=True)
latest = mp4s[0]
print("🎬 Video generated:", latest)
return latest
# ------------------------------------------------------
# 4) GRADIO UI
# ------------------------------------------------------
demo = gr.Interface(
fn=generate_video,
inputs=[gr.Image(type="pil"), gr.Audio(type="filepath")],
outputs=gr.Video(),
title="SadTalker (Google Drive Version)",
description="Fast loading + no duplicate downloads + auto video return"
)
demo.launch()