File size: 12,059 Bytes
249e06d
9400b83
 
 
 
 
 
 
 
 
 
 
 
 
249e06d
 
 
 
 
 
 
 
 
 
 
9400b83
249e06d
9400b83
 
 
249e06d
9400b83
249e06d
 
 
9400b83
 
249e06d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9400b83
249e06d
 
 
 
9400b83
249e06d
 
 
 
 
 
 
 
9400b83
 
 
249e06d
9400b83
 
 
 
 
 
 
249e06d
 
 
9400b83
 
 
249e06d
9400b83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249e06d
9400b83
 
 
 
249e06d
 
 
 
 
 
 
 
 
9400b83
 
249e06d
9400b83
 
 
249e06d
 
 
 
 
 
 
 
 
 
 
9400b83
 
249e06d
 
 
 
 
 
9400b83
249e06d
 
 
9400b83
249e06d
 
 
 
 
9400b83
 
249e06d
 
 
 
 
 
 
 
9400b83
249e06d
 
9400b83
 
 
249e06d
 
 
 
 
 
9400b83
 
249e06d
 
 
9400b83
 
 
 
 
249e06d
9400b83
249e06d
9400b83
 
249e06d
9400b83
 
 
249e06d
9400b83
 
 
 
 
249e06d
 
 
 
 
 
 
 
 
 
 
 
 
9400b83
 
 
 
249e06d
9400b83
249e06d
 
9400b83
249e06d
9400b83
 
 
 
249e06d
 
9400b83
 
 
 
 
 
 
 
 
249e06d
9400b83
 
 
 
 
 
 
 
249e06d
 
9400b83
 
249e06d
 
 
9400b83
249e06d
 
9400b83
249e06d
9400b83
249e06d
 
 
 
 
 
9400b83
 
 
 
 
249e06d
 
 
9400b83
 
249e06d
 
9400b83
249e06d
9400b83
 
249e06d
9400b83
 
249e06d
9400b83
249e06d
 
 
 
9400b83
 
 
 
 
 
 
 
249e06d
 
 
 
 
 
9400b83
249e06d
 
9400b83
249e06d
 
9400b83
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
"""
Speech-to-Video Server  (api/ β€” warm-load version)
====================================================
Models are loaded ONCE at server startup (lifespan), not at /connect.
This means /connect is instant for subsequent sessions.

Model loading split:
  lifespan  β†’ MuseTalk bundle + Kokoro TTS + UNet warmup (stay in VRAM)
  /connect  β†’ Room, Publisher, MuseTalkWorker, Pipeline (per-session)
  /disconnect β†’ session objects torn down; models stay loaded

Run:
  cd backend && python api/server.py
  # or: uvicorn api.server:app --host 0.0.0.0 --port 8767
"""
from __future__ import annotations

import asyncio
import logging
import sys
import time
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Optional

# ── path setup ────────────────────────────────────────────────────────────────
_current_file = Path(__file__).resolve()
_api_dir      = _current_file.parent        # backend/api/
_backend_dir  = _api_dir.parent             # backend/
_project_dir  = _backend_dir.parent         # speech_to_video/

for p in [_backend_dir, _project_dir]:
    if str(p) not in sys.path:
        sys.path.insert(0, str(p))

# ── imports ───────────────────────────────────────────────────────────────────
import numpy as np
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from livekit import rtc
from livekit import api as lk_api

from config import (
    HOST,
    PORT,
    LIVEKIT_URL,
    LIVEKIT_API_KEY,
    LIVEKIT_API_SECRET,
    LIVEKIT_ROOM_NAME,
    VIDEO_FPS,
    DEFAULT_AVATAR,
    DEVICE,
)
from tts.kokoro_tts import KokoroTTS
from musetalk.worker import load_musetalk_models, MuseTalkWorker, MuseTalkBundle
from publisher.livekit_publisher import AVPublisher
from api.pipeline import StreamingPipeline

import torch
torch.set_float32_matmul_precision("high")
torch._dynamo.config.suppress_errors = True

log = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s  %(levelname)-7s  %(name)s  %(message)s",
)

# ── global model state (loaded once, lives for server lifetime) ───────────────
_musetalk_bundle: Optional[MuseTalkBundle] = None
_tts: Optional[KokoroTTS] = None

