personaxgemini / webui.py
eshwar06's picture
Update webui.py
3ee4df5 verified
import gradio as gr
import asyncio
import numpy as np
import os
import shutil
import warnings
import time
from huggingface_hub import hf_hub_download, snapshot_download
# Suppress warnings
warnings.filterwarnings('ignore')
# --- 🛠️ SELF-HEALING MODEL DOWNLOADER (Run at Startup) ---
def check_and_download_models():
"""
Checks if required models exist. If not, downloads them.
Also handles the 'models/sd-vae' crash by ensuring the folder is correct.
"""
print("🔍 Checking Model Integrity...")
# 1. FIX: The specific VAE folder causing your crash
vae_path = "./models/sd-vae"
vae_config = os.path.join(vae_path, "config.json")
# If folder exists but empty/broken, delete it (Heals the "OSError")
if os.path.exists(vae_path) and not os.path.exists(vae_config):
print(f"⚠️ Found broken VAE folder at {vae_path}. Deleting to fix...")
shutil.rmtree(vae_path)
# Download VAE if missing
if not os.path.exists(vae_config):
print("⬇️ Downloading VAE (sd-vae-ft-mse)...")
try:
os.makedirs(vae_path, exist_ok=True)
hf_hub_download(repo_id="stabilityai/sd-vae-ft-mse", filename="config.json", local_dir=vae_path)
hf_hub_download(repo_id="stabilityai/sd-vae-ft-mse", filename="diffusion_pytorch_model.bin", local_dir=vae_path)
print("✅ VAE Downloaded")
except Exception as e:
print(f"❌ VAE Download Failed: {e}")
# 2. MuseTalk Weights (Main Model)
mt_path = "./Musetalk/models/musetalk"
if not os.path.exists(os.path.join(mt_path, "pytorch_model.bin")):
print("⬇️ Downloading MuseTalk Weights...")
os.makedirs(mt_path, exist_ok=True)
try:
snapshot_download(
repo_id="TMElyralab/MuseTalk",
local_dir="./Musetalk/models",
allow_patterns=["musetalk/*", "pytorch_model.bin"],
local_dir_use_symlinks=False
)
print("✅ MuseTalk Weights Downloaded")
except Exception as e:
print(f"❌ MuseTalk Download Failed: {e}")
# 3. DWPose (For Face Tracking)
dw_path = "./Musetalk/models/dwpose"
if not os.path.exists(os.path.join(dw_path, "dw-ll_ucoco_384.pth")):
print("⬇️ Downloading DWPose...")
os.makedirs(dw_path, exist_ok=True)
try:
hf_hub_download(repo_id="yzd-v/DWPose", filename="dw-ll_ucoco_384.pth", local_dir=dw_path)
hf_hub_download(repo_id="yzd-v/DWPose", filename="yolox_l.onnx", local_dir=dw_path)
print("✅ DWPose Downloaded")
except Exception as e:
print(f"❌ DWPose Download Failed: {e}")
# 4. Face Parser
fp_path = "./Musetalk/models/face-parse-bisent"
if not os.path.exists(os.path.join(fp_path, "79999_iter.pth")):
print("⬇️ Downloading Face Parser...")
os.makedirs(fp_path, exist_ok=True)
try:
hf_hub_download(repo_id="leonelhs/faceparser", filename="79999_iter.pth", local_dir=fp_path)
hf_hub_download(repo_id="leonelhs/faceparser", filename="resnet18-5c106cde.pth", local_dir=fp_path)
print("✅ Face Parser Downloaded")
except Exception as e:
print(f"❌ Face Parser Download Failed: {e}")
# 5. Whisper (Audio Feature Extractor)
wh_path = "./Musetalk/models/whisper"
if not os.path.exists(os.path.join(wh_path, "tiny.pt")):
print("⬇️ Downloading Whisper...")
os.makedirs(wh_path, exist_ok=True)
try:
# Using external URL for Whisper tiny.pt
os.system(f"wget -nc -O {wh_path}/tiny.pt https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt")
print("✅ Whisper Downloaded")
except Exception as e:
print(f"❌ Whisper Download Failed: {e}")
# 🚀 EXECUTE CHECK IMMEDIATELY
check_and_download_models()
# --- IMPORTS AFTER MODEL CHECK ---
from LLM.GeminiLive import GeminiLiveClient
from TFG.Streamer import AudioBuffer
# Try importing MuseTalk, but don't fail if it crashes on first load (rely on init_model)
try:
from TFG.MuseTalk import MuseTalk_RealTime
except ImportError:
MuseTalk_RealTime = None
# --- CONFIGURATION ---
DEFAULT_AVATAR_VIDEO = "./Musetalk/data/video/yongen_musev.mp4"
WSS_URL = "wss://gemini-live-bridge-production.up.railway.app/ws"
DEFAULT_BBOX_SHIFT = 5 # Fixed default
# --- GLOBAL STATE ---
client = None
audio_buffer = None
musetalker = None
avatar_prepared = False
current_avatar_path = None
is_streaming = False
# --- CORE FUNCTIONS ---
def init_audio_system():
global client, audio_buffer
if client is None:
client = GeminiLiveClient(websocket_url=WSS_URL)
if audio_buffer is None:
audio_buffer = AudioBuffer(sample_rate=16000, context_size_seconds=0.2)
def init_model():
global musetalker
if musetalker is None:
if MuseTalk_RealTime is None:
print("❌ MuseTalk module not available")
return None
print("🚀 Loading MuseTalk Model...")
musetalker = MuseTalk_RealTime()
musetalker.init_model()
print("✅ MuseTalk Model Loaded")
return musetalker
def prepare_avatar_internal(avatar_path):
global avatar_prepared, current_avatar_path, musetalker
try:
model = init_model()
if model is None: return False
init_audio_system()
# Reset state if already prepped
if avatar_prepared:
avatar_prepared = False
if audio_buffer: audio_buffer.clear()
if hasattr(model, 'input_latent_list_cycle'): model.input_latent_list_cycle = None
print(f"🎭 Preparing Avatar: {os.path.basename(avatar_path)}")
# Use default bbox shift of 5
model.prepare_material(avatar_path, DEFAULT_BBOX_SHIFT)
current_avatar_path = avatar_path
avatar_prepared = True
if audio_buffer: audio_buffer.clear()
return True
except Exception as e:
print(f"❌ Error Preparing Avatar: {e}")
return False
async def toggle_streaming(current_state):
global is_streaming, client, avatar_prepared
# If currently streaming -> Stop
if current_state:
is_streaming = False
if client: await client.close()
return "▶️ Start Streaming", "Stream Stopped", False
# If Not streaming -> Start
# 1. Ensure Avatar is Ready (Use default if none selected)
if not avatar_prepared:
yield "⏳ Preparing Avatar...", "Initializing Default Avatar...", True
success = prepare_avatar_internal(DEFAULT_AVATAR_VIDEO)
if not success:
yield "▶️ Start Streaming", "❌ Avatar Init Failed", False
return
# 2. Connect to Gemini Live
init_audio_system()
yield "⏳ Connecting...", "Connecting to Gemini...", True
try:
success = await client.connect()
if success:
is_streaming = True
yield "⏹️ Stop Streaming", "✅ LIVE: Connected to Gemini", True
else:
is_streaming = False
yield "▶️ Start Streaming", "❌ Connection Failed", False
except Exception as e:
is_streaming = False
yield "▶️ Start Streaming", f"❌ Error: {str(e)}", False
def change_avatar_handler(file):
if file is None:
return "❌ No file selected", gr.update()
msg = f"⏳ Switching to {os.path.basename(file)}..."
# We can yield this message immediately
success = prepare_avatar_internal(file)
if success:
return f"✅ Avatar Changed: {os.path.basename(file)}", gr.update(visible=False) # Hide upload row
else:
return "❌ Failed to change avatar", gr.update()
async def stream_loop(audio_data):
global is_streaming
ret_frame, ret_audio = None, None
if not is_streaming or not client or not client.running or not avatar_prepared or not musetalker:
return None, None
# Send Mic Audio
if audio_data is not None:
sr, y = audio_data
if y.size > 0:
await client.send_audio(y, original_sr=sr)
# Receive AI Audio & buffer it
new_audio_chunks = []
while not client.output_queue.empty():
try:
chunk = client.output_queue.get_nowait()
if chunk is not None:
audio_buffer.push(chunk)
new_audio_chunks.append(chunk)
except asyncio.QueueEmpty:
break
if new_audio_chunks:
ret_audio = (16000, np.concatenate(new_audio_chunks))
# Generate Video Frame from buffer
current_window = audio_buffer.get_window()
if current_window is not None:
try:
ret_frame = musetalker.inference_streaming(
audio_buffer_16k=current_window,
return_frame_only=False
)
except Exception:
pass
return ret_frame, ret_audio
# --- CSS STYLING ---
css = """
.video-container {
height: 600px !important;
width: 600px !important;
border-radius: 20px;
overflow: hidden;
background-color: #000;
margin: 0 auto;
box-shadow: 0 10px 30px rgba(0,0,0,0.5);
}
.btn-primary {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
color: white !important;
border: none !important;
font-size: 1.2em !important;
padding: 12px !important;
}
.btn-secondary {
background: #4a5568 !important;
color: white !important;
font-size: 1em !important;
}
"""
# --- UI LAYOUT ---
with gr.Blocks(title="Gemini Live Avatar", theme=gr.themes.Soft(), css=css) as demo:
# State to track if we are currently streaming
streaming_state = gr.State(False)
with gr.Column(elem_classes="main-container"):
# 1. MAIN VIDEO FEED (Centered & Large)
with gr.Row():
avatar_img = gr.Image(
label="AI Video Feed",
interactive=False,
streaming=True,
show_label=False,
type="numpy",
elem_classes="video-container"
)
# 2. STATUS & CONTROLS (Clean Row)
with gr.Row(elem_id="controls"):
status_display = gr.Textbox(
value="Ready",
show_label=False,
container=False,
interactive=False,
text_align="center"
)
with gr.Row():
# LEFT: Start/Stop Stream
start_btn = gr.Button("▶️ Start Streaming", variant="primary", scale=2, elem_classes="btn-primary")
# RIGHT: Change Avatar
change_avatar_btn = gr.Button("👤 Change Avatar", variant="secondary", scale=1, elem_classes="btn-secondary")
# 3. HIDDEN UPLOAD FOR "CHANGE AVATAR"
with gr.Row(visible=False) as upload_row:
avatar_file_input = gr.File(label="Upload New Avatar Image/Video", type="filepath")
# 4. HIDDEN AUDIO I/O
mic = gr.Audio(sources=["microphone"], type="numpy", streaming=True, visible=False)
speaker = gr.Audio(visible=False, autoplay=True, streaming=True)
# --- EVENTS ---
# Start/Stop Streaming
start_btn.click(
toggle_streaming,
inputs=[streaming_state],
outputs=[start_btn, status_display, streaming_state]
)
# Change Avatar Flow
change_avatar_btn.click(
lambda: gr.update(visible=True), # Show upload row
outputs=[upload_row]
)
avatar_file_input.change(
change_avatar_handler,
inputs=[avatar_file_input],
outputs=[status_display, upload_row] # Hide upload row on success
)
# Main Stream Loop
mic.stream(
stream_loop,
inputs=[mic],
outputs=[avatar_img, speaker],
stream_every=0.04,
time_limit=300
)
if __name__ == "__main__":
demo.queue().launch(
server_name="0.0.0.0",
server_port=7860,
allowed_paths=["./Musetalk", "./models"]
)