Rajhuggingface4253 commited on
Commit
e3fd3e2
·
verified ·
1 Parent(s): 829b125

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +342 -257
app.py CHANGED
@@ -9,15 +9,19 @@ import soundfile as sf
9
  import subprocess
10
  import tempfile
11
  from concurrent.futures import ThreadPoolExecutor
12
- from typing import Optional, Generator
13
  from contextlib import asynccontextmanager
14
  import logging
15
- import aiofiles
16
  import torch
17
- from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Query
18
  from fastapi.responses import Response, StreamingResponse
19
  from fastapi.middleware.cors import CORSMiddleware
20
  from pydantic import BaseModel, Field
 
 
 
 
21
 
22
  # Ensure the cloned neutts-air repository is in the path
23
  import sys
@@ -25,106 +29,102 @@ sys.path.append(os.path.join(os.getcwd(), 'neutts-air'))
25
  from neuttsair.neutts import NeuTTSAir
26
 
27
  # Configure logging
28
- logging.basicConfig(level=logging.INFO)
 
 
 
29
  logger = logging.getLogger("NeuTTS-API")
30
 
31
- # --- Configuration & Utility Functions ---
32
-
33
- # Explicitly use CPU as per Dockerfile and Hugging Face free tier compatibility
34
- DEVICE = "cpu"
35
- # Configure Max Workers for concurrent synthesis threads (1-2 is safe for CPU-only)
36
  MAX_WORKERS = 2
37
  tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
38
  SAMPLE_RATE = 24000
39
- CLEANUP_THRESHOLD = 3600 # 1 hour in seconds
40
  TEMP_AUDIO_DIR = "temp_audio"
41
  GENERATED_AUDIO_DIR = "generated_audio"
42
  os.makedirs(TEMP_AUDIO_DIR, exist_ok=True)
43
  os.makedirs(GENERATED_AUDIO_DIR, exist_ok=True)
44
 
 
45
  class TTSRequestModel(BaseModel):
46
- """Model for non-file inputs to synthesis and streaming."""
47
  text: str = Field(..., min_length=1, max_length=1000)
48
  speed: float = Field(default=1.0, ge=0.5, le=2.0)
49
  output_format: str = Field(default="wav", pattern="^(wav|mp3|flac)$")
50
 
51
-
52
- def convert_to_wav_blocking(input_path: str) -> str:
53
- """
54
- NEW FUNCTION: Uses FFmpeg to convert any uploaded audio format (WebM, MP4, etc.)
55
- to a 24kHz, 16-bit PCM WAV file, which is required by soundfile/libsndfile.
56
- This function must run in the ThreadPoolExecutor.
57
- """
58
- # Create a unique temporary filename for the converted WAV file
59
- # We use tempfile.NamedTemporaryFile to safely create a path
60
- # and then delete the file handle so ffmpeg can write to it.
 
 
61
  with tempfile.NamedTemporaryFile(suffix=".wav", dir=TEMP_AUDIO_DIR, delete=False) as tmp:
62
  output_path = tmp.name
63
-
64
- logger.info(f"Converting '{os.path.basename(input_path)}' to WAV (24kHz, mono) at {os.path.basename(output_path)}")
65
-
66
- # FFmpeg command details:
67
- # -y: overwrite output file if it exists
68
- # -i: input file path
69
- # -f wav: output format is WAV
70
- # -ar 24000: set sample rate to 24000 (required by NeuTTS)
71
- # -ac 1: set audio channels to 1 (mono)
72
- # -c:a pcm_s16le: set codec to uncompressed 16-bit PCM (standard WAV)
73
  command = [
74
- "ffmpeg",
75
- "-y",
76
- "-i", input_path,
77
- "-f", "wav",
78
- "-ar", str(SAMPLE_RATE),
79
- "-ac", "1",
80
- "-c:a", "pcm_s16le",
81
- output_path
82
  ]
83
-
84
  try:
85
- # Run the FFmpeg command
86
- # Use a short timeout to prevent runaway processes
87
- result = subprocess.run(command, check=True, capture_output=True, text=True, timeout=30)
88
- logger.info(f"FFmpeg conversion successful.")
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  return output_path
90
- except subprocess.CalledProcessError as e:
91
- logger.error(f"FFmpeg conversion failed: {e.stderr}")
92
- # Clean up the output path if FFmpeg failed to write it
93
- if os.path.exists(output_path):
94
- os.unlink(output_path)
95
- # Provide the last line of the FFmpeg error to the user
96
- error_detail = e.stderr.splitlines()[-1] if e.stderr else "Unknown FFmpeg error."
97
- raise HTTPException(status_code=400, detail=f"Audio format conversion failed: {error_detail}")
98
- except subprocess.TimeoutExpired:
99
- logger.error("FFmpeg conversion timed out.")
100
  if os.path.exists(output_path):
101
  os.unlink(output_path)
102
- raise HTTPException(status_code=504, detail="Audio conversion timed out after 30 seconds.")
103
  except Exception as e:
104
- logger.error(f"General conversion error: {e}")
105
  if os.path.exists(output_path):
106
  os.unlink(output_path)
107
- raise HTTPException(status_code=500, detail="An unexpected error occurred during audio conversion.")
108
- # --- Model Wrapper and Logic ---
109
 
 
110
  class NeuTTSWrapper:
