Rajhuggingface4253 commited on
Commit
a1eb108
·
verified ·
1 Parent(s): ddd978d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -320
app.py CHANGED
@@ -1,429 +1,243 @@
 
 
1
  import os
2
  import io
3
  import asyncio
4
  import time
5
- import shutil
6
- import numpy as np
7
  import psutil
8
  import soundfile as sf
9
  import subprocess
10
- import tempfile
 
11
  from concurrent.futures import ThreadPoolExecutor
12
- from typing import Optional, Generator
13
  from contextlib import asynccontextmanager
14
  import logging
15
- import aiofiles
 
16
  import torch
17
- from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Query
18
  from fastapi.responses import Response, StreamingResponse
19
  from fastapi.middleware.cors import CORSMiddleware
20
- from pydantic import BaseModel, Field
21
- import re
22
- import hashlib
23
- from functools import lru_cache
24
- # Ensure the cloned neutts-air repository is in the path
25
  import sys
26
  sys.path.append(os.path.join(os.getcwd(), 'neutts-air'))
27
  from neuttsair.neutts import NeuTTSAir
28
 
29
- # Configure logging
30
  logging.basicConfig(level=logging.INFO)
31
- logger = logging.getLogger("NeuTTS-API")
32
 
33
- # --- Configuration & Utility Functions ---
 
 
 
34
 
35
- # Explicitly use CPU as per Dockerfile and Hugging Face free tier compatibility
36
- DEVICE = "cpu"
37
- # Configure Max Workers for concurrent synthesis threads (1-2 is safe for CPU-only)
38
- MAX_WORKERS = 2
39
  tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
40
  SAMPLE_RATE = 24000
41
- CLEANUP_THRESHOLD = 300 # 1 hour in seconds
42
- TEMP_AUDIO_DIR = "temp_audio"
43
- GENERATED_AUDIO_DIR = "generated_audio"
44
- os.makedirs(TEMP_AUDIO_DIR, exist_ok=True)
45
- os.makedirs(GENERATED_AUDIO_DIR, exist_ok=True)
46
-
47
- class TTSRequestModel(BaseModel):
48
- """Model for non-file inputs to synthesis and streaming."""
49
- text: str = Field(..., min_length=1, max_length=1000)
50
- speed: float = Field(default=1.0, ge=0.5, le=2.0)
51
- output_format: str = Field(default="wav", pattern="^(wav|mp3|flac)$")
52
 
 
53
 
54
  async def convert_to_wav_in_memory(upload_file: UploadFile) -> io.BytesIO:
55
- """
56
- Converts uploaded audio to a 24kHz WAV in memory using FFmpeg pipes.
57
- This avoids all intermediate disk I/O for maximum speed.
58
- """
59
  ffmpeg_command = [
60
- "ffmpeg",
61
- "-i", "pipe:0", # Read from stdin
62
- "-f", "wav",
63
- "-ar", str(SAMPLE_RATE),
64
- "-ac", "1",
65
- "-c:a", "pcm_s16le",
66
- "pipe:1" # Write to stdout
67
  ]
68
-
69
- # Start the subprocess with pipes for stdin, stdout, and stderr
70
  proc = await asyncio.create_subprocess_exec(
71
- *ffmpeg_command,
72
- stdin=subprocess.PIPE,
73
- stdout=subprocess.PIPE,
74
- stderr=subprocess.PIPE
75
  )
76
-
77
- # Stream the uploaded file data into ffmpeg's stdin
78
- # and capture the resulting WAV data from its stdout
79
  wav_data, stderr_data = await proc.communicate(input=await upload_file.read())
80
-
81
  if proc.returncode != 0:
82
  error_message = stderr_data.decode()
83
  logger.error(f"In-memory conversion failed: {error_message}")
84
- # Provide the last line of the FFmpeg error to the user
85
- error_detail = error_message.splitlines()[-1] if error_message else "Unknown FFmpeg error."
86
- raise HTTPException(status_code=400, detail=f"Audio format conversion failed: {error_detail}")
87
-
88
- logger.info("In-memory FFmpeg conversion successful.")
89
- # Return the raw WAV data in a BytesIO buffer, ready for the model
90
  return io.BytesIO(wav_data)
