dynamite-072 commited on
Commit
d909bb3
·
1 Parent(s): 8eee24d

Fix model initialization with lifespan handler

Browse files
Files changed (1) hide show
  1. pocket_tts/main.py +19 -9
pocket_tts/main.py CHANGED
@@ -3,6 +3,7 @@ import logging
3
  import os
4
  import tempfile
5
  import threading
 
6
  from pathlib import Path
7
  from queue import Queue
8
 
@@ -42,8 +43,24 @@ cli_app = typer.Typer(
42
  tts_model = None
43
  global_model_state = None
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  web_app = FastAPI(
46
- title="Kyutai Pocket TTS API", description="Text-to-Speech generation API", version="1.0.0"
 
 
 
47
  )
48
  web_app.add_middleware(
49
  CORSMiddleware,
@@ -232,14 +249,7 @@ def serve(
232
  reload: Annotated[bool, typer.Option(help="Enable auto-reload")] = False,
233
  ):
234
  """Start the FastAPI server."""
235
-
236
- global tts_model, global_model_state
237
- tts_model = TTSModel.load_model(DEFAULT_VARIANT)
238
-
239
- # Pre-load the voice prompt
240
- global_model_state = tts_model.get_state_for_audio_prompt(voice)
241
- logger.info(f"The size of the model state is {size_of_dict(global_model_state) // 1e6} MB")
242
-
243
  uvicorn.run("pocket_tts.main:web_app", host=host, port=port, reload=reload)
244
 
245
 
 
3
  import os
4
  import tempfile
5
  import threading
6
+ from contextlib import asynccontextmanager
7
  from pathlib import Path
8
  from queue import Queue
9
 
 
43
  tts_model = None
44
  global_model_state = None
45
 
46
+
47
+ @asynccontextmanager
48
+ async def lifespan(app: FastAPI):
49
+ global tts_model, global_model_state
50
+ logger.info("Loading TTS Model in lifespan...")
51
+ tts_model = TTSModel.load_model(DEFAULT_VARIANT)
52
+ global_model_state = tts_model.get_state_for_audio_prompt(DEFAULT_AUDIO_PROMPT)
53
+ logger.info(f"The size of the model state is {size_of_dict(global_model_state) // 1e6} MB")
54
+ yield
55
+ # Clean up if needed (optional)
56
+ logger.info("Shutting down TTS Model...")
57
+
58
+
59
  web_app = FastAPI(
60
+ title="Kyutai Pocket TTS API",
61
+ description="Text-to-Speech generation API",
62
+ version="1.0.0",
63
+ lifespan=lifespan,
64
  )
65
  web_app.add_middleware(
66
  CORSMiddleware,
 
249
  reload: Annotated[bool, typer.Option(help="Enable auto-reload")] = False,
250
  ):
251
  """Start the FastAPI server."""
252
+ # Model loading is now handled by the lifespan context manager
 
 
 
 
 
 
 
253
  uvicorn.run("pocket_tts.main:web_app", host=host, port=port, reload=reload)
254
 
255