Avatar-Speech / backend /api /server.py
agkavin
avatars
9400b83
"""
Speech-to-Video Server (api/ β€” warm-load version)
====================================================
Models are loaded ONCE at server startup (lifespan), not at /connect.
This means /connect is instant for subsequent sessions.
Model loading split:
lifespan β†’ MuseTalk bundle + Kokoro TTS + UNet warmup (stay in VRAM)
/connect β†’ Room, Publisher, MuseTalkWorker, Pipeline (per-session)
/disconnect β†’ session objects torn down; models stay loaded
Run:
cd backend && python api/server.py
# or: uvicorn api.server:app --host 0.0.0.0 --port 8767
"""
from __future__ import annotations
import asyncio
import logging
import sys
import time
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Optional
# ── path setup ────────────────────────────────────────────────────────────────
_current_file = Path(__file__).resolve()
_api_dir = _current_file.parent # backend/api/
_backend_dir = _api_dir.parent # backend/
_project_dir = _backend_dir.parent # speech_to_video/
for p in [_backend_dir, _project_dir]:
if str(p) not in sys.path:
sys.path.insert(0, str(p))
# ── imports ───────────────────────────────────────────────────────────────────
import numpy as np
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from livekit import rtc
from livekit import api as lk_api
from config import (
HOST,
PORT,
LIVEKIT_URL,
LIVEKIT_API_KEY,
LIVEKIT_API_SECRET,
LIVEKIT_ROOM_NAME,
VIDEO_FPS,
DEFAULT_AVATAR,
DEVICE,
)
from tts.kokoro_tts import KokoroTTS
from musetalk.worker import load_musetalk_models, MuseTalkWorker, MuseTalkBundle
from publisher.livekit_publisher import AVPublisher
from api.pipeline import StreamingPipeline
import torch
torch.set_float32_matmul_precision("high")
torch._dynamo.config.suppress_errors = True
log = logging.getLogger(__name__)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-7s %(name)s %(message)s",
)
# ── global model state (loaded once, lives for server lifetime) ───────────────
_musetalk_bundle: Optional[MuseTalkBundle] = None
_tts: Optional[KokoroTTS] = None
# ── session state (created/destroyed on connect/disconnect) ──────────────────
_pipeline: Optional[StreamingPipeline] = None
_room: Optional[rtc.Room] = None
_publisher: Optional[AVPublisher] = None
# ── lifespan: load models once at startup ────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
global _musetalk_bundle, _tts
t_start = time.monotonic()
log.info("=== Speech-to-Video Server Starting ===")
log.info("Device: %s Avatar: %s", DEVICE, DEFAULT_AVATAR)
# 1. Load MuseTalk (VAE + UNet + Whisper + avatar latents)
log.info("Loading MuseTalk models...")
_musetalk_bundle = await asyncio.to_thread(
load_musetalk_models, DEFAULT_AVATAR, DEVICE
)
log.info("MuseTalk loaded (%.1fs)", time.monotonic() - t_start)
# 2. Load Kokoro TTS
log.info("Loading Kokoro TTS...")
_tts = await asyncio.to_thread(KokoroTTS)
log.info("Kokoro TTS loaded")
# 3. UNet warmup β€” prime GPU caches
worker_tmp = MuseTalkWorker(_musetalk_bundle)
dummy_audio = np.zeros(int(0.32 * 24_000), dtype=np.float32)
feats, _ = await worker_tmp.extract_features(dummy_audio)
t0 = time.monotonic()
n = min(8, len(_musetalk_bundle.avatar_assets.frame_list))
await worker_tmp.generate_batch(feats, 0, n)
log.info("UNet warm-up done (%.1fs)", time.monotonic() - t0)
worker_tmp.shutdown()
_tts.synthesize_full("Hello.")
log.info("TTS warm-up done")
log.info("=== Server ready in %.1fs β€” waiting for /connect (port %d) ===",
time.monotonic() - t_start, PORT)
yield # ── server running ────────────────────────────────────────────────
# ── shutdown ──────────────────────────────────────────────────────────────
global _pipeline, _room, _publisher
if _pipeline:
await _pipeline.stop()
if _publisher:
await _publisher.stop()
if _room:
await _room.disconnect()
log.info("=== Server Shutdown ===")
# ── FastAPI app ───────────────────────────────────────────────────────────────
app = FastAPI(
title="Speech-to-Video (api β€” 3-queue)",
description="Text β†’ Kokoro TTS β†’ Whisper β†’ MuseTalk β†’ LiveKit",
version="2.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ── request models ────────────────────────────────────────────────────────────
class SpeakRequest(BaseModel):
text: str
voice: Optional[str] = None
speed: Optional[float] = None
class TokenRequest(BaseModel):
room_name: str = LIVEKIT_ROOM_NAME
identity: str = "user"
# ── /health and /status ───────────────────────────────────────────────────────
@app.get("/health")
async def health():
return {
"status": "ok",
"models_loaded": _musetalk_bundle is not None and _tts is not None,
"pipeline_active": _pipeline is not None and getattr(_pipeline, "_running", False),
}
@app.get("/status")
async def status():
vram = {}
if torch.cuda.is_available():
vram = {
"allocated_gb": round(torch.cuda.memory_allocated() / 1024**3, 2),
"reserved_gb": round(torch.cuda.memory_reserved() / 1024**3, 2),
}
return {
"pipeline": "api-3-queue",
"models_loaded": _musetalk_bundle is not None,
"pipeline_active": _pipeline is not None and getattr(_pipeline, "_running", False),
"avatar": DEFAULT_AVATAR,
"device": DEVICE,
"vram": vram,
}
# ── /connect ──────────────────────────────────────────────────────────────────
@app.post("/connect")
async def connect():
global _room, _publisher, _pipeline
if _musetalk_bundle is None or _tts is None:
raise HTTPException(status_code=503, detail="Server still loading models")
if _pipeline is not None and getattr(_pipeline, "_running", False):
raise HTTPException(status_code=400, detail="Already connected")
log.info("Connecting to LiveKit room...")
t0 = time.monotonic()
try:
first_frame = _musetalk_bundle.avatar_assets.frame_list[0]
actual_h, actual_w = first_frame.shape[:2]
room = rtc.Room()
token = (
lk_api.AccessToken(LIVEKIT_API_KEY, LIVEKIT_API_SECRET)
.with_identity("backend-agent")
.with_name("Speech-to-Video Agent")
)
token.with_grants(lk_api.VideoGrants(
room_join=True,
room=LIVEKIT_ROOM_NAME,
can_publish=True,
can_subscribe=True,
))
publisher = AVPublisher(
room,
video_width=actual_w,
video_height=actual_h,
video_fps=VIDEO_FPS,
)
# MuseTalkWorker wraps the already-loaded bundle β€” no model reload
musetalk_worker = MuseTalkWorker(_musetalk_bundle)
pipeline = StreamingPipeline(
tts=_tts,
musetalk=musetalk_worker,
publisher=publisher,
avatar_assets=_musetalk_bundle.avatar_assets,
)
await room.connect(url=LIVEKIT_URL, token=token.to_jwt())
log.info("Connected to LiveKit: %s", LIVEKIT_ROOM_NAME)
await publisher.start()
await pipeline.start()
# Fast warmup (models already hot in VRAM)
dummy_audio = np.zeros(int(0.32 * 24_000), dtype=np.float32)
feats, _ = await musetalk_worker.extract_features(dummy_audio)
n = min(8, len(_musetalk_bundle.avatar_assets.frame_list))
await musetalk_worker.generate_batch(feats, 0, n)
log.info("Session warm-up done")
_room = room
_publisher = publisher
_pipeline = pipeline
log.info("/connect done in %.1fs", time.monotonic() - t0)
return {"status": "connected", "room": LIVEKIT_ROOM_NAME, "url": LIVEKIT_URL}
except Exception as exc:
log.error("Connection failed: %s", exc, exc_info=True)
raise HTTPException(status_code=500, detail=str(exc))
# ── /disconnect ───────────────────────────────────────────────────────────────
@app.post("/disconnect")
async def disconnect():
global _room, _publisher, _pipeline
if _pipeline is None:
raise HTTPException(status_code=400, detail="Not connected")
log.info("Disconnecting...")
if _pipeline:
await _pipeline.stop()
if _publisher:
await _publisher.stop()
if _room:
await _room.disconnect()
_room = _publisher = _pipeline = None
# NOTE: _musetalk_bundle and _tts are intentionally NOT cleared β€”
# models stay in VRAM so the next /connect is instant.
log.info("Disconnected β€” models remain loaded for next session")
return {"status": "disconnected"}
# ── /speak ────────────────────────────────────────────────────────────────────
@app.post("/speak")
async def speak(request: SpeakRequest):
if _pipeline is None or not getattr(_pipeline, "_running", False):
raise HTTPException(status_code=400, detail="Not connected")
t0 = time.monotonic()
await _pipeline.push_text(request.text)
return {"status": "processing", "latency_ms": round((time.monotonic() - t0) * 1000, 1)}
# ── /get-token ────────────────────────────────────────────────────────────────
@app.post("/get-token")
@app.get("/livekit-token")
async def get_token(request: TokenRequest = TokenRequest()):
room = request.room_name or LIVEKIT_ROOM_NAME
identity = request.identity or "frontend-user"
token = (
lk_api.AccessToken(LIVEKIT_API_KEY, LIVEKIT_API_SECRET)
.with_identity(identity)
.with_name(identity)
)
token.with_grants(lk_api.VideoGrants(
room_join=True,
room=room,
can_publish=True,
can_subscribe=True,
))
return {"token": token.to_jwt(), "url": LIVEKIT_URL, "room": room}
# ── entry point ───────────────────────────────────────────────────────────────
if __name__ == "__main__":
uvicorn.run(app, host=HOST, port=PORT, reload=False, log_level="info")