91
- # --- Model Wrapper and Logic ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  class NeuTTSWrapper:
94
- def __init__(self, device: str = "cpu"):
95
- self.tts_model = None
96
- self.device = device
97
  self.load_model()
98
 
99
  def load_model(self):
100
  try:
101
- logger.info(f"Loading NeuTTSAir model on device: {self.device}")
102
- # Ensure we respect the CPU configuration
103
- self.tts_model = NeuTTSAir(backbone_device=self.device, codec_device=self.device)
104
- logger.info("✅ NeuTTSAir model loaded successfully.")
 
 
 
 
 
 
 
 
105
  except Exception as e:
106
- logger.error(f"❌ Model loading failed: {e}")
107
  raise
108
 
109
- def _convert_to_streamable_format(self, audio_data: np.ndarray, audio_format: str) -> bytes:
110
- """Converts NumPy audio array to streamable bytes in the specified format."""
111
- audio_buffer = io.BytesIO()
112
- try:
113
  sf.write(audio_buffer, audio_data, SAMPLE_RATE, format=audio_format)
114
- except Exception as e:
115
- logger.error(f"Failed to write audio data to format {audio_format}: {e}")
116
- raise
117
- audio_buffer.seek(0)
118
- return audio_buffer.read()
119
-
120
- def _split_text_into_chunks(self, text: str) -> list[str]:
121
- """
122
- Splits text into sentences OR clauses using a robust regex.
123
- This is fast, library-free, and now handles commas.
124
- """
125
- # This regex now finds all sequences of characters that are not a sentence-ending
126
- # or clause-ending punctuation mark, followed by that punctuation.
127
- # The only change is adding ',' to the character sets.
128
- chunks = re.findall(r'[^.,!?]+[.,!?]*', text)
129
- return [c.strip() for c in chunks if c.strip()]
130
-
131
- @lru_cache(maxsize=32)
132
- def _get_or_create_reference_encoding(self, audio_content_hash: str, audio_bytes: bytes) -> torch.Tensor:
133
- """
134
- Caches the expensive reference encoding operation using an in-memory LRU cache.
135
- The hash of the audio content is the key.
136
- """
137
- logger.info(f"Cache miss for hash: {audio_content_hash[:10]}... Encoding new reference.")
138
- # The model's encode_reference can take a file-like object (BytesIO)
139
- return self.tts_model.encode_reference(io.BytesIO(audio_bytes))
140
 
141
- def generate_speech_blocking(self, text: str, ref_audio_bytes: bytes, reference_text: str) -> np.ndarray:
142
- """Blocking synthesis using cached reference encoding."""
143
- # 1. Hash the audio bytes to get a cache key
144
- audio_hash = hashlib.sha256(ref_audio_bytes).hexdigest()
145
-
146
- # 2. Get the encoding from the cache (or create it if new)
147
- ref_s = self._get_or_create_reference_encoding(audio_hash, ref_audio_bytes)
148
-
149
- # 3. Infer full text
150
- with torch.no_grad():
151
- audio = self.tts_model.infer(text, ref_s, reference_text)
152
- return audio
153
-
154
- def stream_speech_blocking(self, text: str, ref_audio_bytes: bytes, reference_text: str, speed: float, audio_format: str) -> Generator[bytes, None, None]:
155
- """Sentence-by-Sentence Streaming using cached reference encoding."""
156
- logger.info(f"Starting streaming synthesis for text length: {len(text)}")
157
-
158
- # 1. Hash the audio bytes once
159
- audio_hash = hashlib.sha256(ref_audio_bytes).hexdigest()
160
-
161
- # 2. Get the reference encoding from cache, once for the whole stream
162
- ref_s = self._get_or_create_reference_encoding(audio_hash, ref_audio_bytes)
163
-
164
- # 3. Split text using the new regex method
165
- sentences = self._split_text_into_chunks(text)
166
-
167
- # 4. Stream chunks
168
- for i, sentence in enumerate(sentences):
169
- if not sentence.strip():
170
- continue
171
-
172
- logger.debug(f"Generating streaming chunk {i+1}: '{sentence[:30]}...'")
173
-
174
- with torch.no_grad():
175
- audio_chunk = self.tts_model.infer(sentence, ref_s, reference_text)
176
-
177
- yield self._convert_to_streamable_format(audio_chunk, audio_format)
178
-
179
- logger.info("Streaming synthesis complete.")
180
-
181
- # --- Asynchronous Offloading ---
182
-
183
- async def run_blocking_task_async(func, *args, **kwargs):
184
- """Offloads a blocking function call to the ThreadPoolExecutor."""
185
- loop = asyncio.get_event_loop()
186
- return await loop.run_in_executor(
187
- tts_executor,
188
- lambda: func(*args, **kwargs)
189
- )
190
-
191
- async def save_upload_file_async(upload_file: UploadFile) -> str:
192
- """Asynchronously saves the UploadFile to disk."""
193
- temp_filename = os.path.join(TEMP_AUDIO_DIR, f"{time.time()}_{upload_file.filename}")
194
- try:
195
- # Use asyncio to read the file chunks in a non-blocking manner
196
- async with aiofiles.open(temp_filename, 'wb') as out_file:
197
- while content := await upload_file.read(1024 * 1024):
198
- await out_file.write(content)
199
- return temp_filename
200
- except Exception as e:
201
- logger.error(f"Error saving file: {e}")
202
- raise HTTPException(status_code=500, detail="Could not save reference audio file")
203
-
204
- # --- FastAPI Lifespan Manager (Kokoro Feature) ---
205
 
