Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import asyncio | |
| import numpy as np | |
| import os | |
| import shutil | |
| import warnings | |
| import time | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| # Suppress warnings | |
| warnings.filterwarnings('ignore') | |
| # --- 🛠️ SELF-HEALING MODEL DOWNLOADER (Run at Startup) --- | |
| def check_and_download_models(): | |
| """ | |
| Checks if required models exist. If not, downloads them. | |
| Also handles the 'models/sd-vae' crash by ensuring the folder is correct. | |
| """ | |
| print("🔍 Checking Model Integrity...") | |
| # 1. FIX: The specific VAE folder causing your crash | |
| vae_path = "./models/sd-vae" | |
| vae_config = os.path.join(vae_path, "config.json") | |
| # If folder exists but empty/broken, delete it (Heals the "OSError") | |
| if os.path.exists(vae_path) and not os.path.exists(vae_config): | |
| print(f"⚠️ Found broken VAE folder at {vae_path}. Deleting to fix...") | |
| shutil.rmtree(vae_path) | |
| # Download VAE if missing | |
| if not os.path.exists(vae_config): | |
| print("⬇️ Downloading VAE (sd-vae-ft-mse)...") | |
| try: | |
| os.makedirs(vae_path, exist_ok=True) | |
| hf_hub_download(repo_id="stabilityai/sd-vae-ft-mse", filename="config.json", local_dir=vae_path) | |
| hf_hub_download(repo_id="stabilityai/sd-vae-ft-mse", filename="diffusion_pytorch_model.bin", local_dir=vae_path) | |
| print("✅ VAE Downloaded") | |
| except Exception as e: | |
| print(f"❌ VAE Download Failed: {e}") | |
| # 2. MuseTalk Weights (Main Model) | |
| mt_path = "./Musetalk/models/musetalk" | |
| if not os.path.exists(os.path.join(mt_path, "pytorch_model.bin")): | |
| print("⬇️ Downloading MuseTalk Weights...") | |
| os.makedirs(mt_path, exist_ok=True) | |
| try: | |
| snapshot_download( | |
| repo_id="TMElyralab/MuseTalk", | |
| local_dir="./Musetalk/models", | |
| allow_patterns=["musetalk/*", "pytorch_model.bin"], | |
| local_dir_use_symlinks=False | |
| ) | |
| print("✅ MuseTalk Weights Downloaded") | |
| except Exception as e: | |
| print(f"❌ MuseTalk Download Failed: {e}") | |
| # 3. DWPose (For Face Tracking) | |
| dw_path = "./Musetalk/models/dwpose" | |
| if not os.path.exists(os.path.join(dw_path, "dw-ll_ucoco_384.pth")): | |
| print("⬇️ Downloading DWPose...") | |
| os.makedirs(dw_path, exist_ok=True) | |
| try: | |
| hf_hub_download(repo_id="yzd-v/DWPose", filename="dw-ll_ucoco_384.pth", local_dir=dw_path) | |
| hf_hub_download(repo_id="yzd-v/DWPose", filename="yolox_l.onnx", local_dir=dw_path) | |
| print("✅ DWPose Downloaded") | |
| except Exception as e: | |
| print(f"❌ DWPose Download Failed: {e}") | |
| # 4. Face Parser | |
| fp_path = "./Musetalk/models/face-parse-bisent" | |
| if not os.path.exists(os.path.join(fp_path, "79999_iter.pth")): | |
| print("⬇️ Downloading Face Parser...") | |
| os.makedirs(fp_path, exist_ok=True) | |
| try: | |
| hf_hub_download(repo_id="leonelhs/faceparser", filename="79999_iter.pth", local_dir=fp_path) | |
| hf_hub_download(repo_id="leonelhs/faceparser", filename="resnet18-5c106cde.pth", local_dir=fp_path) | |
| print("✅ Face Parser Downloaded") | |
| except Exception as e: | |
| print(f"❌ Face Parser Download Failed: {e}") | |
| # 5. Whisper (Audio Feature Extractor) | |
| wh_path = "./Musetalk/models/whisper" | |
| if not os.path.exists(os.path.join(wh_path, "tiny.pt")): | |
| print("⬇️ Downloading Whisper...") | |
| os.makedirs(wh_path, exist_ok=True) | |
| try: | |
| # Using external URL for Whisper tiny.pt | |
| os.system(f"wget -nc -O {wh_path}/tiny.pt https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt") | |
| print("✅ Whisper Downloaded") | |
| except Exception as e: | |
| print(f"❌ Whisper Download Failed: {e}") | |
| # 🚀 EXECUTE CHECK IMMEDIATELY | |
| check_and_download_models() | |
| # --- IMPORTS AFTER MODEL CHECK --- | |
| from LLM.GeminiLive import GeminiLiveClient | |
| from TFG.Streamer import AudioBuffer | |
| # Try importing MuseTalk, but don't fail if it crashes on first load (rely on init_model) | |
| try: | |
| from TFG.MuseTalk import MuseTalk_RealTime | |
| except ImportError: | |
| MuseTalk_RealTime = None | |
| # --- CONFIGURATION --- | |
| DEFAULT_AVATAR_VIDEO = "./Musetalk/data/video/yongen_musev.mp4" | |
| WSS_URL = "wss://gemini-live-bridge-production.up.railway.app/ws" | |
| DEFAULT_BBOX_SHIFT = 5 # Fixed default | |
| # --- GLOBAL STATE --- | |
| client = None | |
| audio_buffer = None | |
| musetalker = None | |
| avatar_prepared = False | |
| current_avatar_path = None | |
| is_streaming = False | |
| # --- CORE FUNCTIONS --- | |
| def init_audio_system(): | |
| global client, audio_buffer | |
| if client is None: | |
| client = GeminiLiveClient(websocket_url=WSS_URL) | |
| if audio_buffer is None: | |
| audio_buffer = AudioBuffer(sample_rate=16000, context_size_seconds=0.2) | |
| def init_model(): | |
| global musetalker | |
| if musetalker is None: | |
| if MuseTalk_RealTime is None: | |
| print("❌ MuseTalk module not available") | |
| return None | |
| print("🚀 Loading MuseTalk Model...") | |
| musetalker = MuseTalk_RealTime() | |
| musetalker.init_model() | |
| print("✅ MuseTalk Model Loaded") | |
| return musetalker | |
| def prepare_avatar_internal(avatar_path): | |
| global avatar_prepared, current_avatar_path, musetalker | |
| try: | |
| model = init_model() | |
| if model is None: return False | |
| init_audio_system() | |
| # Reset state if already prepped | |
| if avatar_prepared: | |
| avatar_prepared = False | |
| if audio_buffer: audio_buffer.clear() | |
| if hasattr(model, 'input_latent_list_cycle'): model.input_latent_list_cycle = None | |
| print(f"🎭 Preparing Avatar: {os.path.basename(avatar_path)}") | |
| # Use default bbox shift of 5 | |
| model.prepare_material(avatar_path, DEFAULT_BBOX_SHIFT) | |
| current_avatar_path = avatar_path | |
| avatar_prepared = True | |
| if audio_buffer: audio_buffer.clear() | |
| return True | |
| except Exception as e: | |
| print(f"❌ Error Preparing Avatar: {e}") | |
| return False | |
| async def toggle_streaming(current_state): | |
| global is_streaming, client, avatar_prepared | |
| # If currently streaming -> Stop | |
| if current_state: | |
| is_streaming = False | |
| if client: await client.close() | |
| return "▶️ Start Streaming", "Stream Stopped", False | |
| # If Not streaming -> Start | |
| # 1. Ensure Avatar is Ready (Use default if none selected) | |
| if not avatar_prepared: | |
| yield "⏳ Preparing Avatar...", "Initializing Default Avatar...", True | |
| success = prepare_avatar_internal(DEFAULT_AVATAR_VIDEO) | |
| if not success: | |
| yield "▶️ Start Streaming", "❌ Avatar Init Failed", False | |
| return | |
| # 2. Connect to Gemini Live | |
| init_audio_system() | |
| yield "⏳ Connecting...", "Connecting to Gemini...", True | |
| try: | |
| success = await client.connect() | |
| if success: | |
| is_streaming = True | |
| yield "⏹️ Stop Streaming", "✅ LIVE: Connected to Gemini", True | |
| else: | |
| is_streaming = False | |
| yield "▶️ Start Streaming", "❌ Connection Failed", False | |
| except Exception as e: | |
| is_streaming = False | |
| yield "▶️ Start Streaming", f"❌ Error: {str(e)}", False | |
| def change_avatar_handler(file): | |
| if file is None: | |
| return "❌ No file selected", gr.update() | |
| msg = f"⏳ Switching to {os.path.basename(file)}..." | |
| # We can yield this message immediately | |
| success = prepare_avatar_internal(file) | |
| if success: | |
| return f"✅ Avatar Changed: {os.path.basename(file)}", gr.update(visible=False) # Hide upload row | |
| else: | |
| return "❌ Failed to change avatar", gr.update() | |
| async def stream_loop(audio_data): | |
| global is_streaming | |
| ret_frame, ret_audio = None, None | |
| if not is_streaming or not client or not client.running or not avatar_prepared or not musetalker: | |
| return None, None | |
| # Send Mic Audio | |
| if audio_data is not None: | |
| sr, y = audio_data | |
| if y.size > 0: | |
| await client.send_audio(y, original_sr=sr) | |
| # Receive AI Audio & buffer it | |
| new_audio_chunks = [] | |
| while not client.output_queue.empty(): | |
| try: | |
| chunk = client.output_queue.get_nowait() | |
| if chunk is not None: | |
| audio_buffer.push(chunk) | |
| new_audio_chunks.append(chunk) | |
| except asyncio.QueueEmpty: | |
| break | |
| if new_audio_chunks: | |
| ret_audio = (16000, np.concatenate(new_audio_chunks)) | |
| # Generate Video Frame from buffer | |
| current_window = audio_buffer.get_window() | |
| if current_window is not None: | |
| try: | |
| ret_frame = musetalker.inference_streaming( | |
| audio_buffer_16k=current_window, | |
| return_frame_only=False | |
| ) | |
| except Exception: | |
| pass | |
| return ret_frame, ret_audio | |
| # --- CSS STYLING --- | |
| css = """ | |
| .video-container { | |
| height: 600px !important; | |
| width: 600px !important; | |
| border-radius: 20px; | |
| overflow: hidden; | |
| background-color: #000; | |
| margin: 0 auto; | |
| box-shadow: 0 10px 30px rgba(0,0,0,0.5); | |
| } | |
| .btn-primary { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
| color: white !important; | |
| border: none !important; | |
| font-size: 1.2em !important; | |
| padding: 12px !important; | |
| } | |
| .btn-secondary { | |
| background: #4a5568 !important; | |
| color: white !important; | |
| font-size: 1em !important; | |
| } | |
| """ | |
| # --- UI LAYOUT --- | |
| with gr.Blocks(title="Gemini Live Avatar", theme=gr.themes.Soft(), css=css) as demo: | |
| # State to track if we are currently streaming | |
| streaming_state = gr.State(False) | |
| with gr.Column(elem_classes="main-container"): | |
| # 1. MAIN VIDEO FEED (Centered & Large) | |
| with gr.Row(): | |
| avatar_img = gr.Image( | |
| label="AI Video Feed", | |
| interactive=False, | |
| streaming=True, | |
| show_label=False, | |
| type="numpy", | |
| elem_classes="video-container" | |
| ) | |
| # 2. STATUS & CONTROLS (Clean Row) | |
| with gr.Row(elem_id="controls"): | |
| status_display = gr.Textbox( | |
| value="Ready", | |
| show_label=False, | |
| container=False, | |
| interactive=False, | |
| text_align="center" | |
| ) | |
| with gr.Row(): | |
| # LEFT: Start/Stop Stream | |
| start_btn = gr.Button("▶️ Start Streaming", variant="primary", scale=2, elem_classes="btn-primary") | |
| # RIGHT: Change Avatar | |
| change_avatar_btn = gr.Button("👤 Change Avatar", variant="secondary", scale=1, elem_classes="btn-secondary") | |
| # 3. HIDDEN UPLOAD FOR "CHANGE AVATAR" | |
| with gr.Row(visible=False) as upload_row: | |
| avatar_file_input = gr.File(label="Upload New Avatar Image/Video", type="filepath") | |
| # 4. HIDDEN AUDIO I/O | |
| mic = gr.Audio(sources=["microphone"], type="numpy", streaming=True, visible=False) | |
| speaker = gr.Audio(visible=False, autoplay=True, streaming=True) | |
| # --- EVENTS --- | |
| # Start/Stop Streaming | |
| start_btn.click( | |
| toggle_streaming, | |
| inputs=[streaming_state], | |
| outputs=[start_btn, status_display, streaming_state] | |
| ) | |
| # Change Avatar Flow | |
| change_avatar_btn.click( | |
| lambda: gr.update(visible=True), # Show upload row | |
| outputs=[upload_row] | |
| ) | |
| avatar_file_input.change( | |
| change_avatar_handler, | |
| inputs=[avatar_file_input], | |
| outputs=[status_display, upload_row] # Hide upload row on success | |
| ) | |
| # Main Stream Loop | |
| mic.stream( | |
| stream_loop, | |
| inputs=[mic], | |
| outputs=[avatar_img, speaker], | |
| stream_every=0.04, | |
| time_limit=300 | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| allowed_paths=["./Musetalk", "./models"] | |
| ) |