Spaces:
Sleeping
Sleeping
Commit
·
d909bb3
1
Parent(s):
8eee24d
Fix model initialization with lifespan handler
Browse files- pocket_tts/main.py +19 -9
pocket_tts/main.py
CHANGED
|
@@ -3,6 +3,7 @@ import logging
|
|
| 3 |
import os
|
| 4 |
import tempfile
|
| 5 |
import threading
|
|
|
|
| 6 |
from pathlib import Path
|
| 7 |
from queue import Queue
|
| 8 |
|
|
@@ -42,8 +43,24 @@ cli_app = typer.Typer(
|
|
| 42 |
tts_model = None
|
| 43 |
global_model_state = None
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
web_app = FastAPI(
|
| 46 |
-
title="Kyutai Pocket TTS API",
|
|
|
|
|
|
|
|
|
|
| 47 |
)
|
| 48 |
web_app.add_middleware(
|
| 49 |
CORSMiddleware,
|
|
@@ -232,14 +249,7 @@ def serve(
|
|
| 232 |
reload: Annotated[bool, typer.Option(help="Enable auto-reload")] = False,
|
| 233 |
):
|
| 234 |
"""Start the FastAPI server."""
|
| 235 |
-
|
| 236 |
-
global tts_model, global_model_state
|
| 237 |
-
tts_model = TTSModel.load_model(DEFAULT_VARIANT)
|
| 238 |
-
|
| 239 |
-
# Pre-load the voice prompt
|
| 240 |
-
global_model_state = tts_model.get_state_for_audio_prompt(voice)
|
| 241 |
-
logger.info(f"The size of the model state is {size_of_dict(global_model_state) // 1e6} MB")
|
| 242 |
-
|
| 243 |
uvicorn.run("pocket_tts.main:web_app", host=host, port=port, reload=reload)
|
| 244 |
|
| 245 |
|
|
|
|
| 3 |
import os
|
| 4 |
import tempfile
|
| 5 |
import threading
|
| 6 |
+
from contextlib import asynccontextmanager
|
| 7 |
from pathlib import Path
|
| 8 |
from queue import Queue
|
| 9 |
|
|
|
|
| 43 |
tts_model = None
|
| 44 |
global_model_state = None
|
| 45 |
|
| 46 |
+
|
| 47 |
+
@asynccontextmanager
|
| 48 |
+
async def lifespan(app: FastAPI):
|
| 49 |
+
global tts_model, global_model_state
|
| 50 |
+
logger.info("Loading TTS Model in lifespan...")
|
| 51 |
+
tts_model = TTSModel.load_model(DEFAULT_VARIANT)
|
| 52 |
+
global_model_state = tts_model.get_state_for_audio_prompt(DEFAULT_AUDIO_PROMPT)
|
| 53 |
+
logger.info(f"The size of the model state is {size_of_dict(global_model_state) // 1e6} MB")
|
| 54 |
+
yield
|
| 55 |
+
# Clean up if needed (optional)
|
| 56 |
+
logger.info("Shutting down TTS Model...")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
web_app = FastAPI(
|
| 60 |
+
title="Kyutai Pocket TTS API",
|
| 61 |
+
description="Text-to-Speech generation API",
|
| 62 |
+
version="1.0.0",
|
| 63 |
+
lifespan=lifespan,
|
| 64 |
)
|
| 65 |
web_app.add_middleware(
|
| 66 |
CORSMiddleware,
|
|
|
|
| 249 |
reload: Annotated[bool, typer.Option(help="Enable auto-reload")] = False,
|
| 250 |
):
|
| 251 |
"""Start the FastAPI server."""
|
| 252 |
+
# Model loading is now handled by the lifespan context manager
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
uvicorn.run("pocket_tts.main:web_app", host=host, port=port, reload=reload)
|
| 254 |
|
| 255 |
|