# ── session state (created/destroyed on connect/disconnect) ──────────────────
_pipeline:  Optional[StreamingPipeline] = None
_room:      Optional[rtc.Room]          = None
_publisher: Optional[AVPublisher]       = None


# ── lifespan: load models once at startup ────────────────────────────────────

@asynccontextmanager
async def lifespan(app: FastAPI):
    global _musetalk_bundle, _tts

    t_start = time.monotonic()
    log.info("=== Speech-to-Video Server Starting ===")
    log.info("Device: %s  Avatar: %s", DEVICE, DEFAULT_AVATAR)

    # 1. Load MuseTalk (VAE + UNet + Whisper + avatar latents)
    log.info("Loading MuseTalk models...")
    _musetalk_bundle = await asyncio.to_thread(
        load_musetalk_models, DEFAULT_AVATAR, DEVICE
    )
    log.info("MuseTalk loaded  (%.1fs)", time.monotonic() - t_start)

    # 2. Load Kokoro TTS
    log.info("Loading Kokoro TTS...")
    _tts = await asyncio.to_thread(KokoroTTS)
    log.info("Kokoro TTS loaded")

    # 3. UNet warmup β€” prime GPU caches
    worker_tmp = MuseTalkWorker(_musetalk_bundle)
    dummy_audio = np.zeros(int(0.32 * 24_000), dtype=np.float32)
    feats, _ = await worker_tmp.extract_features(dummy_audio)
    t0 = time.monotonic()
    n = min(8, len(_musetalk_bundle.avatar_assets.frame_list))
    await worker_tmp.generate_batch(feats, 0, n)
    log.info("UNet warm-up done  (%.1fs)", time.monotonic() - t0)
    worker_tmp.shutdown()

    _tts.synthesize_full("Hello.")
    log.info("TTS warm-up done")

    log.info("=== Server ready in %.1fs β€” waiting for /connect (port %d) ===",
             time.monotonic() - t_start, PORT)

    yield  # ── server running ────────────────────────────────────────────────

    # ── shutdown ──────────────────────────────────────────────────────────────
    global _pipeline, _room, _publisher
    if _pipeline:
        await _pipeline.stop()
    if _publisher:
        await _publisher.stop()
    if _room:
        await _room.disconnect()
    log.info("=== Server Shutdown ===")


# ── FastAPI app ───────────────────────────────────────────────────────────────

app = FastAPI(
    title="Speech-to-Video (api β€” 3-queue)",
    description="Text β†’ Kokoro TTS β†’ Whisper β†’ MuseTalk β†’ LiveKit",
    version="2.0.0",
    lifespan=lifespan,
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)


# ── request models ────────────────────────────────────────────────────────────

class SpeakRequest(BaseModel):
    text: str
    voice: Optional[str] = None
    speed: Optional[float] = None

class TokenRequest(BaseModel):
    room_name: str = LIVEKIT_ROOM_NAME
    identity: str = "user"


# ── /health and /status ───────────────────────────────────────────────────────

@app.get("/health")
async def health():
    return {
        "status": "ok",
        "models_loaded": _musetalk_bundle is not None and _tts is not None,
        "pipeline_active": _pipeline is not None and getattr(_pipeline, "_running", False),
    }

@app.get("/status")
async def status():
    vram = {}
    if torch.cuda.is_available():
        vram = {
            "allocated_gb": round(torch.cuda.memory_allocated() / 1024**3, 2),
            "reserved_gb":  round(torch.cuda.memory_reserved()  / 1024**3, 2),
        }
    return {
        "pipeline": "api-3-queue",
        "models_loaded": _musetalk_bundle is not None,
        "pipeline_active": _pipeline is not None and getattr(_pipeline, "_running", False),
        "avatar": DEFAULT_AVATAR,
        "device": DEVICE,
        "vram": vram,
    }


# ── /connect ──────────────────────────────────────────────────────────────────