111
  def __init__(self, device: str = "cpu"):
112
  self.tts_model = None
113
  self.device = device
 
114
  self.load_model()
115
 
116
  def load_model(self):
117
  try:
118
  logger.info(f"Loading NeuTTSAir model on device: {self.device}")
119
- # Ensure we respect the CPU configuration
120
  self.tts_model = NeuTTSAir(backbone_device=self.device, codec_device=self.device)
121
- logger.info("✅ NeuTTSAir model loaded successfully.")
122
  except Exception as e:
123
  logger.error(f"❌ Model loading failed: {e}")
124
  raise
125
 
126
  def _convert_to_streamable_format(self, audio_data: np.ndarray, audio_format: str) -> bytes:
127
- """Converts NumPy audio array to streamable bytes in the specified format."""
128
  audio_buffer = io.BytesIO()
129
  try:
130
  sf.write(audio_buffer, audio_data, SAMPLE_RATE, format=audio_format)
@@ -134,90 +134,100 @@ class NeuTTSWrapper:
134
  audio_buffer.seek(0)
135
  return audio_buffer.read()
136
 
137
- def _split_text_into_chunks(self, text: str) -> list[str]:
138
- """Simple sentence splitting for streaming (can be enhanced with regex)."""
139
- sentences = [s.strip() for s in text.split('.') if s.strip()]
140
- if not sentences:
141
- sentences = [text.strip()]
142
- return sentences
143
-
144
- def generate_speech_blocking(self, text: str, ref_audio_path: str, reference_text: str) -> np.ndarray:
145
- """Blocking synthesis for standard endpoint."""
146
 
 
 
 
 
 
 
 
 
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  ref_s = self.tts_model.encode_reference(ref_audio_path)
149
-
150
- # 3. Infer full text
151
  with torch.no_grad():
152
  audio = self.tts_model.infer(text, ref_s, reference_text)
153
  return audio
154
 
155
- def stream_producer(self, queue: asyncio.Queue, text: str, ref_audio_path: str, reference_text: str):
156
- """
157
- [PRODUCER] Runs in a thread, generates audio chunks, and puts them into a queue.
158
- """
159
- try:
160
- logger.info("Starting audio production thread...")
161
- ref_s = self.tts_model.encode_reference(ref_audio_path)
162
- sentences = self._split_text_into_chunks(text)
163
-
164
- for i, sentence in enumerate(sentences):
165
- if not sentence.strip():
166
- continue
167
-
168
- logger.debug(f"Producing chunk {i+1}/{len(sentences)}: '{sentence[:30]}...'")
169
-
170
- with torch.no_grad():
171
- audio_chunk = self.tts_model.infer(sentence, ref_s, reference_text)
172
-
173
- queue.put_nowait(audio_chunk)
174
-
175
- except Exception as e:
176
- logger.error(f"Error in producer thread: {e}")
177
- queue.put_nowait(e)
178
- finally:
179
- queue.put_nowait(None)
180
-
181
- async def stream_consumer(queue: asyncio.Queue, output_format: str):
182
- """
183
- [CONSUMER] Asynchronously gets items from the queue and yields them to the client.
184
- """
185
- logger.info("Starting audio consumption...")
186
- while True:
187
- # Wait for an item to appear in the queue
188
- item = await queue.get()
189
 
190
- if isinstance(item, Exception):
191
- logger.error(f"Consumer received an error from the producer: {item}")
192
- break
 
 
 
 
 
 
 
 
 
 
193
 
194
- if item is None:
195
- # Sentinel value received, meaning the stream is finished
196
- logger.info("Consumer received end-of-stream signal.")
197
- break
 
 
 
 
 
198
 
199
- # We have a valid audio chunk, convert it to the desired format
200
- audio_bytes = await run_blocking_task_async(
201
- app.state.tts_wrapper._convert_to_streamable_format,
202
- item, # The NumPy array from the queue
203
- output_format
204
- )
205
- yield audio_bytes
206
- # --- Asynchronous Offloading ---
207
 
208
- async def run_blocking_task_async(func, *args, **kwargs):
209
- """Offloads a blocking function call to the ThreadPoolExecutor."""
210
- loop = asyncio.get_event_loop()
211
- return await loop.run_in_executor(
212
- tts_executor,
213
- lambda: func(*args, **kwargs)
214
- )
215
 
 
216
  async def save_upload_file_async(upload_file: UploadFile) -> str:
217
  """Asynchronously saves the UploadFile to disk."""
218
  temp_filename = os.path.join(TEMP_AUDIO_DIR, f"{time.time()}_{upload_file.filename}")
219
  try:
220
- # Use asyncio to read the file chunks in a non-blocking manner
221
  async with aiofiles.open(temp_filename, 'wb') as out_file:
222
  while content := await upload_file.read(1024 * 1024):
223
  await out_file.write(content)
@@ -226,30 +236,48 @@ async def save_upload_file_async(upload_file: UploadFile) -> str:
226
  logger.error(f"Error saving file: {e}")
227
  raise HTTPException(status_code=500, detail="Could not save reference audio file")
228
 
229
- # --- FastAPI Lifespan Manager (Kokoro Feature) ---
 
 
 
 
 
 
 