206
  @asynccontextmanager
207
  async def lifespan(app: FastAPI):
208
- """Modern lifespan management: initialize model on startup, shutdown executor."""
209
  try:
210
- app.state.tts_wrapper = NeuTTSWrapper(device=DEVICE)
211
  except Exception as e:
212
- logger.error(f"Fatal startup error: {e}")
213
- # Terminate the application if the model can't load
214
- tts_executor.shutdown(wait=False)
215
- raise RuntimeError("Model initialization failed.")
216
-
217
- yield # Application serves requests
218
-
219
- # Shutdown
220
  logger.info("Shutting down ThreadPoolExecutor.")
221
- tts_executor.shutdown(wait=False)
222
 
223
- # --- FastAPI Application Setup ---
224
  app = FastAPI(
225
- title="NeuTTS Air Instant Cloning API",
226
- version="2.0.0-PROD-ENHANCED",
227
- docs_url="/docs",
228
  lifespan=lifespan
229
  )
230
-
231
  app.add_middleware(
232
- CORSMiddleware,
233
- allow_origins=["*"],
234
- allow_methods=["*"],
235
- allow_headers=["*"],
236
  )
237
 
238
- # --- New Endpoints and Enhancements ---
239
 
240
  @app.get("/")
241
  async def root():
242
- return {"message": "NeuTTS Air API v2.0 - Ready for Instant Voice Cloning"}
243
 
244
  @app.get("/health")
245
  async def health_check():
246
- """Enhanced health check (Kokoro Feature + Original Metrics)"""
247
  mem = psutil.virtual_memory()
248
- disk = psutil.disk_usage('/')
249
-
250
  return {
251
  "status": "healthy",
252
  "model_loaded": hasattr(app.state, 'tts_wrapper') and app.state.tts_wrapper.tts_model is not None,
253
- "device": DEVICE,
254
- "concurrency_limit": MAX_WORKERS,
255
- "memory_usage": {
256
- "total_gb": round(mem.total / (1024**3), 2),
257
- "used_percent": mem.percent
258
- },
259
- "disk_usage": {
260
- "total_gb": round(disk.total / (1024**3), 2),
261
- "used_percent": disk.percent
262
- }
263
  }
264
 
265
- @app.delete("/cleanup")
266
- async def cleanup_files():
267
- """Maintenance endpoint to remove old generated and temporary files."""
268
- await run_blocking_task_async(cleanup_files_blocking)
269
- return {"message": "Cleanup initiated successfully."}
270
-
271
- def cleanup_files_blocking():
272
- """Blocking file cleanup logic (original NeuTTS feature)."""
273
- now = time.time()
274
- deleted_count = 0
275
-
276
- for directory in [GENERATED_AUDIO_DIR, TEMP_AUDIO_DIR]:
277
- for filename in os.listdir(directory):
278
- filepath = os.path.join(directory, filename)
279
- if os.path.isfile(filepath):
280
- try:
281
- # Original cleanup logic: delete if older than CLEANUP_THRESHOLD
282
- if now - os.path.getctime(filepath) > CLEANUP_THRESHOLD:
283
- os.remove(filepath)
284
- deleted_count += 1
285
- except Exception as e:
286
- logger.warning(f"Failed to delete {filepath}: {e}")
287
-
288
- logger.info(f"Cleanup completed: {deleted_count} files removed.")
289
- return deleted_count
290
-
291
-
292
- # --- Core Synthesis Endpoints ---
293
-
294
  @app.post("/synthesize", response_class=Response)
