File size: 14,980 Bytes
9400b83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
"""
Unified Speech-to-Video Server
==============================
Single entry point that combines:
  β€’ Avatar video pipeline   β€” POST /speak {text} β†’ avatar lip-sync
  β€’ Complete voice pipeline β€” user speaks into mic β†’ avatar replies with video

Model loading (lifespan β€” happens ONCE at startup, stays in VRAM):
  - MuseTalk bundle   (VAE + UNet + Whisper encoder + avatar latents)
  - Kokoro TTS        (ONNX, patched for int32 bug)
  - faster-whisper    (ASR, default size: "base")
  - LLM client        (httpx to llama-server :8080)
  - UNet + TTS warmup passes

Session management (per /connect β†’ /disconnect cycle):
  - CompletePipeline  (5-stage: ASR β†’ LLM β†’ TTS β†’ Whisper β†’ UNet β†’ publish)
  - AVPublisher       (LiveKit video+audio tracks)
  - rtc.Room          (LiveKit connection)

Run:
  cd backend
  python server.py
  # or: uvicorn server:app --host 0.0.0.0 --port 8767 --reload
"""
from __future__ import annotations

import asyncio
import logging
import sys
import time
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Optional

# ── path setup ────────────────────────────────────────────────────────────────
_backend_dir = Path(__file__).resolve().parent
_project_dir = _backend_dir.parent

for p in [_backend_dir, _project_dir]:
    if str(p) not in sys.path:
        sys.path.insert(0, str(p))

# Agent sub-modules live in backend/agent/ and use their own config
sys.path.insert(0, str(_backend_dir / "agent"))

# ── imports ───────────────────────────────────────────────────────────────────
import numpy as np
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from livekit import rtc
from livekit import api as lk_api

# Avatar pipeline config (backend/config.py)
from config import (
    HOST,
    PORT,
    LIVEKIT_URL,
    LIVEKIT_API_KEY,
    LIVEKIT_API_SECRET,
    LIVEKIT_ROOM_NAME,
    VIDEO_FPS,
    DEFAULT_AVATAR,
    DEVICE,
    SYSTEM_PROMPT,
)
# Agent config (backend/agent/config.py)
# Only LLAMA_SERVER_URL and ASR_MODEL_SIZE are needed here;
# KokoroTTS() reads its own model paths from backend/config.py.
from agent.config import (
    LLAMA_SERVER_URL,
    ASR_MODEL_SIZE,
)

from tts.kokoro_tts import KokoroTTS
from musetalk.worker import load_musetalk_models, MuseTalkWorker, MuseTalkBundle
from publisher.livekit_publisher import AVPublisher
from e2e.complete_pipeline import CompletePipeline

import torch
torch.set_float32_matmul_precision("high")
torch._dynamo.config.suppress_errors = True

log = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s  %(levelname)-7s  %(name)s  %(message)s",
)

# ── global model state ────────────────────────────────────────────────────────
# Loaded once at startup; survive across connect/disconnect cycles.
_musetalk_bundle: Optional[MuseTalkBundle] = None
_tts: Optional[KokoroTTS] = None
_asr = None   # agent.asr.ASR
_llm = None   # agent.llm.LLM

# ── session state ─────────────────────────────────────────────────────────────
# Created at /connect, torn down at /disconnect.
_pipeline:  Optional[CompletePipeline] = None
_room:      Optional[rtc.Room]         = None
_publisher: Optional[AVPublisher]      = None


# ── lifespan ──────────────────────────────────────────────────────────────────