230
 
 
 
 
 
 
 
 
 
 
 
231
  @asynccontextmanager
232
  async def lifespan(app: FastAPI):
233
- """Modern lifespan management: initialize model on startup, shutdown executor."""
234
  try:
235
  app.state.tts_wrapper = NeuTTSWrapper(device=DEVICE)
 
 
 
236
  except Exception as e:
237
  logger.error(f"Fatal startup error: {e}")
238
- # Terminate the application if the model can't load
239
- tts_executor.shutdown(wait=False)
240
- raise RuntimeError("Model initialization failed.")
241
 
242
- yield # Application serves requests
243
 
244
- # Shutdown
245
- logger.info("Shutting down ThreadPoolExecutor.")
246
- tts_executor.shutdown(wait=False)
247
 
248
  # --- FastAPI Application Setup ---
249
  app = FastAPI(
250
- title="NeuTTS Air Instant Cloning API",
251
- version="2.0.0-PROD-ENHANCED",
252
- docs_url="/docs",
253
  lifespan=lifespan
254
  )
255
 
@@ -260,23 +288,25 @@ app.add_middleware(
260
  allow_headers=["*"],
261
  )
262
 
263
- # --- New Endpoints and Enhancements ---
264
-
265
  @app.get("/")
266
  async def root():
267
- return {"message": "NeuTTS Air API v2.0 - Ready for Instant Voice Cloning"}
268
 
269
  @app.get("/health")
270
  async def health_check():
271
- """Enhanced health check (Kokoro Feature + Original Metrics)"""
272
  mem = psutil.virtual_memory()
273
  disk = psutil.disk_usage('/')
274
 
 
 
275
  return {
276
  "status": "healthy",
277
  "model_loaded": hasattr(app.state, 'tts_wrapper') and app.state.tts_wrapper.tts_model is not None,
278
  "device": DEVICE,
279
  "concurrency_limit": MAX_WORKERS,
 
280
  "memory_usage": {
281
  "total_gb": round(mem.total / (1024**3), 2),
282
  "used_percent": mem.percent
@@ -287,170 +317,225 @@ async def health_check():
287
  }
288
  }
289
 
290
- @app.delete("/cleanup")
291
- async def cleanup_files():
292
- """Maintenance endpoint to remove old generated and temporary files."""
293
- await run_blocking_task_async(cleanup_files_blocking)
294
- return {"message": "Cleanup initiated successfully."}
295
-
296
- def cleanup_files_blocking():
297
- """Blocking file cleanup logic (original NeuTTS feature)."""
298
- now = time.time()
299
- deleted_count = 0
300
-
301
- for directory in [GENERATED_AUDIO_DIR, TEMP_AUDIO_DIR]:
302
- for filename in os.listdir(directory):
303
- filepath = os.path.join(directory, filename)
304
- if os.path.isfile(filepath):
305
- try:
306
- # Original cleanup logic: delete if older than CLEANUP_THRESHOLD
307
- if now - os.path.getctime(filepath) > CLEANUP_THRESHOLD:
308
- os.remove(filepath)
309
- deleted_count += 1
310
- except Exception as e:
311
- logger.warning(f"Failed to delete {filepath}: {e}")
312
-
313
- logger.info(f"Cleanup completed: {deleted_count} files removed.")
314
- return deleted_count
315
-
316
-
317
- # --- Core Synthesis Endpoints ---
318
-
319
  @app.post("/synthesize", response_class=Response)
320
  async def text_to_speech(
321
  text: str = Form(...),
322
  reference_text: str = Form(...),
323
  speed: float = Form(1.0, ge=0.5, le=2.0),
324
  output_format: str = Form("wav", pattern="^(wav|mp3|flac)$"),
325
- reference_audio: UploadFile = File(...)):
 
 
326
  """
327
- Standard blocking TTS endpoint with Multi-Format Output (Kokoro Feature).
328
- Includes FFmpeg conversion for uploaded audio format compatibility.
329
  """
330
  if not hasattr(app.state, 'tts_wrapper'):
331
  raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
332
 
333
- # 1. Asynchronously save reference audio (original upload)
334
- temp_ref_path = await save_upload_file_async(reference_audio)
335
- converted_wav_path = None # NEW: Initialize for cleanup
336
  start_time = time.time()
337
-
 
 
338
  try:
339
- # 2. **NEW STEP**: Convert the uploaded file (WebM, etc.) to a 24kHz WAV file using FFmpeg
340
- converted_wav_path = await run_blocking_task_async(
341
- convert_to_wav_blocking,
342
- temp_ref_path
343
- )
344
-
345
- # 3. Offload the ENTIRE blocking process (encode + infer) to a thread
346
- audio_data = await run_blocking_task_async(
347
- app.state.tts_wrapper.generate_speech_blocking,
348
- text,
349
- converted_wav_path, # IMPORTANT: Pass the CONVERTED WAV path
350
- reference_text
351
  )
352
-
353
- # 4. Convert to requested format (Blocking, but usually fast)
354
- audio_bytes = await run_blocking_task_async(
 
355
  app.state.tts_wrapper._convert_to_streamable_format,
356
- audio_data,
357
- output_format
358
  )
