import gradio as gr import asyncio import numpy as np import os import shutil import warnings import time from huggingface_hub import hf_hub_download, snapshot_download # Suppress warnings warnings.filterwarnings('ignore') # --- đŸ› ī¸ SELF-HEALING MODEL DOWNLOADER (Run at Startup) --- def check_and_download_models(): """ Checks if required models exist. If not, downloads them. Also handles the 'models/sd-vae' crash by ensuring the folder is correct. """ print("🔍 Checking Model Integrity...") # 1. FIX: The specific VAE folder causing your crash vae_path = "./models/sd-vae" vae_config = os.path.join(vae_path, "config.json") # If folder exists but empty/broken, delete it (Heals the "OSError") if os.path.exists(vae_path) and not os.path.exists(vae_config): print(f"âš ī¸ Found broken VAE folder at {vae_path}. Deleting to fix...") shutil.rmtree(vae_path) # Download VAE if missing if not os.path.exists(vae_config): print("âŦ‡ī¸ Downloading VAE (sd-vae-ft-mse)...") try: os.makedirs(vae_path, exist_ok=True) hf_hub_download(repo_id="stabilityai/sd-vae-ft-mse", filename="config.json", local_dir=vae_path) hf_hub_download(repo_id="stabilityai/sd-vae-ft-mse", filename="diffusion_pytorch_model.bin", local_dir=vae_path) print("✅ VAE Downloaded") except Exception as e: print(f"❌ VAE Download Failed: {e}") # 2. MuseTalk Weights (Main Model) mt_path = "./Musetalk/models/musetalk" if not os.path.exists(os.path.join(mt_path, "pytorch_model.bin")): print("âŦ‡ī¸ Downloading MuseTalk Weights...") os.makedirs(mt_path, exist_ok=True) try: snapshot_download( repo_id="TMElyralab/MuseTalk", local_dir="./Musetalk/models", allow_patterns=["musetalk/*", "pytorch_model.bin"], local_dir_use_symlinks=False ) print("✅ MuseTalk Weights Downloaded") except Exception as e: print(f"❌ MuseTalk Download Failed: {e}") # 3. DWPose (For Face Tracking) dw_path = "./Musetalk/models/dwpose" if not os.path.exists(os.path.join(dw_path, "dw-ll_ucoco_384.pth")): print("âŦ‡ī¸ Downloading DWPose...") os.makedirs(dw_path, exist_ok=True) try: hf_hub_download(repo_id="yzd-v/DWPose", filename="dw-ll_ucoco_384.pth", local_dir=dw_path) hf_hub_download(repo_id="yzd-v/DWPose", filename="yolox_l.onnx", local_dir=dw_path) print("✅ DWPose Downloaded") except Exception as e: print(f"❌ DWPose Download Failed: {e}") # 4. Face Parser fp_path = "./Musetalk/models/face-parse-bisent" if not os.path.exists(os.path.join(fp_path, "79999_iter.pth")): print("âŦ‡ī¸ Downloading Face Parser...") os.makedirs(fp_path, exist_ok=True) try: hf_hub_download(repo_id="leonelhs/faceparser", filename="79999_iter.pth", local_dir=fp_path) hf_hub_download(repo_id="leonelhs/faceparser", filename="resnet18-5c106cde.pth", local_dir=fp_path) print("✅ Face Parser Downloaded") except Exception as e: print(f"❌ Face Parser Download Failed: {e}") # 5. Whisper (Audio Feature Extractor) wh_path = "./Musetalk/models/whisper" if not os.path.exists(os.path.join(wh_path, "tiny.pt")): print("âŦ‡ī¸ Downloading Whisper...") os.makedirs(wh_path, exist_ok=True) try: # Using external URL for Whisper tiny.pt os.system(f"wget -nc -O {wh_path}/tiny.pt https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt") print("✅ Whisper Downloaded") except Exception as e: print(f"❌ Whisper Download Failed: {e}") # 🚀 EXECUTE CHECK IMMEDIATELY check_and_download_models() # --- IMPORTS AFTER MODEL CHECK --- from LLM.GeminiLive import GeminiLiveClient from TFG.Streamer import AudioBuffer # Try importing MuseTalk, but don't fail if it crashes on first load (rely on init_model) try: from TFG.MuseTalk import MuseTalk_RealTime except ImportError: MuseTalk_RealTime = None # --- CONFIGURATION --- DEFAULT_AVATAR_VIDEO = "./Musetalk/data/video/yongen_musev.mp4" WSS_URL = "wss://gemini-live-bridge-production.up.railway.app/ws" DEFAULT_BBOX_SHIFT = 5 # Fixed default # --- GLOBAL STATE --- client = None audio_buffer = None musetalker = None avatar_prepared = False current_avatar_path = None is_streaming = False # --- CORE FUNCTIONS --- def init_audio_system(): global client, audio_buffer if client is None: client = GeminiLiveClient(websocket_url=WSS_URL) if audio_buffer is None: audio_buffer = AudioBuffer(sample_rate=16000, context_size_seconds=0.2) def init_model(): global musetalker if musetalker is None: if MuseTalk_RealTime is None: print("❌ MuseTalk module not available") return None print("🚀 Loading MuseTalk Model...") musetalker = MuseTalk_RealTime() musetalker.init_model() print("✅ MuseTalk Model Loaded") return musetalker def prepare_avatar_internal(avatar_path): global avatar_prepared, current_avatar_path, musetalker try: model = init_model() if model is None: return False init_audio_system() # Reset state if already prepped if avatar_prepared: avatar_prepared = False if audio_buffer: audio_buffer.clear() if hasattr(model, 'input_latent_list_cycle'): model.input_latent_list_cycle = None print(f"🎭 Preparing Avatar: {os.path.basename(avatar_path)}") # Use default bbox shift of 5 model.prepare_material(avatar_path, DEFAULT_BBOX_SHIFT) current_avatar_path = avatar_path avatar_prepared = True if audio_buffer: audio_buffer.clear() return True except Exception as e: print(f"❌ Error Preparing Avatar: {e}") return False async def toggle_streaming(current_state): global is_streaming, client, avatar_prepared # If currently streaming -> Stop if current_state: is_streaming = False if client: await client.close() return "â–ļī¸ Start Streaming", "Stream Stopped", False # If Not streaming -> Start # 1. Ensure Avatar is Ready (Use default if none selected) if not avatar_prepared: yield "âŗ Preparing Avatar...", "Initializing Default Avatar...", True success = prepare_avatar_internal(DEFAULT_AVATAR_VIDEO) if not success: yield "â–ļī¸ Start Streaming", "❌ Avatar Init Failed", False return # 2. Connect to Gemini Live init_audio_system() yield "âŗ Connecting...", "Connecting to Gemini...", True try: success = await client.connect() if success: is_streaming = True yield "âšī¸ Stop Streaming", "✅ LIVE: Connected to Gemini", True else: is_streaming = False yield "â–ļī¸ Start Streaming", "❌ Connection Failed", False except Exception as e: is_streaming = False yield "â–ļī¸ Start Streaming", f"❌ Error: {str(e)}", False def change_avatar_handler(file): if file is None: return "❌ No file selected", gr.update() msg = f"âŗ Switching to {os.path.basename(file)}..." # We can yield this message immediately success = prepare_avatar_internal(file) if success: return f"✅ Avatar Changed: {os.path.basename(file)}", gr.update(visible=False) # Hide upload row else: return "❌ Failed to change avatar", gr.update() async def stream_loop(audio_data): global is_streaming ret_frame, ret_audio = None, None if not is_streaming or not client or not client.running or not avatar_prepared or not musetalker: return None, None # Send Mic Audio if audio_data is not None: sr, y = audio_data if y.size > 0: await client.send_audio(y, original_sr=sr) # Receive AI Audio & buffer it new_audio_chunks = [] while not client.output_queue.empty(): try: chunk = client.output_queue.get_nowait() if chunk is not None: audio_buffer.push(chunk) new_audio_chunks.append(chunk) except asyncio.QueueEmpty: break if new_audio_chunks: ret_audio = (16000, np.concatenate(new_audio_chunks)) # Generate Video Frame from buffer current_window = audio_buffer.get_window() if current_window is not None: try: ret_frame = musetalker.inference_streaming( audio_buffer_16k=current_window, return_frame_only=False ) except Exception: pass return ret_frame, ret_audio # --- CSS STYLING --- css = """ .video-container { height: 600px !important; width: 600px !important; border-radius: 20px; overflow: hidden; background-color: #000; margin: 0 auto; box-shadow: 0 10px 30px rgba(0,0,0,0.5); } .btn-primary { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; color: white !important; border: none !important; font-size: 1.2em !important; padding: 12px !important; } .btn-secondary { background: #4a5568 !important; color: white !important; font-size: 1em !important; } """ # --- UI LAYOUT --- with gr.Blocks(title="Gemini Live Avatar", theme=gr.themes.Soft(), css=css) as demo: # State to track if we are currently streaming streaming_state = gr.State(False) with gr.Column(elem_classes="main-container"): # 1. MAIN VIDEO FEED (Centered & Large) with gr.Row(): avatar_img = gr.Image( label="AI Video Feed", interactive=False, streaming=True, show_label=False, type="numpy", elem_classes="video-container" ) # 2. STATUS & CONTROLS (Clean Row) with gr.Row(elem_id="controls"): status_display = gr.Textbox( value="Ready", show_label=False, container=False, interactive=False, text_align="center" ) with gr.Row(): # LEFT: Start/Stop Stream start_btn = gr.Button("â–ļī¸ Start Streaming", variant="primary", scale=2, elem_classes="btn-primary") # RIGHT: Change Avatar change_avatar_btn = gr.Button("👤 Change Avatar", variant="secondary", scale=1, elem_classes="btn-secondary") # 3. HIDDEN UPLOAD FOR "CHANGE AVATAR" with gr.Row(visible=False) as upload_row: avatar_file_input = gr.File(label="Upload New Avatar Image/Video", type="filepath") # 4. HIDDEN AUDIO I/O mic = gr.Audio(sources=["microphone"], type="numpy", streaming=True, visible=False) speaker = gr.Audio(visible=False, autoplay=True, streaming=True) # --- EVENTS --- # Start/Stop Streaming start_btn.click( toggle_streaming, inputs=[streaming_state], outputs=[start_btn, status_display, streaming_state] ) # Change Avatar Flow change_avatar_btn.click( lambda: gr.update(visible=True), # Show upload row outputs=[upload_row] ) avatar_file_input.change( change_avatar_handler, inputs=[avatar_file_input], outputs=[status_display, upload_row] # Hide upload row on success ) # Main Stream Loop mic.stream( stream_loop, inputs=[mic], outputs=[avatar_img, speaker], stream_every=0.04, time_limit=300 ) if __name__ == "__main__": demo.queue().launch( server_name="0.0.0.0", server_port=7860, allowed_paths=["./Musetalk", "./models"] )