Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import math | |
| import uuid | |
| import shutil | |
| import glob | |
| import tempfile | |
| import subprocess | |
| import json | |
| import time | |
| import warnings | |
| from pathlib import Path | |
| from typing import Optional, Tuple | |
| import asyncio | |
| import librosa | |
| import soundfile as sf | |
| from pydub import AudioSegment | |
| import cv2 | |
| import gradio as gr | |
| import torch | |
| # Setup paths for Hugging Face Spaces | |
| BASE = Path("/tmp") if os.path.exists("/tmp") else Path.cwd() | |
| WORK = BASE / "ai_avatar_work" | |
| OUT = BASE / "ai_avatar_out" | |
| WORK.mkdir(exist_ok=True, parents=True) | |
| OUT.mkdir(exist_ok=True, parents=True) | |
| # Setup SadTalker | |
| SADTALKER_DIR = BASE / "SadTalker" | |
| def setup_sadtalker(): | |
| """Setup SadTalker if not already available.""" | |
| if not SADTALKER_DIR.exists(): | |
| print("Setting up SadTalker...") | |
| try: | |
| # Clone SadTalker | |
| subprocess.run([ | |
| "git", "clone", "https://github.com/OpenTalker/SadTalker.git", | |
| str(SADTALKER_DIR) | |
| ], check=True, capture_output=True, text=True) | |
| # Install requirements | |
| requirements_path = SADTALKER_DIR / "requirements.txt" | |
| if requirements_path.exists(): | |
| subprocess.run([ | |
| sys.executable, "-m", "pip", "install", "-r", str(requirements_path) | |
| ], check=True, capture_output=True, text=True) | |
| # Download models | |
| download_script = SADTALKER_DIR / "scripts" / "download_models.sh" | |
| if download_script.exists(): | |
| subprocess.run([ | |
| "bash", str(download_script) | |
| ], cwd=str(SADTALKER_DIR), check=True, capture_output=True, text=True) | |
| print("β SadTalker setup complete!") | |
| except subprocess.CalledProcessError as e: | |
| print(f"β SadTalker setup failed: {e}") | |
| print(f"stdout: {e.stdout}") | |
| print(f"stderr: {e.stderr}") | |
| return False | |
| return True | |
| # Initialize SadTalker on startup | |
| setup_sadtalker() | |
| # -------------------- Configuration -------------------- | |
| class AgentConfig: | |
| def __init__(self, | |
| language="en", | |
| grab_frame_at=1.0, | |
| expr_scale=1.0, | |
| pose_scale=1.0, | |
| fps=25): | |
| self.language = language | |
| self.grab_frame_at = grab_frame_at | |
| self.expr_scale = float(expr_scale) | |
| self.pose_scale = float(pose_scale) | |
| self.fps = int(fps) | |
| class AgentLogs: | |
| def __init__(self): | |
| self.lines = [] | |
| def log(self, msg): | |
| print(msg) | |
| self.lines.append(msg) | |
| # -------------------- Utility Functions -------------------- | |
| def run_cmd(cmd: list, check=True): | |
| print("βΆ", " ".join(cmd)) | |
| p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) | |
| print(p.stdout) | |
| if check and p.returncode != 0: | |
| raise RuntimeError(f"Command failed: {' '.join(cmd)}") | |
| return p | |
| def extract_audio_to_wav(video_path: str, wav_out: str, sr: int = 16000): | |
| """Extract mono WAV audio from video.""" | |
| cmd = [ | |
| "ffmpeg", "-y", "-i", video_path, "-ac", "1", "-ar", str(sr), wav_out | |
| ] | |
| run_cmd(cmd) | |
| def cut_audio_segment(in_wav: str, out_wav: str, start_s: float, dur_s: float): | |
| """Cut segment from audio using ffmpeg.""" | |
| cmd = ["ffmpeg", "-y", "-ss", f"{start_s:.3f}", "-t", f"{dur_s:.3f}", | |
| "-i", in_wav, "-acodec", "pcm_s16le", out_wav] | |
| run_cmd(cmd) | |
| def ensure_exact_duration(in_wav: str, out_wav: str, target_sec: float = 20.0): | |
| """Trim or pad with silence to exactly target_sec.""" | |
| audio = AudioSegment.from_file(in_wav) | |
| cur = len(audio) / 1000.0 | |
| if cur > target_sec: | |
| audio = audio[:int(target_sec * 1000)] | |
| elif cur < target_sec: | |
| silence = AudioSegment.silent(duration=int((target_sec - cur) * 1000)) | |
| audio = audio + silence | |
| audio.export(out_wav, format="wav") | |
| def grab_frame_from_video(video_path: str, out_img: str, at_sec: float = 1.0): | |
| """Extract a single frame from a video at time t.""" | |
| cap = cv2.VideoCapture(video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 25 | |
| frame_no = int(at_sec * fps) | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, frame_no) | |
| ok, frame = cap.read() | |
| if not ok: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, 0) | |
| ok, frame = cap.read() | |
| cap.release() | |
| if not ok: | |
| raise RuntimeError("Could not extract frame from video.") | |
| cv2.imwrite(out_img, frame) | |
| # -------------------- Speech VAD -------------------- | |
| def find_voice_reference_chunk(wav_path: str, sr: int = 16000, target_dur: float = 6.0) -> Tuple[float, float]: | |
| """Use Silero VAD to find a ~target_dur chunk with speech.""" | |
| try: | |
| wav, file_sr = librosa.load(wav_path, sr=sr, mono=True) | |
| vad_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', | |
| model='silero_vad', force_reload=False) | |
| (get_speech_timestamps, _, read_audio, _, collect_chunks) = utils | |
| speech_timestamps = get_speech_timestamps( | |
| torch.tensor(wav, dtype=torch.float32), vad_model, sampling_rate=sr | |
| ) | |
| if not speech_timestamps: | |
| total = len(wav) / sr | |
| mid_start = max(0.0, (total / 2) - (target_dur / 2)) | |
| return mid_start, min(target_dur, total - mid_start) | |
| best = max(speech_timestamps, key=lambda x: x['end'] - x['start']) | |
| cand_dur = (best['end'] - best['start']) / sr | |
| if cand_dur >= target_dur: | |
| start = best['start'] / sr | |
| return start, target_dur | |
| start = best['start'] / sr | |
| dur = cand_dur | |
| for s in speech_timestamps: | |
| if s is best: | |
| continue | |
| if abs(s['start'] / sr - (start + dur)) < 0.1: | |
| dur += (s['end'] - s['start']) / sr | |
| if dur >= target_dur: | |
| break | |
| return start, min(target_dur, len(wav) / sr - start) | |
| except Exception as e: | |
| print(f"VAD failed: {e}, using fallback") | |
| return 1.0, min(6.0, target_dur) | |
| # -------------------- Script Generation -------------------- | |
| def generate_20s_script(user_text: str = None, language: str = "en") -> str: | |
| """Generate or return user-provided script text.""" | |
| if user_text and user_text.strip(): | |
| return user_text.strip() | |
| # For Hugging Face deployment, we'll use a simple fallback | |
| # You can add API keys as Hugging Face Secrets if needed | |
| fallback_scripts = { | |
| "en": "Hello! I'm your AI avatar. I'm here to demonstrate lifelike speech, natural lip-sync, and subtle head movement. This short clip shows voice cloning from a brief sample, then animates a still image for a realistic talking experience. Thanks for watching!", | |
| "es": "Β‘Hola! Soy tu avatar de IA. Estoy aquΓ para demostrar un habla realista, sincronizaciΓ³n labial natural y movimientos sutiles de cabeza. Este breve clip muestra clonaciΓ³n de voz a partir de una muestra breve.", | |
| "fr": "Bonjour! Je suis votre avatar IA. Je suis ici pour dΓ©montrer une parole rΓ©aliste, une synchronisation labiale naturelle et des mouvements subtils de la tΓͺte. Ce court clip montre le clonage vocal.", | |
| } | |
| return fallback_scripts.get(language, fallback_scripts["en"]) | |
| # -------------------- TTS -------------------- | |
| def tts_20s_voice_clone(script_text: str, ref_wav: str, out_wav: str, language: str = "en") -> str: | |
| """TTS with voice cloning fallbacks.""" | |
| tmp = str(WORK / f"tts_raw_{uuid.uuid4().hex}.wav") | |
| try: | |
| from TTS.api import TTS | |
| xtts = TTS("tts_models/multilingual/multi-dataset/xtts_v2") | |
| xtts.tts_to_file(text=script_text, speaker_wav=ref_wav, | |
| language=language, file_path=tmp) | |
| print("β XTTS v2 voice clone success") | |
| except Exception as e: | |
| print("XTTS v2 clone failed:", e) | |
| try: | |
| from TTS.api import TTS | |
| tts = TTS("tts_models/en/ljspeech/tacotron2-DDC") | |
| tts.tts_to_file(text=script_text, file_path=tmp) | |
| print("β XTTS fallback (non-clone) success") | |
| except Exception as e2: | |
| print("XTTS non-clone failed:", e2) | |
| try: | |
| import edge_tts | |
| async def gen_edge(): | |
| communicate = edge_tts.Communicate(script_text, voice="en-US-JennyNeural") | |
| await communicate.save(tmp) | |
| asyncio.run(gen_edge()) | |
| print("β edge-tts fallback success") | |
| except Exception as e3: | |
| raise RuntimeError(f"All TTS fallbacks failed: {e3}") | |
| ensure_exact_duration(tmp, out_wav, 20.0) | |
| return out_wav | |
| # -------------------- SadTalker with Fallback -------------------- | |
| def run_sadtalker(source_img: str, driven_wav: str, out_dir: str, | |
| expr_scale: float = 1.0, pose_scale: float = 1.0, fps: int = 25) -> str: | |
| """Call SadTalker inference with fallback.""" | |
| if not SADTALKER_DIR.exists(): | |
| if not setup_sadtalker(): | |
| return create_static_video_fallback(source_img, driven_wav, out_dir, fps) | |
| out_dir = str(Path(out_dir)) | |
| os.makedirs(out_dir, exist_ok=True) | |
| inference_script = SADTALKER_DIR / "inference.py" | |
| if not inference_script.exists(): | |
| print("β SadTalker inference script not found, using fallback") | |
| return create_static_video_fallback(source_img, driven_wav, out_dir, fps) | |
| try: | |
| args = [ | |
| sys.executable, str(inference_script), | |
| "--driven_audio", driven_wav, | |
| "--source_image", source_img, | |
| "--preprocess", "full", | |
| "--still", | |
| "--enhancer", "gfpgan", | |
| "--expression_scale", str(expr_scale), | |
| "--pose_scale", str(pose_scale), | |
| "--result_dir", out_dir, | |
| "--fps", str(fps), | |
| ] | |
| # Change to SadTalker directory for execution | |
| original_cwd = os.getcwd() | |
| try: | |
| os.chdir(str(SADTALKER_DIR)) | |
| run_cmd(args) | |
| finally: | |
| os.chdir(original_cwd) | |
| mp4s = sorted(glob.glob(os.path.join(out_dir, "**", "*.mp4"), recursive=True), | |
| key=os.path.getmtime) | |
| if not mp4s: | |
| print("β SadTalker produced no output, using fallback") | |
| return create_static_video_fallback(source_img, driven_wav, out_dir, fps) | |
| return mp4s[-1] | |
| except Exception as e: | |
| print(f"β SadTalker failed: {e}, using fallback") | |
| return create_static_video_fallback(source_img, driven_wav, out_dir, fps) | |
| def create_static_video_fallback(source_img: str, driven_wav: str, out_dir: str, fps: int = 25) -> str: | |
| """Create a static video with the image and audio as fallback.""" | |
| output_path = os.path.join(out_dir, "fallback_output.mp4") | |
| # Get audio duration | |
| audio = AudioSegment.from_file(driven_wav) | |
| duration = len(audio) / 1000.0 # Convert to seconds | |
| # Create video with static image and audio | |
| cmd = [ | |
| "ffmpeg", "-y", | |
| "-loop", "1", "-i", source_img, | |
| "-i", driven_wav, | |
| "-c:v", "libx264", "-tune", "stillimage", "-c:a", "aac", | |
| "-b:a", "192k", "-pix_fmt", "yuv420p", | |
| "-shortest", "-r", str(fps), | |
| "-t", str(duration), | |
| output_path | |
| ] | |
| try: | |
| run_cmd(cmd) | |
| print(f"β Created fallback static video: {output_path}") | |
| return output_path | |
| except Exception as e: | |
| raise RuntimeError(f"Even fallback video creation failed: {e}") | |
| # -------------------- Final Muxing -------------------- | |
| def mux_audio_video(video_path: str, audio_wav: str, final_mp4: str, fps: int = 25): | |
| """Replace video audio with our exact 20s wav.""" | |
| cmd = [ | |
| "ffmpeg", "-y", "-i", video_path, | |
| "-t", "20.0", | |
| "-r", str(fps), | |
| "-i", audio_wav, | |
| "-map", "0:v:0", "-map", "1:a:0", | |
| "-c:v", "libx264", "-pix_fmt", "yuv420p", | |
| "-c:a", "aac", "-shortest", final_mp4 | |
| ] | |
| run_cmd(cmd) | |
| # -------------------- Main Agent Function -------------------- | |
| def run_agent(video_path: str, | |
| maybe_image_path: Optional[str], | |
| user_script_text: Optional[str], | |
| cfg: AgentConfig): | |
| """Main agent orchestrator function.""" | |
| logs = AgentLogs() | |
| try: | |
| # Check SadTalker setup first | |
| logs.log("Checking SadTalker setup...") | |
| if not SADTALKER_DIR.exists(): | |
| logs.log("Setting up SadTalker (first run may take a few minutes)...") | |
| if not setup_sadtalker(): | |
| logs.log("β οΈ SadTalker setup failed, will use static video fallback") | |
| else: | |
| logs.log("β SadTalker ready") | |
| video_path = str(video_path) | |
| vid_name = Path(video_path).stem | |
| session = WORK / f"run_{uuid.uuid4().hex[:8]}_{vid_name}" | |
| session.mkdir(parents=True, exist_ok=True) | |
| full_audio = str(session / "audio.wav") | |
| ref_audio = str(session / "ref.wav") | |
| tts_audio = str(session / "tts_20s.wav") | |
| still_img = str(session / "still.jpg") | |
| sadtalker_out = str(session / "sadtalker_out") | |
| final_mp4 = str(OUT / f"{vid_name}_avatar_20s.mp4") | |
| logs.log("Step 1) Extracting audio...") | |
| extract_audio_to_wav(video_path, full_audio, sr=16000) | |
| logs.log("Step 2) Finding speech reference (~6s) via VAD...") | |
| start, dur = find_voice_reference_chunk(full_audio, sr=16000, target_dur=6.0) | |
| logs.log(f" - ref start: {start:.2f}s, dur: {dur:.2f}s") | |
| cut_audio_segment(full_audio, ref_audio, start, dur) | |
| logs.log("Step 3) Generating 20s script text...") | |
| script_text = generate_20s_script(user_script_text, cfg.language) | |
| logs.log(" - script preview: " + (script_text[:140] + ("..." if len(script_text) > 140 else ""))) | |
| logs.log("Step 4) TTS voice cloning to 20s...") | |
| tts_20s_voice_clone(script_text, ref_audio, tts_audio, language=cfg.language) | |
| logs.log(f" - TTS saved: {tts_audio}") | |
| logs.log("Step 5) Prepare still image...") | |
| if maybe_image_path and str(maybe_image_path).strip(): | |
| shutil.copy(maybe_image_path, still_img) | |
| logs.log(" - Using provided still image.") | |
| else: | |
| grab_frame_from_video(video_path, still_img, at_sec=cfg.grab_frame_at) | |
| logs.log(f" - Grabbed frame at {cfg.grab_frame_at}s from video.") | |
| logs.log("Step 6) Run SadTalker animation (or fallback)...") | |
| raw_video = run_sadtalker(still_img, tts_audio, sadtalker_out, | |
| expr_scale=cfg.expr_scale, | |
| pose_scale=cfg.pose_scale, | |
| fps=cfg.fps) | |
| logs.log(f" - Video output: {raw_video}") | |
| logs.log("Step 7) Mux final MP4 (20s, audio + avatar)...") | |
| mux_audio_video(raw_video, tts_audio, final_mp4, fps=cfg.fps) | |
| logs.log(f"β DONE: {final_mp4}") | |
| return final_mp4, "\n".join(logs.lines) | |
| except Exception as e: | |
| logs.log(f"β ERROR: {e}") | |
| return "", "\n".join(logs.lines) | |
| # -------------------- Gradio Interface -------------------- | |
| def ui_run(video, image, script_text, language, grab_t, expr, pose, fps): | |
| if video is None: | |
| return None, "Please upload a ~30s video." | |
| cfg = AgentConfig(language=language, grab_frame_at=grab_t, | |
| expr_scale=expr, pose_scale=pose, fps=int(fps)) | |
| out, logs = run_agent(video, image, script_text, cfg) | |
| return out if out else None, logs | |
| # Create Gradio interface | |
| with gr.Blocks(title="AI Avatar Agentic Flow", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π¬π§ AI Avatar Agentic Flow β Voice Clone + SadTalker (20s MP4) | |
| Upload a short (~30s) video to clone the voice, then generate a 20-second talking avatar with lip-sync and head movements. | |
| **Features:** | |
| - Automatic voice reference detection using VAD | |
| - Voice cloning with TTS fallbacks (XTTS v2 β edge-tts) | |
| - Animated avatar with SadTalker | |
| - 20-second output with perfect audio sync | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| video = gr.Video( | |
| label="Upload ~30s Video (used for voice reference + optional frame)", | |
| height=300 | |
| ) | |
| image = gr.Image( | |
| type="filepath", | |
| label="Optional Still Image (else a frame is grabbed from the video)", | |
| height=300 | |
| ) | |
| with gr.Column(): | |
| script_text = gr.Textbox( | |
| lines=4, | |
| label="Optional 20s Script Text (leave blank to auto-generate)", | |
| placeholder="Enter your custom script here..." | |
| ) | |
| with gr.Row(): | |
| language = gr.Dropdown( | |
| choices=["en", "es", "fr", "de", "it", "hi", "ur"], | |
| value="en", | |
| label="Language" | |
| ) | |
| grab_t = gr.Slider( | |
| 0.0, 5.0, value=1.0, step=0.1, | |
| label="Grab Frame at (sec) if no image" | |
| ) | |
| with gr.Row(): | |
| expr = gr.Slider( | |
| 0.5, 2.0, value=1.0, step=0.1, | |
| label="Expression Scale" | |
| ) | |
| pose = gr.Slider( | |
| 0.5, 2.0, value=1.0, step=0.1, | |
| label="Pose Scale" | |
| ) | |
| fps = gr.Slider( | |
| 15, 30, value=25, step=1, | |
| label="FPS" | |
| ) | |
| run_btn = gr.Button("π Run Agent", variant="primary", size="lg") | |
| with gr.Row(): | |
| out_video = gr.Video(label="Final 20s MP4", height=400) | |
| logs = gr.Textbox(label="Logs", lines=20, max_lines=30) | |
| run_btn.click( | |
| ui_run, | |
| inputs=[video, image, script_text, language, grab_t, expr, pose, fps], | |
| outputs=[out_video, logs] | |
| ) | |
| gr.Markdown(""" | |
| ## π§° Tips & Troubleshooting | |
| - **Processing Time:** First run may take longer due to model downloads | |
| - **Audio Length:** Output is enforced to exactly 20 seconds | |
| - **Voice Reference:** Auto-finds ~6s speech chunk using Silero VAD | |
| - **Language Support:** XTTS v2 supports multiple languages | |
| - **Fallbacks:** Script generation and TTS have multiple fallback options | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() |