@asynccontextmanager
async def lifespan(app: FastAPI):
    global _musetalk_bundle, _tts, _asr, _llm

    t_start = time.monotonic()
    log.info("=== Speech-to-Video Unified Server Starting ===")
    log.info("Device: %s  Avatar: %s", DEVICE, DEFAULT_AVATAR)

    # 1. MuseTalk bundle (VAE + UNet + Whisper encoder + avatar latents)
    log.info("Loading MuseTalk models...")
    _musetalk_bundle = await asyncio.to_thread(
        load_musetalk_models, DEFAULT_AVATAR, DEVICE
    )
    log.info("MuseTalk loaded  (%.1fs)", time.monotonic() - t_start)

    # 2. Kokoro TTS (avatar pipeline)
    log.info("Loading Kokoro TTS...")
    _tts = await asyncio.to_thread(KokoroTTS)
    log.info("Kokoro TTS loaded")

    # 3. faster-whisper ASR (voice pipeline)
    log.info("Loading faster-whisper ASR (size=%s)...", ASR_MODEL_SIZE)
    from agent.asr import ASR
    _asr = await asyncio.to_thread(ASR, ASR_MODEL_SIZE, DEVICE)
    log.info("ASR loaded")

    # 4. LLM client β€” httpx to llama-server, no GPU needed
    log.info("Initialising LLM client β†’ %s", LLAMA_SERVER_URL)
    from agent.llm import LLM
    _llm = LLM(LLAMA_SERVER_URL)
    await asyncio.to_thread(_llm.warmup)
    log.info("LLM client ready")

    # 5. UNet warmup β€” prime GPU caches
    log.info("Warming up UNet...")
    _worker_tmp = MuseTalkWorker(_musetalk_bundle)
    dummy_audio = np.zeros(int(0.32 * 24_000), dtype=np.float32)
    feats, _ = await _worker_tmp.extract_features(dummy_audio)
    t0 = time.monotonic()
    n = min(8, len(_musetalk_bundle.avatar_assets.frame_list))
    await _worker_tmp.generate_batch(feats, 0, n)
    log.info("UNet warm-up done  (%.1fs)", time.monotonic() - t0)
    _worker_tmp.shutdown()

    # 6. TTS warmup
    _tts.synthesize_full("Hello.")
    log.info("TTS warm-up done")

    log.info(
        "=== All models ready in %.1fs β€” waiting for /connect (port %d) ===",
        time.monotonic() - t_start,
        PORT,
    )

    yield  # ── server running ────────────────────────────────────────────────

    # ── shutdown ──────────────────────────────────────────────────────────────
    global _pipeline, _room, _publisher
    if _pipeline:
        await _pipeline.stop()
    if _publisher:
        await _publisher.stop()
    if _room:
        await _room.disconnect()
    log.info("=== Server Shutdown ===")


# ── FastAPI app ───────────────────────────────────────────────────────────────

app = FastAPI(
    title="Speech-to-Video β€” Complete Pipeline",
    description=(
        "User mic β†’ ASR β†’ LLM β†’ Kokoro TTS β†’ MuseTalk β†’ LiveKit avatar video.\n"
        "POST /speak also works for direct text input (bypasses ASR/LLM)."
    ),
    version="3.0.0",
    lifespan=lifespan,
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)


# ── request models ────────────────────────────────────────────────────────────

class SpeakRequest(BaseModel):
    text: str
    voice: Optional[str] = None
    speed: Optional[float] = None

class TokenRequest(BaseModel):
    room_name: str = LIVEKIT_ROOM_NAME
    identity: str = "user"


# ── /health ───────────────────────────────────────────────────────────────────

@app.get("/health")
async def health():
    return {
        "status": "ok",
        "models_loaded": all(
            m is not None for m in [_musetalk_bundle, _tts, _asr, _llm]
        ),
        "pipeline_active": _pipeline is not None and getattr(_pipeline, "_running", False),
    }

@app.get("/status")
async def status():
    vram = {}
    if torch.cuda.is_available():
        vram = {
            "allocated_gb": round(torch.cuda.memory_allocated() / 1024**3, 2),
            "reserved_gb":  round(torch.cuda.memory_reserved()  / 1024**3, 2),
        }
    return {
        "pipeline": "complete-5-stage",
        "models_loaded": {
            "musetalk": _musetalk_bundle is not None,
            "tts":      _tts is not None,
            "asr":      _asr is not None,
            "llm":      _llm is not None,
        },
        "pipeline_active": _pipeline is not None and getattr(_pipeline, "_running", False),
        "avatar": DEFAULT_AVATAR,
        "device": DEVICE,
        "vram": vram,
    }


# ── /connect ──────────────────────────────────────────────────────────────────