359
-
360
- # 5. Save to disk (Original NeuTTS requirement)
361
- audio_filename = f"tts_{time.time()}.{output_format}"
362
  final_path = os.path.join(GENERATED_AUDIO_DIR, audio_filename)
363
- await run_blocking_task_async(
364
- lambda: open(final_path, 'wb').write(audio_bytes)
365
- )
366
-
367
  processing_time = time.time() - start_time
368
  audio_duration = len(audio_data) / SAMPLE_RATE
 
369
  return Response(
370
  content=audio_bytes,
371
  media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
372
  headers={
373
  "Content-Disposition": f"attachment; filename={audio_filename}",
374
  "X-Processing-Time": f"{processing_time:.2f}s",
375
- "X-Audio-Duration": f"{audio_duration:.2f}s"
 
376
  }
377
  )
 
378
  except Exception as e:
379
  logger.error(f"Synthesis error: {e}")
380
- # Reraise HTTPExceptions that may have come from the conversion step
381
  if isinstance(e, HTTPException):
382
- raise
383
  raise HTTPException(status_code=500, detail=f"Synthesis failed: {e}")
384
  finally:
385
- # 6. Clean up BOTH the original file AND the converted WAV file
386
- if os.path.exists(temp_ref_path):
387
- os.unlink(temp_ref_path)
388
- if converted_wav_path and os.path.exists(converted_wav_path):
389
- os.unlink(converted_wav_path)
 
 
 
 
 
 
 
390
 
391
  @app.post("/synthesize/stream")
392
- async def stream_text_to_speech_cloning(
393
  text: str = Form(..., min_length=1, max_length=5000),
394
  reference_text: str = Form(...),
 
395
  output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"),
396
  reference_audio: UploadFile = File(...)
397
  ):
398
  """
399
- TRUE streaming endpoint using the definitive producer-consumer pattern.
400
  """
401
  if not hasattr(app.state, 'tts_wrapper'):
402
  raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
403
-
404
- temp_ref_path = await save_upload_file_async(reference_audio)
405
 
