Avatar-Speech / backend /server.py
agkavin
avatars
9400b83
"""
Unified Speech-to-Video Server
==============================
Single entry point that combines:
β€’ Avatar video pipeline β€” POST /speak {text} β†’ avatar lip-sync
β€’ Complete voice pipeline β€” user speaks into mic β†’ avatar replies with video
Model loading (lifespan β€” happens ONCE at startup, stays in VRAM):
- MuseTalk bundle (VAE + UNet + Whisper encoder + avatar latents)
- Kokoro TTS (ONNX, patched for int32 bug)
- faster-whisper (ASR, default size: "base")
- LLM client (httpx to llama-server :8080)
- UNet + TTS warmup passes
Session management (per /connect β†’ /disconnect cycle):
- CompletePipeline (5-stage: ASR β†’ LLM β†’ TTS β†’ Whisper β†’ UNet β†’ publish)
- AVPublisher (LiveKit video+audio tracks)
- rtc.Room (LiveKit connection)
Run:
cd backend
python server.py
# or: uvicorn server:app --host 0.0.0.0 --port 8767 --reload
"""
from __future__ import annotations
import asyncio
import logging
import sys
import time
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Optional
# ── path setup ────────────────────────────────────────────────────────────────
_backend_dir = Path(__file__).resolve().parent
_project_dir = _backend_dir.parent
for p in [_backend_dir, _project_dir]:
if str(p) not in sys.path:
sys.path.insert(0, str(p))
# Agent sub-modules live in backend/agent/ and use their own config
sys.path.insert(0, str(_backend_dir / "agent"))
# ── imports ───────────────────────────────────────────────────────────────────
import numpy as np
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from livekit import rtc
from livekit import api as lk_api
# Avatar pipeline config (backend/config.py)
from config import (
HOST,
PORT,
LIVEKIT_URL,
LIVEKIT_API_KEY,
LIVEKIT_API_SECRET,
LIVEKIT_ROOM_NAME,
VIDEO_FPS,
DEFAULT_AVATAR,
DEVICE,
SYSTEM_PROMPT,
)
# Agent config (backend/agent/config.py)
# Only LLAMA_SERVER_URL and ASR_MODEL_SIZE are needed here;
# KokoroTTS() reads its own model paths from backend/config.py.
from agent.config import (
LLAMA_SERVER_URL,
ASR_MODEL_SIZE,
)
from tts.kokoro_tts import KokoroTTS
from musetalk.worker import load_musetalk_models, MuseTalkWorker, MuseTalkBundle
from publisher.livekit_publisher import AVPublisher
from e2e.complete_pipeline import CompletePipeline
import torch
torch.set_float32_matmul_precision("high")
torch._dynamo.config.suppress_errors = True
log = logging.getLogger(__name__)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-7s %(name)s %(message)s",
)
# ── global model state ────────────────────────────────────────────────────────
# Loaded once at startup; survive across connect/disconnect cycles.
_musetalk_bundle: Optional[MuseTalkBundle] = None
_tts: Optional[KokoroTTS] = None
_asr = None # agent.asr.ASR
_llm = None # agent.llm.LLM
# ── session state ─────────────────────────────────────────────────────────────
# Created at /connect, torn down at /disconnect.
_pipeline: Optional[CompletePipeline] = None
_room: Optional[rtc.Room] = None
_publisher: Optional[AVPublisher] = None
# ── lifespan ──────────────────────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
global _musetalk_bundle, _tts, _asr, _llm
t_start = time.monotonic()
log.info("=== Speech-to-Video Unified Server Starting ===")
log.info("Device: %s Avatar: %s", DEVICE, DEFAULT_AVATAR)
# 1. MuseTalk bundle (VAE + UNet + Whisper encoder + avatar latents)
log.info("Loading MuseTalk models...")
_musetalk_bundle = await asyncio.to_thread(
load_musetalk_models, DEFAULT_AVATAR, DEVICE
)
log.info("MuseTalk loaded (%.1fs)", time.monotonic() - t_start)
# 2. Kokoro TTS (avatar pipeline)
log.info("Loading Kokoro TTS...")
_tts = await asyncio.to_thread(KokoroTTS)
log.info("Kokoro TTS loaded")
# 3. faster-whisper ASR (voice pipeline)
log.info("Loading faster-whisper ASR (size=%s)...", ASR_MODEL_SIZE)
from agent.asr import ASR
_asr = await asyncio.to_thread(ASR, ASR_MODEL_SIZE, DEVICE)
log.info("ASR loaded")
# 4. LLM client β€” httpx to llama-server, no GPU needed
log.info("Initialising LLM client β†’ %s", LLAMA_SERVER_URL)
from agent.llm import LLM
_llm = LLM(LLAMA_SERVER_URL)
await asyncio.to_thread(_llm.warmup)
log.info("LLM client ready")
# 5. UNet warmup β€” prime GPU caches
log.info("Warming up UNet...")
_worker_tmp = MuseTalkWorker(_musetalk_bundle)
dummy_audio = np.zeros(int(0.32 * 24_000), dtype=np.float32)
feats, _ = await _worker_tmp.extract_features(dummy_audio)
t0 = time.monotonic()
n = min(8, len(_musetalk_bundle.avatar_assets.frame_list))
await _worker_tmp.generate_batch(feats, 0, n)
log.info("UNet warm-up done (%.1fs)", time.monotonic() - t0)
_worker_tmp.shutdown()
# 6. TTS warmup
_tts.synthesize_full("Hello.")
log.info("TTS warm-up done")
log.info(
"=== All models ready in %.1fs β€” waiting for /connect (port %d) ===",
time.monotonic() - t_start,
PORT,
)
yield # ── server running ────────────────────────────────────────────────
# ── shutdown ──────────────────────────────────────────────────────────────
global _pipeline, _room, _publisher
if _pipeline:
await _pipeline.stop()
if _publisher:
await _publisher.stop()
if _room:
await _room.disconnect()
log.info("=== Server Shutdown ===")
# ── FastAPI app ───────────────────────────────────────────────────────────────
app = FastAPI(
title="Speech-to-Video β€” Complete Pipeline",
description=(
"User mic β†’ ASR β†’ LLM β†’ Kokoro TTS β†’ MuseTalk β†’ LiveKit avatar video.\n"
"POST /speak also works for direct text input (bypasses ASR/LLM)."
),
version="3.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ── request models ────────────────────────────────────────────────────────────
class SpeakRequest(BaseModel):
text: str
voice: Optional[str] = None
speed: Optional[float] = None
class TokenRequest(BaseModel):
room_name: str = LIVEKIT_ROOM_NAME
identity: str = "user"
# ── /health ───────────────────────────────────────────────────────────────────
@app.get("/health")
async def health():
return {
"status": "ok",
"models_loaded": all(
m is not None for m in [_musetalk_bundle, _tts, _asr, _llm]
),
"pipeline_active": _pipeline is not None and getattr(_pipeline, "_running", False),
}
@app.get("/status")
async def status():
vram = {}
if torch.cuda.is_available():
vram = {
"allocated_gb": round(torch.cuda.memory_allocated() / 1024**3, 2),
"reserved_gb": round(torch.cuda.memory_reserved() / 1024**3, 2),
}
return {
"pipeline": "complete-5-stage",
"models_loaded": {
"musetalk": _musetalk_bundle is not None,
"tts": _tts is not None,
"asr": _asr is not None,
"llm": _llm is not None,
},
"pipeline_active": _pipeline is not None and getattr(_pipeline, "_running", False),
"avatar": DEFAULT_AVATAR,
"device": DEVICE,
"vram": vram,
}
# ── /connect ──────────────────────────────────────────────────────────────────
@app.post("/connect")
async def connect():
"""
Create a session:
1. Instantiate CompletePipeline (no model loading β€” models already in VRAM)
2. Connect backend-agent to LiveKit room
3. Start publisher + pipeline (pipeline auto-subscribes to mic audio tracks)
4. Return LiveKit connection info
"""
global _room, _publisher, _pipeline
if any(m is None for m in [_musetalk_bundle, _tts, _asr, _llm]):
raise HTTPException(status_code=503, detail="Server still loading models")
if _pipeline is not None and getattr(_pipeline, "_running", False):
raise HTTPException(status_code=400, detail="Already connected")
log.info("Creating new session...")
t0 = time.monotonic()
try:
# Determine actual video dimensions from precomputed avatar frames
first_frame = _musetalk_bundle.avatar_assets.frame_list[0]
actual_h, actual_w = first_frame.shape[:2]
log.info("Avatar frame size: %dx%d", actual_w, actual_h)
# LiveKit room + JWT
room = rtc.Room()
token = (
lk_api.AccessToken(LIVEKIT_API_KEY, LIVEKIT_API_SECRET)
.with_identity("backend-agent")
.with_name("Speech-to-Video Agent")
)
token.with_grants(lk_api.VideoGrants(
room_join=True,
room=LIVEKIT_ROOM_NAME,
can_publish=True,
can_subscribe=True,
))
# Publisher
publisher = AVPublisher(
room,
video_width=actual_w,
video_height=actual_h,
video_fps=VIDEO_FPS,
)
# MuseTalkWorker wraps the already-loaded bundle β€” no GPU reload
musetalk_worker = MuseTalkWorker(_musetalk_bundle)
# Complete pipeline (5-stage)
pipeline = CompletePipeline(
tts=_tts,
musetalk=musetalk_worker,
publisher=publisher,
avatar_assets=_musetalk_bundle.avatar_assets,
asr=_asr,
llm=_llm,
system_prompt=SYSTEM_PROMPT,
)
# Connect β†’ publish tracks β†’ start pipeline
await room.connect(url=LIVEKIT_URL, token=token.to_jwt())
log.info("Connected to LiveKit room: %s", LIVEKIT_ROOM_NAME)
await publisher.start()
await pipeline.start(room) # pipeline subscribes to audio here
# Fast warmup (models already hot)
dummy_audio = np.zeros(int(0.32 * 24_000), dtype=np.float32)
feats, _ = await musetalk_worker.extract_features(dummy_audio)
n = min(8, len(_musetalk_bundle.avatar_assets.frame_list))
await musetalk_worker.generate_batch(feats, 0, n)
log.info("Session warm-up done")
_room = room
_publisher = publisher
_pipeline = pipeline
log.info("/connect done in %.1fs", time.monotonic() - t0)
return {
"status": "connected",
"room": LIVEKIT_ROOM_NAME,
"url": LIVEKIT_URL,
"pipeline": "complete-5-stage",
}
except Exception as exc:
log.error("Connection failed: %s", exc, exc_info=True)
raise HTTPException(status_code=500, detail=str(exc))
# ── /disconnect ───────────────────────────────────────────────────────────────
@app.post("/disconnect")
async def disconnect():
global _room, _publisher, _pipeline
if _pipeline is None:
raise HTTPException(status_code=400, detail="Not connected")
log.info("Disconnecting session...")
if _pipeline:
await _pipeline.stop()
if _publisher:
await _publisher.stop()
if _room:
await _room.disconnect()
_room = _publisher = _pipeline = None
# Models intentionally NOT cleared β€” stay in VRAM for instant reconnect
log.info("Session disconnected β€” models remain loaded")
return {"status": "disconnected"}
# ── /speak (text bypass β€” works alongside live voice) ────────────────────────
@app.post("/speak")
async def speak(request: SpeakRequest):
"""
Directly inject text into the avatar pipeline, bypassing ASR + LLM.
Useful for testing or mixing programmatic responses with live voice.
"""
if _pipeline is None or not getattr(_pipeline, "_running", False):
raise HTTPException(status_code=400, detail="Not connected")
t0 = time.monotonic()
await _pipeline.push_text(request.text)
return {"status": "processing", "latency_ms": round((time.monotonic() - t0) * 1000, 1)}
# ── /get-token ────────────────────────────────────────────────────────────────
@app.post("/get-token")
@app.get("/livekit-token")
async def get_token(request: TokenRequest = TokenRequest()):
"""Issue a LiveKit JWT for the frontend (viewer) or external clients."""
room = request.room_name or LIVEKIT_ROOM_NAME
identity = request.identity or "frontend-user"
token = (
lk_api.AccessToken(LIVEKIT_API_KEY, LIVEKIT_API_SECRET)
.with_identity(identity)
.with_name(identity)
)
token.with_grants(lk_api.VideoGrants(
room_join=True,
room=room,
can_publish=True,
can_subscribe=True,
))
return {"token": token.to_jwt(), "url": LIVEKIT_URL, "room": room}
# ── entry point ───────────────────────────────────────────────────────────────
if __name__ == "__main__":
uvicorn.run(app, host=HOST, port=PORT, reload=False, log_level="info")