""" Speech-to-Video Server (api/ — warm-load version) ==================================================== Models are loaded ONCE at server startup (lifespan), not at /connect. This means /connect is instant for subsequent sessions. Model loading split: lifespan → MuseTalk bundle + Kokoro TTS + UNet warmup (stay in VRAM) /connect → Room, Publisher, MuseTalkWorker, Pipeline (per-session) /disconnect → session objects torn down; models stay loaded Run: cd backend && python api/server.py # or: uvicorn api.server:app --host 0.0.0.0 --port 8767 """ from __future__ import annotations import asyncio import logging import sys import time from contextlib import asynccontextmanager from pathlib import Path from typing import Optional # ── path setup ──────────────────────────────────────────────────────────────── _current_file = Path(__file__).resolve() _api_dir = _current_file.parent # backend/api/ _backend_dir = _api_dir.parent # backend/ _project_dir = _backend_dir.parent # speech_to_video/ for p in [_backend_dir, _project_dir]: if str(p) not in sys.path: sys.path.insert(0, str(p)) # ── imports ─────────────────────────────────────────────────────────────────── import numpy as np import uvicorn from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from livekit import rtc from livekit import api as lk_api from config import ( HOST, PORT, LIVEKIT_URL, LIVEKIT_API_KEY, LIVEKIT_API_SECRET, LIVEKIT_ROOM_NAME, VIDEO_FPS, DEFAULT_AVATAR, DEVICE, ) from tts.kokoro_tts import KokoroTTS from musetalk.worker import load_musetalk_models, MuseTalkWorker, MuseTalkBundle from publisher.livekit_publisher import AVPublisher from api.pipeline import StreamingPipeline import torch torch.set_float32_matmul_precision("high") torch._dynamo.config.suppress_errors = True log = logging.getLogger(__name__) logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)-7s %(name)s %(message)s", ) # ── global model state (loaded once, lives for server lifetime) ─────────────── _musetalk_bundle: Optional[MuseTalkBundle] = None _tts: Optional[KokoroTTS] = None # ── session state (created/destroyed on connect/disconnect) ────────────────── _pipeline: Optional[StreamingPipeline] = None _room: Optional[rtc.Room] = None _publisher: Optional[AVPublisher] = None # ── lifespan: load models once at startup ──────────────────────────────────── @asynccontextmanager async def lifespan(app: FastAPI): global _musetalk_bundle, _tts t_start = time.monotonic() log.info("=== Speech-to-Video Server Starting ===") log.info("Device: %s Avatar: %s", DEVICE, DEFAULT_AVATAR) # 1. Load MuseTalk (VAE + UNet + Whisper + avatar latents) log.info("Loading MuseTalk models...") _musetalk_bundle = await asyncio.to_thread( load_musetalk_models, DEFAULT_AVATAR, DEVICE ) log.info("MuseTalk loaded (%.1fs)", time.monotonic() - t_start) # 2. Load Kokoro TTS log.info("Loading Kokoro TTS...") _tts = await asyncio.to_thread(KokoroTTS) log.info("Kokoro TTS loaded") # 3. UNet warmup — prime GPU caches worker_tmp = MuseTalkWorker(_musetalk_bundle) dummy_audio = np.zeros(int(0.32 * 24_000), dtype=np.float32) feats, _ = await worker_tmp.extract_features(dummy_audio) t0 = time.monotonic() n = min(8, len(_musetalk_bundle.avatar_assets.frame_list)) await worker_tmp.generate_batch(feats, 0, n) log.info("UNet warm-up done (%.1fs)", time.monotonic() - t0) worker_tmp.shutdown() _tts.synthesize_full("Hello.") log.info("TTS warm-up done") log.info("=== Server ready in %.1fs — waiting for /connect (port %d) ===", time.monotonic() - t_start, PORT) yield # ── server running ──────────────────────────────────────────────── # ── shutdown ────────────────────────────────────────────────────────────── global _pipeline, _room, _publisher if _pipeline: await _pipeline.stop() if _publisher: await _publisher.stop() if _room: await _room.disconnect() log.info("=== Server Shutdown ===") # ── FastAPI app ─────────────────────────────────────────────────────────────── app = FastAPI( title="Speech-to-Video (api — 3-queue)", description="Text → Kokoro TTS → Whisper → MuseTalk → LiveKit", version="2.0.0", lifespan=lifespan, ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ── request models ──────────────────────────────────────────────────────────── class SpeakRequest(BaseModel): text: str voice: Optional[str] = None speed: Optional[float] = None class TokenRequest(BaseModel): room_name: str = LIVEKIT_ROOM_NAME identity: str = "user" # ── /health and /status ─────────────────────────────────────────────────────── @app.get("/health") async def health(): return { "status": "ok", "models_loaded": _musetalk_bundle is not None and _tts is not None, "pipeline_active": _pipeline is not None and getattr(_pipeline, "_running", False), } @app.get("/status") async def status(): vram = {} if torch.cuda.is_available(): vram = { "allocated_gb": round(torch.cuda.memory_allocated() / 1024**3, 2), "reserved_gb": round(torch.cuda.memory_reserved() / 1024**3, 2), } return { "pipeline": "api-3-queue", "models_loaded": _musetalk_bundle is not None, "pipeline_active": _pipeline is not None and getattr(_pipeline, "_running", False), "avatar": DEFAULT_AVATAR, "device": DEVICE, "vram": vram, } # ── /connect ────────────────────────────────────────────────────────────────── @app.post("/connect") async def connect(): global _room, _publisher, _pipeline if _musetalk_bundle is None or _tts is None: raise HTTPException(status_code=503, detail="Server still loading models") if _pipeline is not None and getattr(_pipeline, "_running", False): raise HTTPException(status_code=400, detail="Already connected") log.info("Connecting to LiveKit room...") t0 = time.monotonic() try: first_frame = _musetalk_bundle.avatar_assets.frame_list[0] actual_h, actual_w = first_frame.shape[:2] room = rtc.Room() token = ( lk_api.AccessToken(LIVEKIT_API_KEY, LIVEKIT_API_SECRET) .with_identity("backend-agent") .with_name("Speech-to-Video Agent") ) token.with_grants(lk_api.VideoGrants( room_join=True, room=LIVEKIT_ROOM_NAME, can_publish=True, can_subscribe=True, )) publisher = AVPublisher( room, video_width=actual_w, video_height=actual_h, video_fps=VIDEO_FPS, ) # MuseTalkWorker wraps the already-loaded bundle — no model reload musetalk_worker = MuseTalkWorker(_musetalk_bundle) pipeline = StreamingPipeline( tts=_tts, musetalk=musetalk_worker, publisher=publisher, avatar_assets=_musetalk_bundle.avatar_assets, ) await room.connect(url=LIVEKIT_URL, token=token.to_jwt()) log.info("Connected to LiveKit: %s", LIVEKIT_ROOM_NAME) await publisher.start() await pipeline.start() # Fast warmup (models already hot in VRAM) dummy_audio = np.zeros(int(0.32 * 24_000), dtype=np.float32) feats, _ = await musetalk_worker.extract_features(dummy_audio) n = min(8, len(_musetalk_bundle.avatar_assets.frame_list)) await musetalk_worker.generate_batch(feats, 0, n) log.info("Session warm-up done") _room = room _publisher = publisher _pipeline = pipeline log.info("/connect done in %.1fs", time.monotonic() - t0) return {"status": "connected", "room": LIVEKIT_ROOM_NAME, "url": LIVEKIT_URL} except Exception as exc: log.error("Connection failed: %s", exc, exc_info=True) raise HTTPException(status_code=500, detail=str(exc)) # ── /disconnect ─────────────────────────────────────────────────────────────── @app.post("/disconnect") async def disconnect(): global _room, _publisher, _pipeline if _pipeline is None: raise HTTPException(status_code=400, detail="Not connected") log.info("Disconnecting...") if _pipeline: await _pipeline.stop() if _publisher: await _publisher.stop() if _room: await _room.disconnect() _room = _publisher = _pipeline = None # NOTE: _musetalk_bundle and _tts are intentionally NOT cleared — # models stay in VRAM so the next /connect is instant. log.info("Disconnected — models remain loaded for next session") return {"status": "disconnected"} # ── /speak ──────────────────────────────────────────────────────────────────── @app.post("/speak") async def speak(request: SpeakRequest): if _pipeline is None or not getattr(_pipeline, "_running", False): raise HTTPException(status_code=400, detail="Not connected") t0 = time.monotonic() await _pipeline.push_text(request.text) return {"status": "processing", "latency_ms": round((time.monotonic() - t0) * 1000, 1)} # ── /get-token ──────────────────────────────────────────────────────────────── @app.post("/get-token") @app.get("/livekit-token") async def get_token(request: TokenRequest = TokenRequest()): room = request.room_name or LIVEKIT_ROOM_NAME identity = request.identity or "frontend-user" token = ( lk_api.AccessToken(LIVEKIT_API_KEY, LIVEKIT_API_SECRET) .with_identity(identity) .with_name(identity) ) token.with_grants(lk_api.VideoGrants( room_join=True, room=room, can_publish=True, can_subscribe=True, )) return {"token": token.to_jwt(), "url": LIVEKIT_URL, "room": room} # ── entry point ─────────────────────────────────────────────────────────────── if __name__ == "__main__": uvicorn.run(app, host=HOST, port=PORT, reload=False, log_level="info")