406
- async def cleanup_and_run_stream():
407
- """A nested async generator to handle the entire producer-consumer lifecycle and cleanup."""
408
- converted_wav_path = None
409
- queue = asyncio.Queue()
410
- loop = asyncio.get_event_loop()
411
- try:
412
- # Convert the uploaded file to the required WAV format
413
- converted_wav_path = await run_blocking_task_async(convert_to_wav_blocking, temp_ref_path)
414
-
415
- # Start the producer (the model) in a background thread.
416
- # It will start putting audio chunks into the queue.
417
- loop.run_in_executor(
418
- tts_executor,
419
- app.state.tts_wrapper.stream_producer,
420
- queue, text, converted_wav_path, reference_text
421
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
 
423
- # Start the consumer, which gets chunks from the queue and yields them to the client.
424
- async for chunk in stream_consumer(queue, output_format):
425
- yield chunk
426
-
427
- finally:
428
- # This block guarantees cleanup after the stream is finished or fails
429
- if os.path.exists(temp_ref_path):
430
- os.unlink(temp_ref_path)
431
- if converted_wav_path and os.path.exists(converted_wav_path):
432
- os.unlink(converted_wav_path)
433
- logger.info("Cleaned up temporary stream files.")
434
-
435
- return StreamingResponse(
436
- cleanup_and_run_stream(),
437
- media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
438
- headers={
439
- "Content-Disposition": "attachment; filename=tts_live_stream.mp3",
440
- "Cache-Control": "no-cache",
441
- "X-Accel-Buffering": "no" # Header to prevent proxy buffering
442
- }
443
- )
444
 
445
  @app.get("/audio/{filename}")
446
  async def get_audio(filename: str):
447
- """Original NeuTTS feature to serve generated audio files."""
448
  file_path = os.path.join(GENERATED_AUDIO_DIR, filename)
449
  if not os.path.exists(file_path):
450
  raise HTTPException(status_code=404, detail="Audio file not found")
451
 
 
 
 
 
452
  return Response(
453
- content=open(file_path, "rb").read(),
454
- media_type=f"audio/{filename.split('.')[-1]}", # Simple media type detection
455
  headers={"Content-Disposition": f"attachment; filename={filename}"}
456
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  import subprocess
10
  import tempfile
11
  from concurrent.futures import ThreadPoolExecutor
12
+ from typing import Optional, Generator, AsyncGenerator
13
  from contextlib import asynccontextmanager
14
  import logging
15
+ import aiofiles
16
  import torch
17
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Query, BackgroundTasks
18
  from fastapi.responses import Response, StreamingResponse
19
  from fastapi.middleware.cors import CORSMiddleware
20
  from pydantic import BaseModel, Field
21
+ import uuid
22
+ from dataclasses import dataclass
23
+ from queue import Queue, Empty
24
+ import threading
25
 
26
  # Ensure the cloned neutts-air repository is in the path
27
  import sys
 
29
  from neuttsair.neutts import NeuTTSAir
30
 
31
  # Configure logging
32
+ logging.basicConfig(
33
+ level=logging.INFO,
34
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
35
+ )
36
  logger = logging.getLogger("NeuTTS-API")
37
 
38
+ # --- Configuration & Constants ---
39
+ DEVICE = "cpu"
 
 
 
40
  MAX_WORKERS = 2
41
  tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
42
  SAMPLE_RATE = 24000
43
+ CLEANUP_THRESHOLD = 300
44
  TEMP_AUDIO_DIR = "temp_audio"
45
  GENERATED_AUDIO_DIR = "generated_audio"
46
  os.makedirs(TEMP_AUDIO_DIR, exist_ok=True)
47
  os.makedirs(GENERATED_AUDIO_DIR, exist_ok=True)
48
 
49
+ # --- Data Models ---
50
  class TTSRequestModel(BaseModel):
 
51
  text: str = Field(..., min_length=1, max_length=1000)
52
  speed: float = Field(default=1.0, ge=0.5, le=2.0)
53
  output_format: str = Field(default="wav", pattern="^(wav|mp3|flac)$")
54
 
55
+ @dataclass
56
+ class SynthesisTask:
57
+ task_id: str
58
+ text: str
59
+ reference_audio_path: str
60
+ reference_text: str
61
+ output_format: str
62
+ created_at: float
63
+
64
+ # --- Enhanced Audio Conversion with Async Support ---
65
+ async def convert_to_wav_async(input_path: str) -> str:
66
+ """Asynchronous audio conversion using subprocess with async wrapper."""
67
  with tempfile.NamedTemporaryFile(suffix=".wav", dir=TEMP_AUDIO_DIR, delete=False) as tmp:
68
  output_path = tmp.name
69
+
70
+ logger.info(f"Converting '{os.path.basename(input_path)}' to WAV")
71
+
 
 
 
 
 
 
 
72
  command = [
73
+ "ffmpeg", "-y", "-i", input_path,
74
+ "-f", "wav", "-ar", str(SAMPLE_RATE),
75
+ "-ac", "1", "-c:a", "pcm_s16le", output_path
 
 
 
 
 
76
  ]
77
+
78
  try:
79
+ # Run FFmpeg asynchronously
80
+ process = await asyncio.create_subprocess_exec(
81
+ *command,
82
+ stdout=asyncio.subprocess.PIPE,
83
+ stderr=asyncio.subprocess.PIPE
84
+ )
85
+
86
+ stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=30)
87
+
88
+ if process.returncode != 0:
89
+ error_detail = stderr.decode().splitlines()[-1] if stderr else "Unknown FFmpeg error"
90
+ logger.error(f"FFmpeg conversion failed: {error_detail}")
91
+ if os.path.exists(output_path):
92
+ os.unlink(output_path)
93
+ raise HTTPException(status_code=400, detail=f"Audio conversion failed: {error_detail}")
94
+
95
+ logger.info("FFmpeg conversion successful")
96
  return output_path
97
+
98
+ except asyncio.TimeoutError:
99
+ logger.error("FFmpeg conversion timed out")
 
 
 
 
 
 
 
100
  if os.path.exists(output_path):
101
  os.unlink(output_path)
102
+ raise HTTPException(status_code=504, detail="Audio conversion timed out")
103
  except Exception as e:
104
+ logger.error(f"Conversion error: {e}")
105
  if os.path.exists(output_path):
106
  os.unlink(output_path)
107
+ raise HTTPException(status_code=500, detail="Unexpected conversion error")
 
108
 
109
+ # --- Enhanced Model Wrapper with Async Streaming ---
110
  class NeuTTSWrapper:
111
  def __init__(self, device: str = "cpu"):
112
  self.tts_model = None
113
  self.device = device
114
+ self._model_lock = asyncio.Lock() # For thread-safe model access
115
  self.load_model()
116
 
117
  def load_model(self):
118
  try:
119
  logger.info(f"Loading NeuTTSAir model on device: {self.device}")
 
120
  self.tts_model = NeuTTSAir(backbone_device=self.device, codec_device=self.device)
121
+ logger.info("✅ NeuTTSAir model loaded successfully")
122
  except Exception as e:
123
  logger.error(f"❌ Model loading failed: {e}")
124
  raise
125
 
126
  def _convert_to_streamable_format(self, audio_data: np.ndarray, audio_format: str) -> bytes:
127
+ """Convert NumPy audio array to streamable bytes."""
128
  audio_buffer = io.BytesIO()
129
  try:
130
  sf.write(audio_buffer, audio_data, SAMPLE_RATE, format=audio_format)
 
134
  audio_buffer.seek(0)
135
  return audio_buffer.read()
136
 
137
+ def _split_text_into_chunks(self, text: str, max_chunk_length: int = 100) -> list[str]:
138
+ """Enhanced text splitting for better streaming chunks."""
139
+ # Simple sentence-based splitting with length limits
140
+ sentences = []
141
+ current_sentence = ""
 
 
 
 
142
 
143
+ for word in text.split():
144
+ test_sentence = f"{current_sentence} {word}".strip()
145
+ if len(test_sentence) <= max_chunk_length:
146
+ current_sentence = test_sentence
147
+ else:
148
+ if current_sentence:
149
+ sentences.append(current_sentence)
150
+ current_sentence = word
151
 
152
+ if current_sentence:
153
+ sentences.append(current_sentence)
154
+
155
+ return sentences or [text]
156
+
157
+ async def generate_speech_async(self, text: str, ref_audio_path: str, reference_text: str) -> np.ndarray:
158
+ """Asynchronous speech generation with proper locking."""
159
+ async with self._model_lock:
160
+ return await asyncio.get_event_loop().run_in_executor(
161
+ tts_executor,
162
+ self._generate_speech_blocking,
163
+ text, ref_audio_path, reference_text
164
+ )
165
+
166
+ def _generate_speech_blocking(self, text: str, ref_audio_path: str, reference_text: str) -> np.ndarray:
167
+ """Blocking speech generation (runs in thread pool)."""
168
  ref_s = self.tts_model.encode_reference(ref_audio_path)
 
 
169
  with torch.no_grad():
170
  audio = self.tts_model.infer(text, ref_s, reference_text)
171
  return audio
172
 
173
+ async def stream_speech_async(
174
+ self,
175
+ text: str,
176
+ ref_audio_path: str,
177
+ reference_text: str,
178
+ audio_format: str
179
+ ) -> AsyncGenerator[bytes, None]:
180
+ """True asynchronous streaming with immediate chunk delivery."""
181
+ logger.info(f"Starting true streaming synthesis for text length: {len(text)}")
182
+
183
+ # Encode reference once (this is the only blocking part we need to do first)
184
+ async with self._model_lock:
185
+ ref_s = await asyncio.get_event_loop().run_in_executor(
186
+ tts_executor,
187
+ self.tts_model.encode_reference,
188
+ ref_audio_path
189
+ )
190
+
191
+ # Split text into chunks for streaming
192
+ sentences = self._split_text_into_chunks(text)
193
+ logger.info(f"Split text into {len(sentences)} chunks for streaming")
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
+ # Stream each chunk asynchronously
196
+ for i, sentence in enumerate(sentences):
197
+ if not sentence.strip():
198
+ continue
199
+
200
+ logger.debug(f"Generating streaming chunk {i+1}: '{sentence[:30]}...'")
201
+
202
+ # Generate this chunk asynchronously
203
+ audio_chunk = await asyncio.get_event_loop().run_in_executor(
204
+ tts_executor,
205
+ self._infer_chunk,
206
+ sentence, ref_s, reference_text
207
+ )
208
 
209
+ # Convert and yield immediately
210
+ chunk_bytes = await asyncio.get_event_loop().run_in_executor(
211
+ tts_executor,
212
+ self._convert_to_streamable_format,
213
+ audio_chunk, audio_format
214
+ )
215
+
216
+ yield chunk_bytes
217
+ logger.debug(f"Yielded chunk {i+1} ({len(chunk_bytes)} bytes)")
218
 
219
+ logger.info("Streaming synthesis complete")
 
 
 
 
 
 
 
220
 
221
+ def _infer_chunk(self, sentence: str, ref_s, reference_text: str) -> np.ndarray:
222
+ """Infer a single chunk (runs in thread pool)."""
223
+ with torch.no_grad():
224
+ return self.tts_model.infer(sentence, ref_s, reference_text)
 
 
 
225
 
226
+ # --- Async Utility Functions ---
227
  async def save_upload_file_async(upload_file: UploadFile) -> str:
228
  """Asynchronously saves the UploadFile to disk."""
229
  temp_filename = os.path.join(TEMP_AUDIO_DIR, f"{time.time()}_{upload_file.filename}")
230
  try:
 
231
  async with aiofiles.open(temp_filename, 'wb') as out_file:
232
  while content := await upload_file.read(1024 * 1024):
233
  await out_file.write(content)
 
236
  logger.error(f"Error saving file: {e}")
237
  raise HTTPException(status_code=500, detail="Could not save reference audio file")
238
 
239
+ async def cleanup_file_async(file_path: str):
240
+ """Asynchronously clean up a file."""
241
+ try:
242
+ if os.path.exists(file_path):
243
+ os.unlink(file_path)
244
+ logger.debug(f"Cleaned up file: {file_path}")
245
+ except Exception as e:
246
+ logger.warning(f"Failed to cleanup file {file_path}: {e}")
247
 
248
+ async def scheduled_cleanup_task():
249
+ """Runs the cleanup task periodically in the background."""
250
+ while True:
251
+ await asyncio.sleep(CLEANUP_THRESHOLD) # Wait for the defined period (e.g., 1 hour)
252
+ logger.info("Running scheduled cleanup of old audio files...")
253
+ try:
254
+ await cleanup_files_async()
255
+ except Exception as e:
256
+ logger.error(f"Scheduled cleanup task failed: {e}")
257
+ # --- FastAPI Lifespan Manager ---
258
  @asynccontextmanager
259
  async def lifespan(app: FastAPI):
260
+ """Modern lifespan management."""
261
  try:
262
  app.state.tts_wrapper = NeuTTSWrapper(device=DEVICE)
263
+ app.state.synthesis_tasks = {} # Track active tasks
264
+ asyncio.create_task(scheduled_cleanup_task())
265
+ logger.info("✅ Application startup complete")
266
  except Exception as e:
267
  logger.error(f"Fatal startup error: {e}")
268
+ tts_executor.shutdown(wait=False)
269
+ raise RuntimeError("Model initialization failed")
 
270
 
271
+ yield
272
 
273
+ logger.info("Shutting down ThreadPoolExecutor")
274
+ tts_executor.shutdown(wait=True)
 
275
 
276
  # --- FastAPI Application Setup ---
277
  app = FastAPI(
278
+ title="NeuTTS Air Instant Cloning API - Enhanced",
279
+ version="3.0.0-PROD-STREAMING",
280
+ docs_url="/docs",
281
  lifespan=lifespan
282
  )
283
 
 
288
  allow_headers=["*"],
289
  )
