agkavin
Fix pipeline deadlocks, remove torch.compile, implement 3-queue parallel pipeline, optimize for 16fps
a4cc15e | """ | |
| Speech-to-Video Server β e2e (Three-Queue Parallel Pipeline) | |
| ============================================================= | |
| Model loading is split from session management: | |
| Server startup (lifespan) | |
| ββ load MuseTalk bundle (VAE, UNet, Whisper, avatar latents) β once, stays in VRAM | |
| ββ load Kokoro TTS β once | |
| ββ torch.compile + UNet warmup β once (cached to disk) | |
| POST /connect (user opens the app) | |
| ββ create Room, Publisher, MuseTalkWorker, Pipeline β per-session | |
| ββ connect to LiveKit β per-session | |
| ββ fast warmup (models already hot) β per-session | |
| POST /disconnect | |
| ββ stop pipeline/publisher, disconnect room β models stay loaded | |
| Run: | |
| python -m backend.e2e.server | |
| # or: | |
| cd backend && uvicorn e2e.server:app --host 0.0.0.0 --port 8767 | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import logging | |
| import os | |
| import sys | |
| import time | |
| from contextlib import asynccontextmanager | |
| from pathlib import Path | |
| from typing import Optional | |
| # ββ path setup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _current_file = Path(__file__).resolve() | |
| _e2e_dir = _current_file.parent | |
| _backend_dir = _e2e_dir.parent | |
| _project_dir = _backend_dir.parent | |
| for p in [_backend_dir, _project_dir]: | |
| if str(p) not in sys.path: | |
| sys.path.insert(0, str(p)) | |
| # ββ imports βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| import numpy as np | |
| import uvicorn | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from livekit import rtc | |
| from livekit import api as lk_api | |
| from config import ( | |
| HOST, | |
| LIVEKIT_URL, | |
| LIVEKIT_API_KEY, | |
| LIVEKIT_API_SECRET, | |
| LIVEKIT_ROOM_NAME, | |
| VIDEO_FPS, | |
| DEFAULT_AVATAR, | |
| DEVICE, | |
| PORT, | |
| ) | |
| from tts.kokoro_tts import KokoroTTS | |
| from musetalk.worker import load_musetalk_models, MuseTalkWorker, MuseTalkBundle | |
| from publisher.livekit_publisher import AVPublisher | |
| from e2e.pipeline import StreamingPipeline | |
| import torch | |
| torch.set_float32_matmul_precision("high") | |
| torch._dynamo.config.suppress_errors = True | |
| log = logging.getLogger(__name__) | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)-7s %(name)s %(message)s", | |
| ) | |
| # ββ global model state (loaded once, lives for server lifetime) βββββββββββββββ | |
| _musetalk_bundle: Optional[MuseTalkBundle] = None | |
| _tts: Optional[KokoroTTS] = None | |
| # ββ session state (created/destroyed on connect/disconnect) ββββββββββββββββββ | |
| _pipeline: Optional[StreamingPipeline] = None | |
| _room: Optional[rtc.Room] = None | |
| _publisher: Optional[AVPublisher] = None | |
| # ββ lifespan: load models once at startup ββββββββββββββββββββββββββββββββββββ | |
| async def lifespan(app: FastAPI): | |
| global _musetalk_bundle, _tts | |
| t_start = time.monotonic() | |
| log.info("=== Speech-to-Video e2e Server Starting ===") | |
| log.info("Device: %s Avatar: %s", DEVICE, DEFAULT_AVATAR) | |
| # ββ 1. Load MuseTalk (VAE + UNet + Whisper + avatar latents) βββββββββββββ | |
| log.info("Loading MuseTalk models...") | |
| _musetalk_bundle = await asyncio.to_thread( | |
| load_musetalk_models, DEFAULT_AVATAR, DEVICE | |
| ) | |
| log.info("MuseTalk loaded (%.1fs)", time.monotonic() - t_start) | |
| # ββ 2. Load Kokoro TTS ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| log.info("Loading Kokoro TTS...") | |
| _tts = await asyncio.to_thread(KokoroTTS) | |
| log.info("Kokoro TTS loaded") | |
| # ββ 3. UNet warmup (prime GPU caches, no torch.compile) ββββββββββββββββββ | |
| worker_tmp = MuseTalkWorker(_musetalk_bundle) | |
| dummy_audio = np.zeros(int(0.32 * 24_000), dtype=np.float32) | |
| feats, _ = await worker_tmp.extract_features(dummy_audio) | |
| t0 = time.monotonic() | |
| n = min(8, len(_musetalk_bundle.avatar_assets.frame_list)) | |
| await worker_tmp.generate_batch(feats, 0, n) | |
| log.info("UNet warm-up done (%.1fs)", time.monotonic() - t0) | |
| worker_tmp.shutdown() # release executor threads β must not outlive warmup | |
| _tts.synthesize_full("Hello.") | |
| log.info("TTS warm-up done") | |
| log.info("=== Server ready in %.1fs β waiting for /connect (port %d) ===", | |
| time.monotonic() - t_start, PORT) | |
| yield # ββ server running ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ββ shutdown ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| global _pipeline, _room, _publisher | |
| if _pipeline: | |
| await _pipeline.stop() | |
| if _publisher: | |
| await _publisher.stop() | |
| if _room: | |
| await _room.disconnect() | |
| log.info("=== Server Shutdown ===") | |
| # ββ FastAPI app βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI( | |
| title="Speech-to-Video (e2e 3-queue)", | |
| description="Text β Kokoro TTS β Whisper β MuseTalk β LiveKit", | |
| version="2.0.0", | |
| lifespan=lifespan, | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ββ request models ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class SpeakRequest(BaseModel): | |
| text: str | |
| voice: Optional[str] = None | |
| speed: Optional[float] = None | |
| class TokenRequest(BaseModel): | |
| room_name: str = LIVEKIT_ROOM_NAME | |
| identity: str = "user" | |
| # ββ /health and /status βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def health(): | |
| return { | |
| "status": "ok", | |
| "models_loaded": _musetalk_bundle is not None and _tts is not None, | |
| "pipeline_active": _pipeline is not None and getattr(_pipeline, "_running", False), | |
| } | |
| async def status(): | |
| vram = {} | |
| if torch.cuda.is_available(): | |
| vram = { | |
| "allocated_gb": round(torch.cuda.memory_allocated() / 1024**3, 2), | |
| "reserved_gb": round(torch.cuda.memory_reserved() / 1024**3, 2), | |
| } | |
| return { | |
| "pipeline": "e2e-3-queue", | |
| "models_loaded": _musetalk_bundle is not None, | |
| "pipeline_active": _pipeline is not None and getattr(_pipeline, "_running", False), | |
| "avatar": DEFAULT_AVATAR, | |
| "device": DEVICE, | |
| "vram": vram, | |
| } | |
| # ββ /connect ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def connect(): | |
| global _room, _publisher, _pipeline | |
| if _musetalk_bundle is None or _tts is None: | |
| raise HTTPException(status_code=503, detail="Server still loading models") | |
| if _pipeline is not None and getattr(_pipeline, "_running", False): | |
| raise HTTPException(status_code=400, detail="Already connected") | |
| log.info("Connecting to LiveKit room...") | |
| t0 = time.monotonic() | |
| try: | |
| # ββ session objects (no model loading here) βββββββββββββββββββββββββββ | |
| first_frame = _musetalk_bundle.avatar_assets.frame_list[0] | |
| actual_h, actual_w = first_frame.shape[:2] | |
| room = rtc.Room() | |
| token = ( | |
| lk_api.AccessToken(LIVEKIT_API_KEY, LIVEKIT_API_SECRET) | |
| .with_identity("backend-agent") | |
| .with_name("Speech-to-Video Agent") | |
| ) | |
| token.with_grants(lk_api.VideoGrants( | |
| room_join=True, | |
| room=LIVEKIT_ROOM_NAME, | |
| can_publish=True, | |
| can_subscribe=True, | |
| )) | |
| publisher = AVPublisher( | |
| room, | |
| video_width=actual_w, | |
| video_height=actual_h, | |
| video_fps=VIDEO_FPS, | |
| ) | |
| # MuseTalkWorker wraps the already-loaded bundle β no model reload | |
| musetalk_worker = MuseTalkWorker(_musetalk_bundle) | |
| pipeline = StreamingPipeline( | |
| tts=_tts, | |
| musetalk=musetalk_worker, | |
| publisher=publisher, | |
| avatar_assets=_musetalk_bundle.avatar_assets, | |
| ) | |
| # ββ connect LiveKit β start pipeline ββββββββββββββββββββββββββββββββββ | |
| await room.connect(url=LIVEKIT_URL, token=token.to_jwt()) | |
| log.info("Connected to LiveKit: %s", LIVEKIT_ROOM_NAME) | |
| await publisher.start() | |
| await pipeline.start() | |
| # ββ fast warmup (models already hot in VRAM) ββββββββββββββββββββββββββ | |
| dummy_audio = np.zeros(int(0.32 * 24_000), dtype=np.float32) | |
| feats, _ = await musetalk_worker.extract_features(dummy_audio) | |
| n = min(8, len(_musetalk_bundle.avatar_assets.frame_list)) | |
| await musetalk_worker.generate_batch(feats, 0, n) | |
| log.info("Session warm-up done") | |
| _room = room | |
| _publisher = publisher | |
| _pipeline = pipeline | |
| log.info("/connect done in %.1fs", time.monotonic() - t0) | |
| return {"status": "connected", "room": LIVEKIT_ROOM_NAME, "url": LIVEKIT_URL} | |
| except Exception as exc: | |
| log.error("Connection failed: %s", exc, exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(exc)) | |
| # ββ /disconnect βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def disconnect(): | |
| global _room, _publisher, _pipeline | |
| if _pipeline is None: | |
| raise HTTPException(status_code=400, detail="Not connected") | |
| log.info("Disconnecting...") | |
| if _pipeline: | |
| await _pipeline.stop() | |
| if _publisher: | |
| await _publisher.stop() | |
| if _room: | |
| await _room.disconnect() | |
| _room = _publisher = _pipeline = None | |
| # NOTE: _musetalk_bundle and _tts are intentionally NOT cleared β | |
| # models stay in VRAM so the next /connect is instant. | |
| log.info("Disconnected β models remain loaded for next session") | |
| return {"status": "disconnected"} | |
| # ββ /speak ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def speak(request: SpeakRequest): | |
| if _pipeline is None or not getattr(_pipeline, "_running", False): | |
| raise HTTPException(status_code=400, detail="Not connected") | |
| t0 = time.monotonic() | |
| await _pipeline.push_text(request.text) | |
| return {"status": "processing", "latency_ms": round((time.monotonic() - t0) * 1000, 1)} | |
| # ββ /get-token ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def get_token(request: TokenRequest = TokenRequest()): | |
| room = request.room_name or LIVEKIT_ROOM_NAME | |
| identity = request.identity or "frontend-user" | |
| token = ( | |
| lk_api.AccessToken(LIVEKIT_API_KEY, LIVEKIT_API_SECRET) | |
| .with_identity(identity) | |
| .with_name(identity) | |
| ) | |
| token.with_grants(lk_api.VideoGrants( | |
| room_join=True, | |
| room=room, | |
| can_publish=True, | |
| can_subscribe=True, | |
| )) | |
| return {"token": token.to_jwt(), "url": LIVEKIT_URL, "room": room} | |
| # ββ entry point βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host=HOST, port=PORT, reload=False, log_level="info") | |