295
  async def text_to_speech(
296
  text: str = Form(...),
297
  reference_text: str = Form(...),
298
- speed: float = Form(1.0, ge=0.5, le=2.0),
299
  output_format: str = Form("wav", pattern="^(wav|mp3|flac)$"),
300
- reference_audio: UploadFile = File(...)):
301
- """
302
- Standard blocking TTS endpoint with in-memory processing and caching.
303
- """
304
- if not hasattr(app.state, 'tts_wrapper'):
305
- raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
306
-
307
  start_time = time.time()
308
  try:
309
- # 1. Convert the uploaded file to WAV directly in memory
310
  converted_wav_buffer = await convert_to_wav_in_memory(reference_audio)
311
- ref_audio_bytes = converted_wav_buffer.getvalue()
312
 
313
- # 2. Offload the blocking AI process (now faster with caching)
 
 
 
 
314
  audio_data = await run_blocking_task_async(
315
- app.state.tts_wrapper.generate_speech_blocking,
316
- text,
317
- ref_audio_bytes, # Pass bytes, not a path
318
- reference_text
319
  )
320
-
321
- # 3. Convert to requested output format
322
  audio_bytes = await run_blocking_task_async(
323
- app.state.tts_wrapper._convert_to_streamable_format,
324
- audio_data,
325
- output_format
326
  )
327
-
328
  processing_time = time.time() - start_time
329
- audio_duration = len(audio_data) / SAMPLE_RATE
330
  return Response(
331
  content=audio_bytes,
332
  media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
333
- headers={
334
- "Content-Disposition": f"attachment; filename=tts_output.{output_format}",
335
- "X-Processing-Time": f"{processing_time:.2f}s",
336
- "X-Audio-Duration": f"{audio_duration:.2f}s"
337
- }
338
  )
339
  except Exception as e:
340
- logger.error(f"Synthesis error: {e}")
341
- if isinstance(e, HTTPException):
342
- raise
343
- raise HTTPException(status_code=500, detail=f"Synthesis failed: {e}")
344
 
345
  @app.post("/synthesize/stream")
346
  async def stream_text_to_speech_cloning(
347
- text: str = Form(..., min_length=1, max_length=5000),
348
  reference_text: str = Form(...),
349
- speed: float = Form(1.0, ge=0.5, le=2.0),
350
  output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"),
351
- reference_audio: UploadFile = File(...)):
352
- """
353
- Sentence-by-Sentence Streaming using a high-performance, asyncio-native
354
- producer-consumer pipeline.
355
- """
356
- if not hasattr(app.state, 'tts_wrapper'):
357
- raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
 
 
 
 
 
358
 
359
  async def stream_generator():
 
360
  loop = asyncio.get_event_loop()
361
- q = asyncio.Queue(maxsize=2)
362
 
363
- async def producer():
364
  try:
365
- converted_wav_buffer = await convert_to_wav_in_memory(reference_audio)
366
- ref_audio_bytes = converted_wav_buffer.getvalue()
367
- audio_hash = hashlib.sha256(ref_audio_bytes).hexdigest()
368
-
369
- # Use LRU cache like blocking endpoint
370
- ref_s = await loop.run_in_executor(
371
- tts_executor,
372
- app.state.tts_wrapper._get_or_create_reference_encoding,
373
- audio_hash,
374
- ref_audio_bytes
375
- )
376
-
377
- sentences = app.state.tts_wrapper._split_text_into_chunks(text)
378
-
379
- def process_chunk(sentence_text):
380
- with torch.no_grad():
381
- audio_chunk = app.state.tts_wrapper.tts_model.infer(sentence_text, ref_s, reference_text)
382
- return app.state.tts_wrapper._convert_to_streamable_format(audio_chunk, output_format)
383
-
384
- # Schedule all chunks to be processed in the background.
385
- for sentence in sentences:
386
- task = loop.run_in_executor(tts_executor, process_chunk, sentence)
387
- await q.put(task) # Put the FUTURE, not the result, in the queue.
388
-
389
  except Exception as e:
390
- logger.error(f"Error in producer task: {e}")
391
- await q.put(e)
392
  finally:
393
- await q.put(None) # Signal that all tasks have been scheduled.
394
 
395
- producer_task = asyncio.create_task(producer())
 
396
 
397
- # The CONSUMER's job is to wait for each result and yield it.
398
  while True:
399
- result = await q.get()
400
- if result is None:
401
  break
402
-
403
- if isinstance(result, Exception):
404
- logger.error(f"Terminating stream due to producer error: {result}")
405
- raise result
406
-
407
- # Await the result of the background task
408
- chunk_bytes = await result
409
- yield chunk_bytes
410
-
411
- await producer_task
412
 
413
  return StreamingResponse(
414
  stream_generator(),
415
  media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}"
416
- )
417
-
418
- @app.get("/audio/{filename}")
419
- async def get_audio(filename: str):
420
- """Original NeuTTS feature to serve generated audio files."""
421
- file_path = os.path.join(GENERATED_AUDIO_DIR, filename)
422
- if not os.path.exists(file_path):
423
- raise HTTPException(status_code=404, detail="Audio file not found")
424
-
425
- return Response(
426
- content=open(file_path, "rb").read(),
427
- media_type=f"audio/{filename.split('.')[-1]}", # Simple media type detection
428
- headers={"Content-Disposition": f"attachment; filename={filename}"}
429
- )
 
1
+ # app.py
2
+
3
  import os
4
  import io
5
  import asyncio
6
  import time
 
 
7
  import psutil
8
  import soundfile as sf
9
  import subprocess
10
+ import numpy as np
11
+ import librosa # Needed for monkey-patching
12
  from concurrent.futures import ThreadPoolExecutor
 
13
  from contextlib import asynccontextmanager
14
  import logging
15
+ from types import MethodType
16
+
17
  import torch
18
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form
19
  from fastapi.responses import Response, StreamingResponse
20
  from fastapi.middleware.cors import CORSMiddleware
21
+
22
+ # This will now work because the Dockerfile clones the repo
23
+ # and we add it to the path
 
 
24
  import sys
25
  sys.path.append(os.path.join(os.getcwd(), 'neutts-air'))
26
  from neuttsair.neutts import NeuTTSAir
27
 
28
+ # --- Configuration & Logging ---
29
  logging.basicConfig(level=logging.INFO)
30
+ logger = logging.getLogger("NeuTTS-GGUF-API")
31
 
32
+ # Production-ready configuration via Environment Variables
33
+ BACKBONE_MODEL_PATH = os.getenv("BACKBONE_MODEL_PATH", "/app/models/neutts-air.gguf")
34
+ CODEC_REPO = os.getenv("CODEC_REPO", "neuphonic/neucodec-onnx-decoder") # Using ONNX for performance
35
+ DEVICE = "cpu" # llama-cpp handles its own device (CPU/GPU) management
36
 
37
+ MAX_WORKERS = int(os.getenv("MAX_WORKERS", "2"))
 
 
 
38
  tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
39
  SAMPLE_RATE = 24000
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ # --- Core Utility Functions ---
42
 
43
  async def convert_to_wav_in_memory(upload_file: UploadFile) -> io.BytesIO:
44
+ """Converts uploaded audio to a 16kHz WAV for the encoder, in memory."""
 
 
 
45
  ffmpeg_command = [
46
+ "ffmpeg", "-i", "pipe:0", "-f", "wav", "-ar", "16000",
47
+ "-ac", "1", "-c:a", "pcm_s16le", "pipe:1"
 
 
 
 
 
48
  ]
 
 
49
  proc = await asyncio.create_subprocess_exec(
50
+ *ffmpeg_command, stdin=subprocess.PIPE,
51
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE
 
 
52
  )
 
 
 
53
  wav_data, stderr_data = await proc.communicate(input=await upload_file.read())
 
54
  if proc.returncode != 0:
55
  error_message = stderr_data.decode()
56
  logger.error(f"In-memory conversion failed: {error_message}")