290
 
291
+ # --- Enhanced Endpoints ---
 
292
  @app.get("/")
293
  async def root():
294
+ return {"message": "NeuTTS Air API v3.0 - True Streaming Ready"}
295
 
296
  @app.get("/health")
297
  async def health_check():
298
+ """Enhanced health check with streaming metrics."""
299
  mem = psutil.virtual_memory()
300
  disk = psutil.disk_usage('/')
301
 
302
+ active_tasks = len(getattr(app.state, 'synthesis_tasks', {}))
303
+
304
  return {
305
  "status": "healthy",
306
  "model_loaded": hasattr(app.state, 'tts_wrapper') and app.state.tts_wrapper.tts_model is not None,
307
  "device": DEVICE,
308
  "concurrency_limit": MAX_WORKERS,
309
+ "active_synthesis_tasks": active_tasks,
310
  "memory_usage": {
311
  "total_gb": round(mem.total / (1024**3), 2),
312
  "used_percent": mem.percent
 
317
  }
318
  }
319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  @app.post("/synthesize", response_class=Response)
321
  async def text_to_speech(
322
  text: str = Form(...),
323
  reference_text: str = Form(...),
324
  speed: float = Form(1.0, ge=0.5, le=2.0),
325
  output_format: str = Form("wav", pattern="^(wav|mp3|flac)$"),
326
+ reference_audio: UploadFile = File(...),
327
+ background_tasks: BackgroundTasks = None
328
+ ):
329
  """
330
+ Enhanced standard TTS endpoint with better async handling.
 
331
  """