@app.post("/connect")
async def connect():
    """
    Create a session:
    1. Instantiate CompletePipeline (no model loading β€” models already in VRAM)
    2. Connect backend-agent to LiveKit room
    3. Start publisher + pipeline (pipeline auto-subscribes to mic audio tracks)
    4. Return LiveKit connection info
    """
    global _room, _publisher, _pipeline

    if any(m is None for m in [_musetalk_bundle, _tts, _asr, _llm]):
        raise HTTPException(status_code=503, detail="Server still loading models")

    if _pipeline is not None and getattr(_pipeline, "_running", False):
        raise HTTPException(status_code=400, detail="Already connected")

    log.info("Creating new session...")
    t0 = time.monotonic()

    try:
        # Determine actual video dimensions from precomputed avatar frames
        first_frame = _musetalk_bundle.avatar_assets.frame_list[0]
        actual_h, actual_w = first_frame.shape[:2]
        log.info("Avatar frame size: %dx%d", actual_w, actual_h)

        # LiveKit room + JWT
        room = rtc.Room()
        token = (
            lk_api.AccessToken(LIVEKIT_API_KEY, LIVEKIT_API_SECRET)
            .with_identity("backend-agent")
            .with_name("Speech-to-Video Agent")
        )
        token.with_grants(lk_api.VideoGrants(
            room_join=True,
            room=LIVEKIT_ROOM_NAME,
            can_publish=True,
            can_subscribe=True,
        ))

        # Publisher
        publisher = AVPublisher(
            room,
            video_width=actual_w,
            video_height=actual_h,
            video_fps=VIDEO_FPS,
        )

        # MuseTalkWorker wraps the already-loaded bundle β€” no GPU reload
        musetalk_worker = MuseTalkWorker(_musetalk_bundle)

        # Complete pipeline (5-stage)
        pipeline = CompletePipeline(
            tts=_tts,
            musetalk=musetalk_worker,
            publisher=publisher,
            avatar_assets=_musetalk_bundle.avatar_assets,
            asr=_asr,
            llm=_llm,
            system_prompt=SYSTEM_PROMPT,
        )

        # Connect β†’ publish tracks β†’ start pipeline
        await room.connect(url=LIVEKIT_URL, token=token.to_jwt())
        log.info("Connected to LiveKit room: %s", LIVEKIT_ROOM_NAME)

        await publisher.start()
        await pipeline.start(room)  # pipeline subscribes to audio here

        # Fast warmup (models already hot)
        dummy_audio = np.zeros(int(0.32 * 24_000), dtype=np.float32)
        feats, _ = await musetalk_worker.extract_features(dummy_audio)
        n = min(8, len(_musetalk_bundle.avatar_assets.frame_list))
        await musetalk_worker.generate_batch(feats, 0, n)
        log.info("Session warm-up done")

        _room      = room
        _publisher = publisher
        _pipeline  = pipeline

        log.info("/connect done in %.1fs", time.monotonic() - t0)
        return {
            "status":   "connected",
            "room":     LIVEKIT_ROOM_NAME,
            "url":      LIVEKIT_URL,
            "pipeline": "complete-5-stage",
        }

    except Exception as exc:
        log.error("Connection failed: %s", exc, exc_info=True)
        raise HTTPException(status_code=500, detail=str(exc))


# ── /disconnect ───────────────────────────────────────────────────────────────

@app.post("/disconnect")
async def disconnect():
    global _room, _publisher, _pipeline

    if _pipeline is None:
        raise HTTPException(status_code=400, detail="Not connected")

    log.info("Disconnecting session...")

    if _pipeline:
        await _pipeline.stop()
    if _publisher:
        await _publisher.stop()
    if _room:
        await _room.disconnect()

    _room = _publisher = _pipeline = None
    # Models intentionally NOT cleared β€” stay in VRAM for instant reconnect
    log.info("Session disconnected β€” models remain loaded")
    return {"status": "disconnected"}


# ── /speak  (text bypass β€” works alongside live voice) ────────────────────────

@app.post("/speak")
async def speak(request: SpeakRequest):
    """
    Directly inject text into the avatar pipeline, bypassing ASR + LLM.
    Useful for testing or mixing programmatic responses with live voice.
    """
    if _pipeline is None or not getattr(_pipeline, "_running", False):
        raise HTTPException(status_code=400, detail="Not connected")

    t0 = time.monotonic()
    await _pipeline.push_text(request.text)
    return {"status": "processing", "latency_ms": round((time.monotonic() - t0) * 1000, 1)}


# ── /get-token ────────────────────────────────────────────────────────────────

@app.post("/get-token")
@app.get("/livekit-token")
async def get_token(request: TokenRequest = TokenRequest()):
    """Issue a LiveKit JWT for the frontend (viewer) or external clients."""
    room     = request.room_name or LIVEKIT_ROOM_NAME
    identity = request.identity  or "frontend-user"

    token = (
        lk_api.AccessToken(LIVEKIT_API_KEY, LIVEKIT_API_SECRET)
        .with_identity(identity)
        .with_name(identity)
    )
    token.with_grants(lk_api.VideoGrants(
        room_join=True,
        room=room,
        can_publish=True,
        can_subscribe=True,
    ))
    return {"token": token.to_jwt(), "url": LIVEKIT_URL, "room": room}


# ── entry point ───────────────────────────────────────────────────────────────

if __name__ == "__main__":
    uvicorn.run(app, host=HOST, port=PORT, reload=False, log_level="info")