Avatar-Speech / backend /e2e /server.py
agkavin
Fix pipeline deadlocks, remove torch.compile, implement 3-queue parallel pipeline, optimize for 16fps
a4cc15e
"""
Speech-to-Video Server β€” e2e (Three-Queue Parallel Pipeline)
=============================================================
Model loading is split from session management:
Server startup (lifespan)
└─ load MuseTalk bundle (VAE, UNet, Whisper, avatar latents) ← once, stays in VRAM
└─ load Kokoro TTS ← once
└─ torch.compile + UNet warmup ← once (cached to disk)
POST /connect (user opens the app)
└─ create Room, Publisher, MuseTalkWorker, Pipeline ← per-session
└─ connect to LiveKit ← per-session
└─ fast warmup (models already hot) ← per-session
POST /disconnect
└─ stop pipeline/publisher, disconnect room ← models stay loaded
Run:
python -m backend.e2e.server
# or:
cd backend && uvicorn e2e.server:app --host 0.0.0.0 --port 8767
"""
from __future__ import annotations
import asyncio
import logging
import os
import sys
import time
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Optional
# ── path setup ────────────────────────────────────────────────────────────────
_current_file = Path(__file__).resolve()
_e2e_dir = _current_file.parent
_backend_dir = _e2e_dir.parent
_project_dir = _backend_dir.parent
for p in [_backend_dir, _project_dir]:
if str(p) not in sys.path:
sys.path.insert(0, str(p))
# ── imports ───────────────────────────────────────────────────────────────────
import numpy as np
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from livekit import rtc
from livekit import api as lk_api
from config import (
HOST,
LIVEKIT_URL,
LIVEKIT_API_KEY,
LIVEKIT_API_SECRET,
LIVEKIT_ROOM_NAME,
VIDEO_FPS,
DEFAULT_AVATAR,
DEVICE,
PORT,
)
from tts.kokoro_tts import KokoroTTS
from musetalk.worker import load_musetalk_models, MuseTalkWorker, MuseTalkBundle
from publisher.livekit_publisher import AVPublisher
from e2e.pipeline import StreamingPipeline
import torch
torch.set_float32_matmul_precision("high")
torch._dynamo.config.suppress_errors = True
log = logging.getLogger(__name__)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-7s %(name)s %(message)s",
)
# ── global model state (loaded once, lives for server lifetime) ───────────────
_musetalk_bundle: Optional[MuseTalkBundle] = None
_tts: Optional[KokoroTTS] = None
# ── session state (created/destroyed on connect/disconnect) ──────────────────
_pipeline: Optional[StreamingPipeline] = None
_room: Optional[rtc.Room] = None
_publisher: Optional[AVPublisher] = None
# ── lifespan: load models once at startup ────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
global _musetalk_bundle, _tts
t_start = time.monotonic()
log.info("=== Speech-to-Video e2e Server Starting ===")
log.info("Device: %s Avatar: %s", DEVICE, DEFAULT_AVATAR)
# ── 1. Load MuseTalk (VAE + UNet + Whisper + avatar latents) ─────────────
log.info("Loading MuseTalk models...")
_musetalk_bundle = await asyncio.to_thread(
load_musetalk_models, DEFAULT_AVATAR, DEVICE
)
log.info("MuseTalk loaded (%.1fs)", time.monotonic() - t_start)
# ── 2. Load Kokoro TTS ────────────────────────────────────────────────────
log.info("Loading Kokoro TTS...")
_tts = await asyncio.to_thread(KokoroTTS)
log.info("Kokoro TTS loaded")
# ── 3. UNet warmup (prime GPU caches, no torch.compile) ──────────────────
worker_tmp = MuseTalkWorker(_musetalk_bundle)
dummy_audio = np.zeros(int(0.32 * 24_000), dtype=np.float32)
feats, _ = await worker_tmp.extract_features(dummy_audio)
t0 = time.monotonic()
n = min(8, len(_musetalk_bundle.avatar_assets.frame_list))
await worker_tmp.generate_batch(feats, 0, n)
log.info("UNet warm-up done (%.1fs)", time.monotonic() - t0)
worker_tmp.shutdown() # release executor threads β€” must not outlive warmup
_tts.synthesize_full("Hello.")
log.info("TTS warm-up done")
log.info("=== Server ready in %.1fs β€” waiting for /connect (port %d) ===",
time.monotonic() - t_start, PORT)
yield # ── server running ──────────────────────────────────────────────
# ── shutdown ──────────────────────────────────────────────────────────────
global _pipeline, _room, _publisher
if _pipeline:
await _pipeline.stop()
if _publisher:
await _publisher.stop()
if _room:
await _room.disconnect()
log.info("=== Server Shutdown ===")
# ── FastAPI app ───────────────────────────────────────────────────────────────
app = FastAPI(
title="Speech-to-Video (e2e 3-queue)",
description="Text β†’ Kokoro TTS β†’ Whisper β†’ MuseTalk β†’ LiveKit",
version="2.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ── request models ────────────────────────────────────────────────────────────
class SpeakRequest(BaseModel):
text: str
voice: Optional[str] = None
speed: Optional[float] = None
class TokenRequest(BaseModel):
room_name: str = LIVEKIT_ROOM_NAME
identity: str = "user"
# ── /health and /status ───────────────────────────────────────────────────────
@app.get("/health")
async def health():
return {
"status": "ok",
"models_loaded": _musetalk_bundle is not None and _tts is not None,
"pipeline_active": _pipeline is not None and getattr(_pipeline, "_running", False),
}
@app.get("/status")
async def status():
vram = {}
if torch.cuda.is_available():
vram = {
"allocated_gb": round(torch.cuda.memory_allocated() / 1024**3, 2),
"reserved_gb": round(torch.cuda.memory_reserved() / 1024**3, 2),
}
return {
"pipeline": "e2e-3-queue",
"models_loaded": _musetalk_bundle is not None,
"pipeline_active": _pipeline is not None and getattr(_pipeline, "_running", False),
"avatar": DEFAULT_AVATAR,
"device": DEVICE,
"vram": vram,
}
# ── /connect ──────────────────────────────────────────────────────────────────
@app.post("/connect")
async def connect():
global _room, _publisher, _pipeline
if _musetalk_bundle is None or _tts is None:
raise HTTPException(status_code=503, detail="Server still loading models")
if _pipeline is not None and getattr(_pipeline, "_running", False):
raise HTTPException(status_code=400, detail="Already connected")
log.info("Connecting to LiveKit room...")
t0 = time.monotonic()
try:
# ── session objects (no model loading here) ───────────────────────────
first_frame = _musetalk_bundle.avatar_assets.frame_list[0]
actual_h, actual_w = first_frame.shape[:2]
room = rtc.Room()
token = (
lk_api.AccessToken(LIVEKIT_API_KEY, LIVEKIT_API_SECRET)
.with_identity("backend-agent")
.with_name("Speech-to-Video Agent")
)
token.with_grants(lk_api.VideoGrants(
room_join=True,
room=LIVEKIT_ROOM_NAME,
can_publish=True,
can_subscribe=True,
))
publisher = AVPublisher(
room,
video_width=actual_w,
video_height=actual_h,
video_fps=VIDEO_FPS,
)
# MuseTalkWorker wraps the already-loaded bundle β€” no model reload
musetalk_worker = MuseTalkWorker(_musetalk_bundle)
pipeline = StreamingPipeline(
tts=_tts,
musetalk=musetalk_worker,
publisher=publisher,
avatar_assets=_musetalk_bundle.avatar_assets,
)
# ── connect LiveKit β†’ start pipeline ──────────────────────────────────
await room.connect(url=LIVEKIT_URL, token=token.to_jwt())
log.info("Connected to LiveKit: %s", LIVEKIT_ROOM_NAME)
await publisher.start()
await pipeline.start()
# ── fast warmup (models already hot in VRAM) ──────────────────────────
dummy_audio = np.zeros(int(0.32 * 24_000), dtype=np.float32)
feats, _ = await musetalk_worker.extract_features(dummy_audio)
n = min(8, len(_musetalk_bundle.avatar_assets.frame_list))
await musetalk_worker.generate_batch(feats, 0, n)
log.info("Session warm-up done")
_room = room
_publisher = publisher
_pipeline = pipeline
log.info("/connect done in %.1fs", time.monotonic() - t0)
return {"status": "connected", "room": LIVEKIT_ROOM_NAME, "url": LIVEKIT_URL}
except Exception as exc:
log.error("Connection failed: %s", exc, exc_info=True)
raise HTTPException(status_code=500, detail=str(exc))
# ── /disconnect ───────────────────────────────────────────────────────────────
@app.post("/disconnect")
async def disconnect():
global _room, _publisher, _pipeline
if _pipeline is None:
raise HTTPException(status_code=400, detail="Not connected")
log.info("Disconnecting...")
if _pipeline:
await _pipeline.stop()
if _publisher:
await _publisher.stop()
if _room:
await _room.disconnect()
_room = _publisher = _pipeline = None
# NOTE: _musetalk_bundle and _tts are intentionally NOT cleared β€”
# models stay in VRAM so the next /connect is instant.
log.info("Disconnected β€” models remain loaded for next session")
return {"status": "disconnected"}
# ── /speak ────────────────────────────────────────────────────────────────────
@app.post("/speak")
async def speak(request: SpeakRequest):
if _pipeline is None or not getattr(_pipeline, "_running", False):
raise HTTPException(status_code=400, detail="Not connected")
t0 = time.monotonic()
await _pipeline.push_text(request.text)
return {"status": "processing", "latency_ms": round((time.monotonic() - t0) * 1000, 1)}
# ── /get-token ────────────────────────────────────────────────────────────────
@app.post("/get-token")
@app.get("/livekit-token")
async def get_token(request: TokenRequest = TokenRequest()):
room = request.room_name or LIVEKIT_ROOM_NAME
identity = request.identity or "frontend-user"
token = (
lk_api.AccessToken(LIVEKIT_API_KEY, LIVEKIT_API_SECRET)
.with_identity(identity)
.with_name(identity)
)
token.with_grants(lk_api.VideoGrants(
room_join=True,
room=room,
can_publish=True,
can_subscribe=True,
))
return {"token": token.to_jwt(), "url": LIVEKIT_URL, "room": room}
# ── entry point ───────────────────────────────────────────────────────────────
if __name__ == "__main__":
uvicorn.run(app, host=HOST, port=PORT, reload=False, log_level="info")