332
  if not hasattr(app.state, 'tts_wrapper'):
333
  raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
334
 
 
 
 
335
  start_time = time.time()
336
+ temp_ref_path = None
337
+ converted_wav_path = None
338
+
339
  try:
340
+ # 1. Save uploaded file
341
+ temp_ref_path = await save_upload_file_async(reference_audio)
342
+
343
+ # 2. Convert to WAV
344
+ converted_wav_path = await convert_to_wav_async(temp_ref_path)
345
+
346
+ # 3. Generate speech asynchronously
347
+ audio_data = await app.state.tts_wrapper.generate_speech_async(
348
+ text, converted_wav_path, reference_text
 
 
 
349
  )
350
+
351
+ # 4. Convert to requested format
352
+ audio_bytes = await asyncio.get_event_loop().run_in_executor(
353
+ tts_executor,
354
  app.state.tts_wrapper._convert_to_streamable_format,
355
+ audio_data, output_format
 
356
  )
357
+
358
+ # 5. Save to disk (optional - can be disabled in production)
359
+ audio_filename = f"tts_{int(time.time())}.{output_format}"
360
  final_path = os.path.join(GENERATED_AUDIO_DIR, audio_filename)
361
+
362
+ async with aiofiles.open(final_path, 'wb') as f:
363
+ await f.write(audio_bytes)
364
+
365
  processing_time = time.time() - start_time
366
  audio_duration = len(audio_data) / SAMPLE_RATE
367
+
368
  return Response(
369
  content=audio_bytes,
370
  media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
371
  headers={
372
  "Content-Disposition": f"attachment; filename={audio_filename}",
373
  "X-Processing-Time": f"{processing_time:.2f}s",
374
+ "X-Audio-Duration": f"{audio_duration:.2f}s",
375
+ "X-First-Chunk-Time": f"{processing_time:.2f}s" # For comparison
376
  }
377
  )
378
+
379
  except Exception as e:
380
  logger.error(f"Synthesis error: {e}")
 
381
  if isinstance(e, HTTPException):
382
+ raise
383
  raise HTTPException(status_code=500, detail=f"Synthesis failed: {e}")
384
  finally:
385
+ # Schedule cleanup in background
386
+ if background_tasks:
387
+ if temp_ref_path:
388
+ background_tasks.add_task(cleanup_file_async, temp_ref_path)
389
+ if converted_wav_path:
390
+ background_tasks.add_task(cleanup_file_async, converted_wav_path)
391
+ else:
392
+ # Fallback synchronous cleanup
393
+ if temp_ref_path and os.path.exists(temp_ref_path):
394
+ os.unlink(temp_ref_path)
395
+ if converted_wav_path and os.path.exists(converted_wav_path):
396
+ os.unlink(converted_wav_path)
397
 
398
  @app.post("/synthesize/stream")
399
+ async def stream_text_to_speech(
400
  text: str = Form(..., min_length=1, max_length=5000),
401
  reference_text: str = Form(...),
402
+ speed: float = Form(1.0, ge=0.5, le=2.0),
403
  output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"),
404
  reference_audio: UploadFile = File(...)
405
  ):
406
  """
407
+ TRUE Streaming Endpoint - delivers audio chunks as they're generated.
408
  """
409
  if not hasattr(app.state, 'tts_wrapper'):
410
  raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
 
 