57
+ error_detail = error_message.strip().splitlines()[-1]
58
+ raise HTTPException(status_code=400, detail=f"Audio conversion failed: {error_detail}")
 
 
 
 
59
  return io.BytesIO(wav_data)
60
+
61
+ async def run_blocking_task_async(func, *args, **kwargs):
62
+ """Offloads a blocking function call to the ThreadPoolExecutor."""
63
+ loop = asyncio.get_event_loop()
64
+ return await loop.run_in_executor(tts_executor, lambda: func(*args, **kwargs))
65
+
66
+ # --- Model Wrapper and Professional Integration ---
67
+
68
+ def _encode_reference_from_memory(self, ref_audio: io.BytesIO):
69
+ """
70
+ A replacement for the original encode_reference.
71
+ This version reads from an in-memory BytesIO object instead of a file path,
72
+ which is much faster for our API.
73
+ """
74
+ wav, _ = librosa.load(ref_audio, sr=16000, mono=True)
75
+ wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0)
76
+ with torch.no_grad():
77
+ ref_codes = self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0)
78
+ return ref_codes
79
 
80
  class NeuTTSWrapper:
81
+ def __init__(self):
82
+ self.tts_model: NeuTTSAir | None = None
 
83
  self.load_model()
84
 
85
  def load_model(self):
86
  try:
87
+ logger.info(f"Loading NeuTTSAir GGUF model from: {BACKBONE_MODEL_PATH}")
88
+ self.tts_model = NeuTTSAir(
89
+ backbone_repo=BACKBONE_MODEL_PATH,
90
+ codec_repo=CODEC_REPO,
91
+ backbone_device=DEVICE,
92
+ codec_device=DEVICE
93
+ )
94
+ # ** MONKEY-PATCHING **: This is the professional way to adapt the library
95
+ # without changing its source code. We replace its file-based function
96
+ # with our memory-based one.
97
+ self.tts_model.encode_reference = MethodType(_encode_reference_from_memory, self.tts_model)
98
+ logger.info("✅ NeuTTSAir GGUF model loaded and patched successfully.")
99
  except Exception as e:
100
+ logger.error(f"❌ Model loading failed: {e}", exc_info=True)
101
  raise
102
 
103
+ def convert_to_streamable_format(self, audio_data: np.ndarray, audio_format: str) -> bytes:
104
+ """Converts NumPy audio array to bytes in the specified format."""
105
+ with io.BytesIO() as audio_buffer:
 
106
  sf.write(audio_buffer, audio_data, SAMPLE_RATE, format=audio_format)
107
+ return audio_buffer.getvalue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
+ # --- FastAPI Application Setup ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  @asynccontextmanager
112
  async def lifespan(app: FastAPI):
113
+ """Initializes the model on startup and shuts down the executor."""
114
  try:
115
+ app.state.tts_wrapper = NeuTTSWrapper()
116
  except Exception as e:
117
+ logger.error(f"Fatal startup error: Model could not be loaded. {e}")
118
+ # Properly handle shutdown if model loading fails
119
+ tts_executor.shutdown(wait=False, cancel_futures=True)
120
+ raise RuntimeError("Model initialization failed. Application cannot start.") from e
121
+ yield
 
 
 
122
  logger.info("Shutting down ThreadPoolExecutor.")
123
+ tts_executor.shutdown(wait=True)
124
 
 
125
  app = FastAPI(
126
+ title="NeuTTS Air GGUF Cloning API",
127
+ version="3.0.0-PROD-GGUF",
 
128
  lifespan=lifespan
129
  )
 
130
  app.add_middleware(
131
+ CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"],
 
 
 
132
  )
133
 
134
+ # --- API Endpoints ---
135
 
136
  @app.get("/")
137
  async def root():
138
+ return {"message": "NeuTTS Air GGUF API - Ready for High-Speed Voice Cloning"}
139
 
140
  @app.get("/health")
141
  async def health_check():
 
142
  mem = psutil.virtual_memory()
 
 
