Spaces:
Running
Running
| import os | |
| import subprocess | |
| import gdown | |
| import gradio as gr | |
| import zipfile | |
| # ------------------------------------------------------ | |
| # 1) DOWNLOAD SadTalker ZIP FROM GOOGLE DRIVE (ONLY ONCE) | |
| # ------------------------------------------------------ | |
| SADTALKER_ZIP_ID = "1ERCCvqt2YTNiwfqY1N95aQIIBm3maApu" | |
| if not os.path.exists("SadTalker"): | |
| print("⬇️ Downloading SadTalker.zip...") | |
| gdown.download(f"https://drive.google.com/uc?id={SADTALKER_ZIP_ID}", "sad.zip", quiet=False) | |
| print("📦 Extracting ZIP...") | |
| with zipfile.ZipFile("sad.zip", 'r') as zip_ref: | |
| zip_ref.extractall("./") | |
| # Rename folder if extracted as SadTalker-main | |
| if os.path.exists("SadTalker-main"): | |
| os.rename("SadTalker-main", "SadTalker") | |
| print("✔ SadTalker ready!") | |
| os.chdir("SadTalker") | |
| # ------------------------------------------------------ | |
| # 2) DOWNLOAD CHECKPOINT MODELS (ONLY IF MISSING) | |
| # ------------------------------------------------------ | |
| def gd(file_id, out): | |
| url = f"https://drive.google.com/uc?id={file_id}" | |
| if not os.path.exists(out): | |
| print(f"⬇️ Downloading: {out}") | |
| gdown.download(url, out, quiet=False) | |
| else: | |
| print(f"✔ Already exists: {out}") | |
| os.makedirs("checkpoints", exist_ok=True) | |
| gd("1g3VIpU3yhpITMZtrWU2mbmyzLoF9K3Gz", "checkpoints/audio2exp_00300-model.pth") | |
| gd("1Jp4i_Qc-6qCms7v1kN61RE3qnT_9J8vg", "checkpoints/audio2pose_00140-model.pth") | |
| gd("1pTbKmpOeWRSA1NYQ3DnWe3W-9Qeyy7PK", "checkpoints/mapping_00109-model.pth.tar") | |
| gd("1QJins0e_hvvZM4d47gR9eoCxa0L4LpWj", "checkpoints/mapping_00229-model.pth.tar") | |
| gd("1caz3ESiy-aEzI0ptff8qpjcxdTf-L26E", "checkpoints/SadTalker_V0.0.2_256.safetensors") | |
| gd("12mj9EEs4KIS6-9FjF-PjEegF-TDKvvi6", "checkpoints/SadTalker_V0.0.2_512.safetensors") | |
| gd("1cag0u7e5RgdKoxBa7exY599VNGtYgNQb", "checkpoints/shape_predictor_68_face_landmarks.dat") | |
| print("✔ All checkpoints ready!") | |
| # ------------------------------------------------------ | |
| # 3) GENERATE VIDEO (FAST + ALWAYS RETURNS OUTPUT) | |
| # ------------------------------------------------------ | |
| def generate_video(image, audio): | |
| image.save("input.png") | |
| cmd = [ | |
| "python3", "inference.py", | |
| "--driven_audio", audio, | |
| "--source_image", "input.png", | |
| "--result_dir", "results", | |
| "--still", | |
| "--size", "512", | |
| "--enhancer", "None", | |
| "--expression_scale", "1.8", | |
| "--preprocess", "full" | |
| ] | |
| print("▶ Running inference...") | |
| subprocess.run(cmd) | |
| # Find all mp4 in results | |
| mp4s = [] | |
| for root, dirs, files in os.walk("results"): | |
| for f in files: | |
| if f.endswith(".mp4"): | |
| mp4s.append(os.path.join(root, f)) | |
| if not mp4s: | |
| return "❌ No video generated." | |
| # Pick latest file | |
| mp4s.sort(key=lambda x: os.path.getmtime(x), reverse=True) | |
| latest = mp4s[0] | |
| print("🎬 Video generated:", latest) | |
| return latest | |
| # ------------------------------------------------------ | |
| # 4) GRADIO UI | |
| # ------------------------------------------------------ | |
| demo = gr.Interface( | |
| fn=generate_video, | |
| inputs=[gr.Image(type="pil"), gr.Audio(type="filepath")], | |
| outputs=gr.Video(), | |
| title="SadTalker (Google Drive Version)", | |
| description="Fast loading + no duplicate downloads + auto video return" | |
| ) | |
| demo.launch() | |