411
 
412
+ temp_ref_path = None
413
+ converted_wav_path = None
414
+
415
+ try:
416
+ # 1. Save and convert reference audio
417
+ temp_ref_path = await save_upload_file_async(reference_audio)
418
+ converted_wav_path = await convert_to_wav_async(temp_ref_path)
419
+
420
+ # 2. Clean up original file immediately
421
+ if temp_ref_path and os.path.exists(temp_ref_path):
422
+ await cleanup_file_async(temp_ref_path)
423
+ temp_ref_path = None
424
+
425
+ # 3. Create async generator for streaming
426
+ async def generate_audio_stream():
427
+ """Async generator that yields audio chunks as they're produced."""
428
+ try:
429
+ first_chunk_time = time.time()
430
+ chunk_count = 0
431
+
432
+ async for chunk_bytes in app.state.tts_wrapper.stream_speech_async(
433
+ text, converted_wav_path, reference_text, output_format
434
+ ):
435
+ chunk_count += 1
436
+
437
+ # Log timing for first chunk
438
+ if chunk_count == 1:
439
+ first_chunk_time = time.time() - first_chunk_time
440
+ logger.info(f"First audio chunk delivered in {first_chunk_time:.2f}s")
441
+
442
+ yield chunk_bytes
443
+
444
+ except Exception as e:
445
+ logger.error(f"Stream generation error: {e}")
446
+ raise
447
+ finally:
448
+ # Clean up converted file when streaming is complete
449
+ if converted_wav_path and os.path.exists(converted_wav_path):
450
+ await cleanup_file_async(converted_wav_path)
451
+
452
+ # 4. Return streaming response
453
+ return StreamingResponse(
454
+ generate_audio_stream(),
455
+ media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
456
+ headers={
457
+ "Content-Disposition": "attachment; filename=tts_live_stream.mp3",
458
+ "Transfer-Encoding": "chunked",
459
+ "Cache-Control": "no-cache",
460
+ "X-Accel-Buffering": "no",
461
+ "X-Streaming": "true"
462
+ }
463
+ )
464
+
465
+ except Exception as e:
466
+ logger.error(f"Streaming setup error: {e}")
467
+ # Cleanup on error
468
+ if temp_ref_path and os.path.exists(temp_ref_path):
469
+ await cleanup_file_async(temp_ref_path)
470
+ if converted_wav_path and os.path.exists(converted_wav_path):
471
+ await cleanup_file_async(converted_wav_path)
472
 
473
+ if isinstance(e, HTTPException):
474
+ raise
475
+ raise HTTPException(status_code=500, detail=f"Streaming setup failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
 
477
  @app.get("/audio/{filename}")
478
  async def get_audio(filename: str):
479
+ """Serve generated audio files."""
480
  file_path = os.path.join(GENERATED_AUDIO_DIR, filename)
481
  if not os.path.exists(file_path):
482
  raise HTTPException(status_code=404, detail="Audio file not found")
483
 
484
+ # Use async file reading for better performance
485
+ async with aiofiles.open(file_path, "rb") as f:
486
+ content = await f.read()
487
+
488
  return Response(
489
+ content=content,
490
+ media_type=f"audio/{filename.split('.')[-1]}",
491
  headers={"Content-Disposition": f"attachment; filename={filename}"}
492
  )
493
+
494
+ @app.delete("/cleanup")
495
+ async def cleanup_files():
496
+ """Enhanced cleanup endpoint."""
497
+ deleted_count = await cleanup_files_async()
498
+ return {"message": f"Cleanup completed: {deleted_count} files removed"}
499
+
500
+ async def cleanup_files_async():
501
+ """Async file cleanup."""
502
+ now = time.time()
503
+ deleted_count = 0
504
+
505
+ for directory in [GENERATED_AUDIO_DIR, TEMP_AUDIO_DIR]:
506
+ if not os.path.exists(directory):
507
+ continue
508
+
509
+ for filename in os.listdir(directory):
510
+ filepath = os.path.join(directory, filename)
511
+ if os.path.isfile(filepath):
512
+ try:
513
+ if now - os.path.getctime(filepath) > CLEANUP_THRESHOLD:
514
+ await cleanup_file_async(filepath)
515
+ deleted_count += 1
516
+ except Exception as e:
517
+ logger.warning(f"Failed to delete {filepath}: {e}")
518
+
519
+ logger.info(f"Cleanup completed: {deleted_count} files removed")
520
+ return deleted_count
521
+
522
+ # Performance monitoring endpoint
523
+ @app.get("/metrics")
524
+ async def get_metrics():
525
+ """Performance metrics endpoint."""
526
+ return {
527
+ "active_threads": threading.active_count(),
528
+ "executor_queue_size": tts_executor._work_queue.qsize() if hasattr(tts_executor, '_work_queue') else 0,
529
+ "memory_usage_mb": psutil.Process().memory_info().rss / 1024 / 1024
530
+ }
531
+
532
+ if __name__ == "__main__":
533
+ import uvicorn
534
+ uvicorn.run(
535
+ "app:app",
536
+ host="0.0.0.0",
537
+ port=7860,
538
+ workers=1, # Multiple workers not supported with in-memory model
539
+ loop="asyncio",
540
+ access_log=True
541
+ )