143
  return {
144
  "status": "healthy",
145
  "model_loaded": hasattr(app.state, 'tts_wrapper') and app.state.tts_wrapper.tts_model is not None,
146
+ "model_type": "GGUF",
147
+ "backbone_path": BACKBONE_MODEL_PATH,
148
+ "codec_repo": CODEC_REPO,
149
+ "memory_usage_percent": mem.percent
 
 
 
 
 
 
150
  }
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  @app.post("/synthesize", response_class=Response)
153
  async def text_to_speech(
154
  text: str = Form(...),
155
  reference_text: str = Form(...),
 
156
  output_format: str = Form("wav", pattern="^(wav|mp3|flac)$"),
157
+ reference_audio: UploadFile = File(...)
158
+ ):
159
+ """Standard blocking TTS endpoint optimized for GGUF."""
 
 
 
 
160
  start_time = time.time()
161
  try:
 
162
  converted_wav_buffer = await convert_to_wav_in_memory(reference_audio)
 
163
 
164
+ ref_codes = await run_blocking_task_async(
165
+ app.state.tts_wrapper.tts_model.encode_reference,
166
+ converted_wav_buffer
167
+ )
168
+
169
  audio_data = await run_blocking_task_async(
170
+ app.state.tts_wrapper.tts_model.infer,
171
+ text, ref_codes, reference_text
 
 
172
  )
173
+
 
174
  audio_bytes = await run_blocking_task_async(
175
+ app.state.tts_wrapper.convert_to_streamable_format,
176
+ audio_data, output_format
 
177
  )
178
+
179
  processing_time = time.time() - start_time
 
180
  return Response(
181
  content=audio_bytes,
182
  media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
183
+ headers={"X-Processing-Time": f"{processing_time:.2f}s"}
 
 
 
 
184
  )
185
  except Exception as e:
186
+ logger.error(f"Synthesis error: {e}", exc_info=True)
187
+ detail = str(e) if isinstance(e, HTTPException) else "An internal error occurred during synthesis."
188
+ raise HTTPException(status_code=500, detail=detail)
 
189
 
190
  @app.post("/synthesize/stream")
191
  async def stream_text_to_speech_cloning(
192
+ text: str = Form(..., min_length=1),
193
  reference_text: str = Form(...),
 
194
  output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"),
195
+ reference_audio: UploadFile = File(...)
196
+ ):
197
+ """High-performance, sentence-by-sentence streaming using the GGUF backend."""
198
+ try:
199
+ converted_wav_buffer = await convert_to_wav_in_memory(reference_audio)
200
+ ref_codes = await run_blocking_task_async(
201
+ app.state.tts_wrapper.tts_model.encode_reference,
202
+ converted_wav_buffer
203
+ )
204
+ except Exception as e:
205
+ logger.error(f"Error during pre-processing for stream: {e}", exc_info=True)
206
+ raise HTTPException(status_code=500, detail="Failed to prepare reference audio for streaming.")
207
 
208
  async def stream_generator():
209
+ # The model's infer_stream is a blocking generator. We must run it in the executor.
210
  loop = asyncio.get_event_loop()
211
+ queue = asyncio.Queue()
212
 
213
+ def producer():
214
  try:
215
+ # This loop will block in the thread, but not the main event loop
216
+ for audio_chunk in app.state.tts_wrapper.tts_model.infer_stream(text, ref_codes, reference_text):
217
+ # Convert chunk to the desired output format in the same thread
218
+ chunk_bytes = app.state.tts_wrapper.convert_to_streamable_format(audio_chunk, output_format)
219
+ # Put the result into the thread-safe asyncio queue
220
+ loop.call_soon_threadsafe(queue.put_nowait, chunk_bytes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  except Exception as e:
222
+ logger.error(f"Error in streaming producer thread: {e}", exc_info=True)
223
+ loop.call_soon_threadsafe(queue.put_nowait, e)
224
  finally:
225
+ loop.call_soon_threadsafe(queue.put_nowait, None) # Signal end of stream
226
 
227
+ # Start the blocking producer in the thread pool
228
+ producer_task = loop.run_in_executor(tts_executor, producer)
229
 
230
+ # The consumer runs in the main async event loop
231
  while True:
232
+ item = await queue.get()
233
+ if item is None:
234
  break
235
+ if isinstance(item, Exception):
236
+ raise item
237
+ yield item
238
+ await producer_task # Ensure the producer finishes cleanly
 
 
 
 
 
 
239
 
240
  return StreamingResponse(
241
  stream_generator(),
242
  media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}"
243
+ )