@app.post("/connect")
async def connect():
    global _room, _publisher, _pipeline

    if _musetalk_bundle is None or _tts is None:
        raise HTTPException(status_code=503, detail="Server still loading models")

    if _pipeline is not None and getattr(_pipeline, "_running", False):
        raise HTTPException(status_code=400, detail="Already connected")

    log.info("Connecting to LiveKit room...")
    t0 = time.monotonic()

    try:
        first_frame = _musetalk_bundle.avatar_assets.frame_list[0]
        actual_h, actual_w = first_frame.shape[:2]

        room = rtc.Room()
        token = (
            lk_api.AccessToken(LIVEKIT_API_KEY, LIVEKIT_API_SECRET)
            .with_identity("backend-agent")
            .with_name("Speech-to-Video Agent")
        )
        token.with_grants(lk_api.VideoGrants(
            room_join=True,
            room=LIVEKIT_ROOM_NAME,
            can_publish=True,
            can_subscribe=True,
        ))

        publisher = AVPublisher(
            room,
            video_width=actual_w,
            video_height=actual_h,
            video_fps=VIDEO_FPS,
        )

        # MuseTalkWorker wraps the already-loaded bundle β€” no model reload
        musetalk_worker = MuseTalkWorker(_musetalk_bundle)

        pipeline = StreamingPipeline(
            tts=_tts,
            musetalk=musetalk_worker,
            publisher=publisher,
            avatar_assets=_musetalk_bundle.avatar_assets,
        )

        await room.connect(url=LIVEKIT_URL, token=token.to_jwt())
        log.info("Connected to LiveKit: %s", LIVEKIT_ROOM_NAME)

        await publisher.start()
        await pipeline.start()

        # Fast warmup (models already hot in VRAM)
        dummy_audio = np.zeros(int(0.32 * 24_000), dtype=np.float32)
        feats, _ = await musetalk_worker.extract_features(dummy_audio)
        n = min(8, len(_musetalk_bundle.avatar_assets.frame_list))
        await musetalk_worker.generate_batch(feats, 0, n)
        log.info("Session warm-up done")

        _room      = room
        _publisher = publisher
        _pipeline  = pipeline

        log.info("/connect done in %.1fs", time.monotonic() - t0)
        return {"status": "connected", "room": LIVEKIT_ROOM_NAME, "url": LIVEKIT_URL}

    except Exception as exc:
        log.error("Connection failed: %s", exc, exc_info=True)
        raise HTTPException(status_code=500, detail=str(exc))


# ── /disconnect ───────────────────────────────────────────────────────────────

@app.post("/disconnect")
async def disconnect():
    global _room, _publisher, _pipeline

    if _pipeline is None:
        raise HTTPException(status_code=400, detail="Not connected")

    log.info("Disconnecting...")

    if _pipeline:
        await _pipeline.stop()
    if _publisher:
        await _publisher.stop()
    if _room:
        await _room.disconnect()

    _room = _publisher = _pipeline = None
    # NOTE: _musetalk_bundle and _tts are intentionally NOT cleared β€”
    # models stay in VRAM so the next /connect is instant.
    log.info("Disconnected β€” models remain loaded for next session")
    return {"status": "disconnected"}


# ── /speak ────────────────────────────────────────────────────────────────────

@app.post("/speak")
async def speak(request: SpeakRequest):
    if _pipeline is None or not getattr(_pipeline, "_running", False):
        raise HTTPException(status_code=400, detail="Not connected")

    t0 = time.monotonic()
    await _pipeline.push_text(request.text)
    return {"status": "processing", "latency_ms": round((time.monotonic() - t0) * 1000, 1)}


# ── /get-token ────────────────────────────────────────────────────────────────

@app.post("/get-token")
@app.get("/livekit-token")
async def get_token(request: TokenRequest = TokenRequest()):
    room     = request.room_name or LIVEKIT_ROOM_NAME
    identity = request.identity  or "frontend-user"

    token = (
        lk_api.AccessToken(LIVEKIT_API_KEY, LIVEKIT_API_SECRET)
        .with_identity(identity)
        .with_name(identity)
    )
    token.with_grants(lk_api.VideoGrants(
        room_join=True,
        room=room,
        can_publish=True,
        can_subscribe=True,
    ))
    return {"token": token.to_jwt(), "url": LIVEKIT_URL, "room": room}


# ── entry point ───────────────────────────────────────────────────────────────

if __name__ == "__main__":
    uvicorn.run(app, host=HOST, port=PORT, reload=False, log_level="info")