| """ |
| Speech-to-Video Server (api/ β warm-load version) |
| ==================================================== |
| Models are loaded ONCE at server startup (lifespan), not at /connect. |
| This means /connect is instant for subsequent sessions. |
| |
| Model loading split: |
| lifespan β MuseTalk bundle + Kokoro TTS + UNet warmup (stay in VRAM) |
| /connect β Room, Publisher, MuseTalkWorker, Pipeline (per-session) |
| /disconnect β session objects torn down; models stay loaded |
| |
| Run: |
| cd backend && python api/server.py |
| # or: uvicorn api.server:app --host 0.0.0.0 --port 8767 |
| """ |
| from __future__ import annotations |
|
|
| import asyncio |
| import logging |
| import sys |
| import time |
| from contextlib import asynccontextmanager |
| from pathlib import Path |
| from typing import Optional |
|
|
| |
| _current_file = Path(__file__).resolve() |
| _api_dir = _current_file.parent |
| _backend_dir = _api_dir.parent |
| _project_dir = _backend_dir.parent |
|
|
| for p in [_backend_dir, _project_dir]: |
| if str(p) not in sys.path: |
| sys.path.insert(0, str(p)) |
|
|
| |
| import numpy as np |
| import uvicorn |
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| from livekit import rtc |
| from livekit import api as lk_api |
|
|
| from config import ( |
| HOST, |
| PORT, |
| LIVEKIT_URL, |
| LIVEKIT_API_KEY, |
| LIVEKIT_API_SECRET, |
| LIVEKIT_ROOM_NAME, |
| VIDEO_FPS, |
| DEFAULT_AVATAR, |
| DEVICE, |
| ) |
| from tts.kokoro_tts import KokoroTTS |
| from musetalk.worker import load_musetalk_models, MuseTalkWorker, MuseTalkBundle |
| from publisher.livekit_publisher import AVPublisher |
| from api.pipeline import StreamingPipeline |
|
|
| import torch |
| torch.set_float32_matmul_precision("high") |
| torch._dynamo.config.suppress_errors = True |
|
|
| log = logging.getLogger(__name__) |
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s %(levelname)-7s %(name)s %(message)s", |
| ) |
|
|
| |
| _musetalk_bundle: Optional[MuseTalkBundle] = None |
| _tts: Optional[KokoroTTS] = None |
|
|
| |
| _pipeline: Optional[StreamingPipeline] = None |
| _room: Optional[rtc.Room] = None |
| _publisher: Optional[AVPublisher] = None |
|
|
|
|
| |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| global _musetalk_bundle, _tts |
|
|
| t_start = time.monotonic() |
| log.info("=== Speech-to-Video Server Starting ===") |
| log.info("Device: %s Avatar: %s", DEVICE, DEFAULT_AVATAR) |
|
|
| |
| log.info("Loading MuseTalk models...") |
| _musetalk_bundle = await asyncio.to_thread( |
| load_musetalk_models, DEFAULT_AVATAR, DEVICE |
| ) |
| log.info("MuseTalk loaded (%.1fs)", time.monotonic() - t_start) |
|
|
| |
| log.info("Loading Kokoro TTS...") |
| _tts = await asyncio.to_thread(KokoroTTS) |
| log.info("Kokoro TTS loaded") |
|
|
| |
| worker_tmp = MuseTalkWorker(_musetalk_bundle) |
| dummy_audio = np.zeros(int(0.32 * 24_000), dtype=np.float32) |
| feats, _ = await worker_tmp.extract_features(dummy_audio) |
| t0 = time.monotonic() |
| n = min(8, len(_musetalk_bundle.avatar_assets.frame_list)) |
| await worker_tmp.generate_batch(feats, 0, n) |
| log.info("UNet warm-up done (%.1fs)", time.monotonic() - t0) |
| worker_tmp.shutdown() |
|
|
| _tts.synthesize_full("Hello.") |
| log.info("TTS warm-up done") |
|
|
| log.info("=== Server ready in %.1fs β waiting for /connect (port %d) ===", |
| time.monotonic() - t_start, PORT) |
|
|
| yield |
|
|
| |
| global _pipeline, _room, _publisher |
| if _pipeline: |
| await _pipeline.stop() |
| if _publisher: |
| await _publisher.stop() |
| if _room: |
| await _room.disconnect() |
| log.info("=== Server Shutdown ===") |
|
|
|
|
| |
|
|
| app = FastAPI( |
| title="Speech-to-Video (api β 3-queue)", |
| description="Text β Kokoro TTS β Whisper β MuseTalk β LiveKit", |
| version="2.0.0", |
| lifespan=lifespan, |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| |
|
|
| class SpeakRequest(BaseModel): |
| text: str |
| voice: Optional[str] = None |
| speed: Optional[float] = None |
|
|
| class TokenRequest(BaseModel): |
| room_name: str = LIVEKIT_ROOM_NAME |
| identity: str = "user" |
|
|
|
|
| |
|
|
| @app.get("/health") |
| async def health(): |
| return { |
| "status": "ok", |
| "models_loaded": _musetalk_bundle is not None and _tts is not None, |
| "pipeline_active": _pipeline is not None and getattr(_pipeline, "_running", False), |
| } |
|
|
| @app.get("/status") |
| async def status(): |
| vram = {} |
| if torch.cuda.is_available(): |
| vram = { |
| "allocated_gb": round(torch.cuda.memory_allocated() / 1024**3, 2), |
| "reserved_gb": round(torch.cuda.memory_reserved() / 1024**3, 2), |
| } |
| return { |
| "pipeline": "api-3-queue", |
| "models_loaded": _musetalk_bundle is not None, |
| "pipeline_active": _pipeline is not None and getattr(_pipeline, "_running", False), |
| "avatar": DEFAULT_AVATAR, |
| "device": DEVICE, |
| "vram": vram, |
| } |
|
|
|
|
| |
|
|
| @app.post("/connect") |
| async def connect(): |
| global _room, _publisher, _pipeline |
|
|
| if _musetalk_bundle is None or _tts is None: |
| raise HTTPException(status_code=503, detail="Server still loading models") |
|
|
| if _pipeline is not None and getattr(_pipeline, "_running", False): |
| raise HTTPException(status_code=400, detail="Already connected") |
|
|
| log.info("Connecting to LiveKit room...") |
| t0 = time.monotonic() |
|
|
| try: |
| first_frame = _musetalk_bundle.avatar_assets.frame_list[0] |
| actual_h, actual_w = first_frame.shape[:2] |
|
|
| room = rtc.Room() |
| token = ( |
| lk_api.AccessToken(LIVEKIT_API_KEY, LIVEKIT_API_SECRET) |
| .with_identity("backend-agent") |
| .with_name("Speech-to-Video Agent") |
| ) |
| token.with_grants(lk_api.VideoGrants( |
| room_join=True, |
| room=LIVEKIT_ROOM_NAME, |
| can_publish=True, |
| can_subscribe=True, |
| )) |
|
|
| publisher = AVPublisher( |
| room, |
| video_width=actual_w, |
| video_height=actual_h, |
| video_fps=VIDEO_FPS, |
| ) |
|
|
| |
| musetalk_worker = MuseTalkWorker(_musetalk_bundle) |
|
|
| pipeline = StreamingPipeline( |
| tts=_tts, |
| musetalk=musetalk_worker, |
| publisher=publisher, |
| avatar_assets=_musetalk_bundle.avatar_assets, |
| ) |
|
|
| await room.connect(url=LIVEKIT_URL, token=token.to_jwt()) |
| log.info("Connected to LiveKit: %s", LIVEKIT_ROOM_NAME) |
|
|
| await publisher.start() |
| await pipeline.start() |
|
|
| |
| dummy_audio = np.zeros(int(0.32 * 24_000), dtype=np.float32) |
| feats, _ = await musetalk_worker.extract_features(dummy_audio) |
| n = min(8, len(_musetalk_bundle.avatar_assets.frame_list)) |
| await musetalk_worker.generate_batch(feats, 0, n) |
| log.info("Session warm-up done") |
|
|
| _room = room |
| _publisher = publisher |
| _pipeline = pipeline |
|
|
| log.info("/connect done in %.1fs", time.monotonic() - t0) |
| return {"status": "connected", "room": LIVEKIT_ROOM_NAME, "url": LIVEKIT_URL} |
|
|
| except Exception as exc: |
| log.error("Connection failed: %s", exc, exc_info=True) |
| raise HTTPException(status_code=500, detail=str(exc)) |
|
|
|
|
| |
|
|
| @app.post("/disconnect") |
| async def disconnect(): |
| global _room, _publisher, _pipeline |
|
|
| if _pipeline is None: |
| raise HTTPException(status_code=400, detail="Not connected") |
|
|
| log.info("Disconnecting...") |
|
|
| if _pipeline: |
| await _pipeline.stop() |
| if _publisher: |
| await _publisher.stop() |
| if _room: |
| await _room.disconnect() |
|
|
| _room = _publisher = _pipeline = None |
| |
| |
| log.info("Disconnected β models remain loaded for next session") |
| return {"status": "disconnected"} |
|
|
|
|
| |
|
|
| @app.post("/speak") |
| async def speak(request: SpeakRequest): |
| if _pipeline is None or not getattr(_pipeline, "_running", False): |
| raise HTTPException(status_code=400, detail="Not connected") |
|
|
| t0 = time.monotonic() |
| await _pipeline.push_text(request.text) |
| return {"status": "processing", "latency_ms": round((time.monotonic() - t0) * 1000, 1)} |
|
|
|
|
| |
|
|
| @app.post("/get-token") |
| @app.get("/livekit-token") |
| async def get_token(request: TokenRequest = TokenRequest()): |
| room = request.room_name or LIVEKIT_ROOM_NAME |
| identity = request.identity or "frontend-user" |
|
|
| token = ( |
| lk_api.AccessToken(LIVEKIT_API_KEY, LIVEKIT_API_SECRET) |
| .with_identity(identity) |
| .with_name(identity) |
| ) |
| token.with_grants(lk_api.VideoGrants( |
| room_join=True, |
| room=room, |
| can_publish=True, |
| can_subscribe=True, |
| )) |
| return {"token": token.to_jwt(), "url": LIVEKIT_URL, "room": room} |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host=HOST, port=PORT, reload=False, log_level="info") |
|
|