Rajhuggingface4253 commited on
Commit
5d68bda
·
verified ·
1 Parent(s): 01b3f2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -152
app.py CHANGED
@@ -18,7 +18,10 @@ from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Query
18
  from fastapi.responses import Response, StreamingResponse
19
  from fastapi.middleware.cors import CORSMiddleware
20
  from pydantic import BaseModel, Field
21
-
 
 
 
22
  # Ensure the cloned neutts-air repository is in the path
23
  import sys
24
  sys.path.append(os.path.join(os.getcwd(), 'neutts-air'))
@@ -33,7 +36,7 @@ logger = logging.getLogger("NeuTTS-API")
33
  # Explicitly use CPU as per Dockerfile and Hugging Face free tier compatibility
34
  DEVICE = "cpu"
35
  # Configure Max Workers for concurrent synthesis threads (1-2 is safe for CPU-only)
36
- MAX_WORKERS = 2
37
  tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
38
  SAMPLE_RATE = 24000
39
  CLEANUP_THRESHOLD = 300 # 1 hour in seconds
@@ -49,62 +52,43 @@ class TTSRequestModel(BaseModel):
49
  output_format: str = Field(default="wav", pattern="^(wav|mp3|flac)$")
50
 
51
 
52
- def convert_to_wav_blocking(input_path: str) -> str:
53
  """
54
- NEW FUNCTION: Uses FFmpeg to convert any uploaded audio format (WebM, MP4, etc.)
55
- to a 24kHz, 16-bit PCM WAV file, which is required by soundfile/libsndfile.
56
- This function must run in the ThreadPoolExecutor.
57
  """
58
- # Create a unique temporary filename for the converted WAV file
59
- # We use tempfile.NamedTemporaryFile to safely create a path
60
- # and then delete the file handle so ffmpeg can write to it.
61
- with tempfile.NamedTemporaryFile(suffix=".wav", dir=TEMP_AUDIO_DIR, delete=False) as tmp:
62
- output_path = tmp.name
63
-
64
- logger.info(f"Converting '{os.path.basename(input_path)}' to WAV (24kHz, mono) at {os.path.basename(output_path)}")
65
-
66
- # FFmpeg command details:
67
- # -y: overwrite output file if it exists
68
- # -i: input file path
69
- # -f wav: output format is WAV
70
- # -ar 24000: set sample rate to 24000 (required by NeuTTS)
71
- # -ac 1: set audio channels to 1 (mono)
72
- # -c:a pcm_s16le: set codec to uncompressed 16-bit PCM (standard WAV)
73
- command = [
74
- "ffmpeg",
75
- "-y",
76
- "-i", input_path,
77
  "-f", "wav",
78
- "-ar", str(SAMPLE_RATE),
79
- "-ac", "1",
80
- "-c:a", "pcm_s16le",
81
- output_path
82
  ]
 
 
 
 
 
 
 
 
83
 
84
- try:
85
- # Run the FFmpeg command
86
- # Use a short timeout to prevent runaway processes
87
- result = subprocess.run(command, check=True, capture_output=True, text=True, timeout=30)
88
- logger.info(f"FFmpeg conversion successful.")
89
- return output_path
90
- except subprocess.CalledProcessError as e:
91
- logger.error(f"FFmpeg conversion failed: {e.stderr}")
92
- # Clean up the output path if FFmpeg failed to write it
93
- if os.path.exists(output_path):
94
- os.unlink(output_path)
95
  # Provide the last line of the FFmpeg error to the user
96
- error_detail = e.stderr.splitlines()[-1] if e.stderr else "Unknown FFmpeg error."
97
  raise HTTPException(status_code=400, detail=f"Audio format conversion failed: {error_detail}")
98
- except subprocess.TimeoutExpired:
99
- logger.error("FFmpeg conversion timed out.")
100
- if os.path.exists(output_path):
101
- os.unlink(output_path)
102
- raise HTTPException(status_code=504, detail="Audio conversion timed out after 30 seconds.")
103
- except Exception as e:
104
- logger.error(f"General conversion error: {e}")
105
- if os.path.exists(output_path):
106
- os.unlink(output_path)
107
- raise HTTPException(status_code=500, detail="An unexpected error occurred during audio conversion.")
108
  # --- Model Wrapper and Logic ---
109
 
110
  class NeuTTSWrapper:
@@ -135,32 +119,50 @@ class NeuTTSWrapper:
135
  return audio_buffer.read()
136
 
137
  def _split_text_into_chunks(self, text: str) -> list[str]:
138
- """Simple sentence splitting for streaming (can be enhanced with regex)."""
139
- sentences = [s.strip() for s in text.split('.') if s.strip()]
140
- if not sentences:
141
- sentences = [text.strip()]
142
- return sentences
143
-
144
- def generate_speech_blocking(self, text: str, ref_audio_path: str, reference_text: str) -> np.ndarray:
145
- """Blocking synthesis for standard endpoint."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
-
148
- ref_s = self.tts_model.encode_reference(ref_audio_path)
149
 
150
  # 3. Infer full text
151
  with torch.no_grad():
152
  audio = self.tts_model.infer(text, ref_s, reference_text)
153
  return audio
154
 
155
- def stream_speech_blocking(self, text: str, ref_audio_path: str, reference_text: str, speed: float, audio_format: str) -> Generator[bytes, None, None]:
156
- """Sentence-by-Sentence Streaming (Blocking)."""
157
  logger.info(f"Starting streaming synthesis for text length: {len(text)}")
158
 
 
 
 
 
 
159
 
160
-
161
- ref_s = self.tts_model.encode_reference(ref_audio_path)
162
-
163
- # 3. Split text
164
  sentences = self._split_text_into_chunks(text)
165
 
166
  # 4. Stream chunks
@@ -170,11 +172,9 @@ class NeuTTSWrapper:
170
 
171
  logger.debug(f"Generating streaming chunk {i+1}: '{sentence[:30]}...'")
172
 
173
- # Infer sentence
174
  with torch.no_grad():
175
  audio_chunk = self.tts_model.infer(sentence, ref_s, reference_text)
176
 
177
- # Convert and yield
178
  yield self._convert_to_streamable_format(audio_chunk, audio_format)
179
 
180
  logger.info("Streaming synthesis complete.")
@@ -300,69 +300,48 @@ async def text_to_speech(
300
  output_format: str = Form("wav", pattern="^(wav|mp3|flac)$"),
301
  reference_audio: UploadFile = File(...)):
302
  """
303
- Standard blocking TTS endpoint with Multi-Format Output (Kokoro Feature).
304
- Includes FFmpeg conversion for uploaded audio format compatibility.
305
  """
306
  if not hasattr(app.state, 'tts_wrapper'):
307
  raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
308
 
309
- # 1. Asynchronously save reference audio (original upload)
310
- temp_ref_path = await save_upload_file_async(reference_audio)
311
- converted_wav_path = None # NEW: Initialize for cleanup
312
  start_time = time.time()
313
-
314
  try:
315
- # 2. **NEW STEP**: Convert the uploaded file (WebM, etc.) to a 24kHz WAV file using FFmpeg
316
- converted_wav_path = await run_blocking_task_async(
317
- convert_to_wav_blocking,
318
- temp_ref_path
319
- )
320
 
321
- # 3. Offload the ENTIRE blocking process (encode + infer) to a thread
322
  audio_data = await run_blocking_task_async(
323
  app.state.tts_wrapper.generate_speech_blocking,
324
  text,
325
- converted_wav_path, # IMPORTANT: Pass the CONVERTED WAV path
326
  reference_text
327
  )
328
 
329
- # 4. Convert to requested format (Blocking, but usually fast)
330
  audio_bytes = await run_blocking_task_async(
331
  app.state.tts_wrapper._convert_to_streamable_format,
332
  audio_data,
333
  output_format
334
  )
335
 
336
- # 5. Save to disk (Original NeuTTS requirement)
337
- audio_filename = f"tts_{time.time()}.{output_format}"
338
- final_path = os.path.join(GENERATED_AUDIO_DIR, audio_filename)
339
- await run_blocking_task_async(
340
- lambda: open(final_path, 'wb').write(audio_bytes)
341
- )
342
-
343
  processing_time = time.time() - start_time
344
  audio_duration = len(audio_data) / SAMPLE_RATE
345
  return Response(
346
  content=audio_bytes,
347
  media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
348
  headers={
349
- "Content-Disposition": f"attachment; filename={audio_filename}",
350
  "X-Processing-Time": f"{processing_time:.2f}s",
351
  "X-Audio-Duration": f"{audio_duration:.2f}s"
352
  }
353
  )
354
  except Exception as e:
355
  logger.error(f"Synthesis error: {e}")
356
- # Reraise HTTPExceptions that may have come from the conversion step
357
  if isinstance(e, HTTPException):
358
  raise
359
  raise HTTPException(status_code=500, detail=f"Synthesis failed: {e}")
360
- finally:
361
- # 6. Clean up BOTH the original file AND the converted WAV file
362
- if os.path.exists(temp_ref_path):
363
- os.unlink(temp_ref_path)
364
- if converted_wav_path and os.path.exists(converted_wav_path):
365
- os.unlink(converted_wav_path)
366
 
367
  @app.post("/synthesize/stream")
368
  async def stream_text_to_speech_cloning(
@@ -372,74 +351,79 @@ async def stream_text_to_speech_cloning(
372
  output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"),
373
  reference_audio: UploadFile = File(...)):
374
  """
375
- Sentence-by-Sentence Streaming Endpoint.
376
- Fixes race condition by moving cleanup into the streaming generator.
377
  """
378
  if not hasattr(app.state, 'tts_wrapper'):
379
  raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
380
 
381
- # 1. Asynchronously save reference audio (non-blocking)
382
- temp_ref_path = await save_upload_file_async(reference_audio)
383
- converted_wav_path = None # Initialize for cleanup
384
-
385
  try:
386
- # 2. Convert the uploaded file (WebM, etc.) to a 24kHz WAV file
387
- converted_wav_path = await run_blocking_task_async(
388
- convert_to_wav_blocking,
389
- temp_ref_path
390
- )
391
 
392
- # 2.5. CLEANUP ORIGINAL FILE IMMEDIATELY: It is no longer needed after conversion
393
- if os.path.exists(temp_ref_path):
394
- os.unlink(temp_ref_path)
395
-
396
- # 3. Define the generator function, which will run in the thread pool
397
- def stream_generator(path_to_delete: str):
398
- try:
399
- # This logic uses the path_to_delete parameter, which is guaranteed to exist
400
- for chunk_bytes in app.state.tts_wrapper.stream_speech_blocking(
401
- text,
402
- path_to_delete, # Pass the CONVERTED WAV path
403
- reference_text,
404
- speed,
405
- output_format
406
- ):
407
- yield chunk_bytes
408
- except Exception as e:
409
- # Log the error and raise it to stop the stream
410
- logger.error(f"Streaming generator error: {e}")
411
- raise # Re-raise to ensure the stream terminates
412
- finally:
413
- # 4. **CRUCIAL FIX:** Clean up the converted file ONLY AFTER GENERATION IS DONE
414
- if os.path.exists(path_to_delete):
415
- os.unlink(path_to_delete)
416
- logger.info(f"Cleaned up converted file: {path_to_delete}")
417
-
418
- # Return StreamingResponse, passing the path to the generator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
  return StreamingResponse(
420
- stream_generator(converted_wav_path),
421
- media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
422
- headers={
423
- "Content-Disposition": "attachment; filename=tts_live_stream.mp3",
424
- "Transfer-Encoding": "chunked",
425
- "Cache-Control": "no-cache",
426
- "X-Accel-Buffering": "no"
427
- }
428
  )
429
 
430
  except Exception as e:
431
  logger.error(f"Streaming setup error: {e}")
432
- # Clean up files only if the setup failed *before* starting the generator
433
- if os.path.exists(temp_ref_path):
434
- os.unlink(temp_ref_path)
435
- if converted_wav_path and os.path.exists(converted_wav_path):
436
- os.unlink(converted_wav_path)
437
-
438
- # Reraise HTTPExceptions that may have come from the conversion step
439
  if isinstance(e, HTTPException):
440
  raise
441
  raise HTTPException(status_code=500, detail=f"Streaming synthesis failed: {e}")
442
- # Note: The outer 'finally' block is now removed as its logic is handled in 2.5 and 4.
443
 
444
  @app.get("/audio/{filename}")
445
  async def get_audio(filename: str):
 
18
  from fastapi.responses import Response, StreamingResponse
19
  from fastapi.middleware.cors import CORSMiddleware
20
  from pydantic import BaseModel, Field
21
+ import re
22
+ import hashlib
23
+ from functools import lru_cache
24
+ import queue
25
  # Ensure the cloned neutts-air repository is in the path
26
  import sys
27
  sys.path.append(os.path.join(os.getcwd(), 'neutts-air'))
 
36
  # Explicitly use CPU as per Dockerfile and Hugging Face free tier compatibility
37
  DEVICE = "cpu"
38
  # Configure Max Workers for concurrent synthesis threads (1-2 is safe for CPU-only)
39
+ MAX_WORKERS = 1
40
  tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
41
  SAMPLE_RATE = 24000
42
  CLEANUP_THRESHOLD = 300 # 1 hour in seconds
 
52
  output_format: str = Field(default="wav", pattern="^(wav|mp3|flac)$")
53
 
54
 
55
+ async def convert_to_wav_in_memory(upload_file: UploadFile) -> io.BytesIO:
56
  """
57
+ Converts uploaded audio to a 24kHz WAV in memory using FFmpeg pipes.
58
+ This avoids all intermediate disk I/O for maximum speed.
 
59
  """
60
+ ffmpeg_command = [
61
+ "ffmpeg",
62
+ "-i", "pipe:0", # Read from stdin
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  "-f", "wav",
64
+ "-ar", str(SAMPLE_RATE),
65
+ "-ac", "1",
66
+ "-c:a", "pcm_s16le",
67
+ "pipe:1" # Write to stdout
68
  ]
69
+
70
+ # Start the subprocess with pipes for stdin, stdout, and stderr
71
+ proc = await asyncio.create_subprocess_exec(
72
+ *ffmpeg_command,
73
+ stdin=subprocess.PIPE,
74
+ stdout=subprocess.PIPE,
75
+ stderr=subprocess.PIPE
76
+ )
77
 
78
+ # Stream the uploaded file data into ffmpeg's stdin
79
+ # and capture the resulting WAV data from its stdout
80
+ wav_data, stderr_data = await proc.communicate(input=await upload_file.read())
81
+
82
+ if proc.returncode != 0:
83
+ error_message = stderr_data.decode()
84
+ logger.error(f"In-memory conversion failed: {error_message}")
 
 
 
 
85
  # Provide the last line of the FFmpeg error to the user
86
+ error_detail = error_message.splitlines()[-1] if error_message else "Unknown FFmpeg error."
87
  raise HTTPException(status_code=400, detail=f"Audio format conversion failed: {error_detail}")
88
+
89
+ logger.info("In-memory FFmpeg conversion successful.")
90
+ # Return the raw WAV data in a BytesIO buffer, ready for the model
91
+ return io.BytesIO(wav_data)
 
 
 
 
 
 
92
  # --- Model Wrapper and Logic ---
93
 
94
  class NeuTTSWrapper:
 
119
  return audio_buffer.read()
120
 
121
  def _split_text_into_chunks(self, text: str) -> list[str]:
122
+ """
123
+ Splits text into sentences OR clauses using a robust regex.
124
+ This is fast, library-free, and now handles commas.
125
+ """
126
+ # This regex now finds all sequences of characters that are not a sentence-ending
127
+ # or clause-ending punctuation mark, followed by that punctuation.
128
+ # The only change is adding ',' to the character sets.
129
+ chunks = re.findall(r'[^.,!?]+[.,!?]*', text)
130
+ return [c.strip() for c in chunks if c.strip()]
131
+
132
+ @lru_cache(maxsize=32)
133
+ def _get_or_create_reference_encoding(self, audio_content_hash: str, audio_bytes: bytes) -> torch.Tensor:
134
+ """
135
+ Caches the expensive reference encoding operation using an in-memory LRU cache.
136
+ The hash of the audio content is the key.
137
+ """
138
+ logger.info(f"Cache miss for hash: {audio_content_hash[:10]}... Encoding new reference.")
139
+ # The model's encode_reference can take a file-like object (BytesIO)
140
+ return self.tts_model.encode_reference(io.BytesIO(audio_bytes))
141
+
142
+ def generate_speech_blocking(self, text: str, ref_audio_bytes: bytes, reference_text: str) -> np.ndarray:
143
+ """Blocking synthesis using cached reference encoding."""
144
+ # 1. Hash the audio bytes to get a cache key
145
+ audio_hash = hashlib.sha256(ref_audio_bytes).hexdigest()
146
 
147
+ # 2. Get the encoding from the cache (or create it if new)
148
+ ref_s = self._get_or_create_reference_encoding(audio_hash, ref_audio_bytes)
149
 
150
  # 3. Infer full text
151
  with torch.no_grad():
152
  audio = self.tts_model.infer(text, ref_s, reference_text)
153
  return audio
154
 
155
+ def stream_speech_blocking(self, text: str, ref_audio_bytes: bytes, reference_text: str, speed: float, audio_format: str) -> Generator[bytes, None, None]:
156
+ """Sentence-by-Sentence Streaming using cached reference encoding."""
157
  logger.info(f"Starting streaming synthesis for text length: {len(text)}")
158
 
159
+ # 1. Hash the audio bytes once
160
+ audio_hash = hashlib.sha256(ref_audio_bytes).hexdigest()
161
+
162
+ # 2. Get the reference encoding from cache, once for the whole stream
163
+ ref_s = self._get_or_create_reference_encoding(audio_hash, ref_audio_bytes)
164
 
165
+ # 3. Split text using the new regex method
 
 
 
166
  sentences = self._split_text_into_chunks(text)
167
 
168
  # 4. Stream chunks
 
172
 
173
  logger.debug(f"Generating streaming chunk {i+1}: '{sentence[:30]}...'")
174
 
 
175
  with torch.no_grad():
176
  audio_chunk = self.tts_model.infer(sentence, ref_s, reference_text)
177
 
 
178
  yield self._convert_to_streamable_format(audio_chunk, audio_format)
179
 
180
  logger.info("Streaming synthesis complete.")
 
300
  output_format: str = Form("wav", pattern="^(wav|mp3|flac)$"),
301
  reference_audio: UploadFile = File(...)):
302
  """
303
+ Standard blocking TTS endpoint with in-memory processing and caching.
 
304
  """
305
  if not hasattr(app.state, 'tts_wrapper'):
306
  raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
307
 
 
 
 
308
  start_time = time.time()
 
309
  try:
310
+ # 1. Convert the uploaded file to WAV directly in memory
311
+ converted_wav_buffer = await convert_to_wav_in_memory(reference_audio)
312
+ ref_audio_bytes = converted_wav_buffer.getvalue()
 
 
313
 
314
+ # 2. Offload the blocking AI process (now faster with caching)
315
  audio_data = await run_blocking_task_async(
316
  app.state.tts_wrapper.generate_speech_blocking,
317
  text,
318
+ ref_audio_bytes, # Pass bytes, not a path
319
  reference_text
320
  )
321
 
322
+ # 3. Convert to requested output format
323
  audio_bytes = await run_blocking_task_async(
324
  app.state.tts_wrapper._convert_to_streamable_format,
325
  audio_data,
326
  output_format
327
  )
328
 
 
 
 
 
 
 
 
329
  processing_time = time.time() - start_time
330
  audio_duration = len(audio_data) / SAMPLE_RATE
331
  return Response(
332
  content=audio_bytes,
333
  media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
334
  headers={
335
+ "Content-Disposition": f"attachment; filename=tts_output.{output_format}",
336
  "X-Processing-Time": f"{processing_time:.2f}s",
337
  "X-Audio-Duration": f"{audio_duration:.2f}s"
338
  }
339
  )
340
  except Exception as e:
341
  logger.error(f"Synthesis error: {e}")
 
342
  if isinstance(e, HTTPException):
343
  raise
344
  raise HTTPException(status_code=500, detail=f"Synthesis failed: {e}")
 
 
 
 
 
 
345
 
346
  @app.post("/synthesize/stream")
347
  async def stream_text_to_speech_cloning(
 
351
  output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"),
352
  reference_audio: UploadFile = File(...)):
353
  """
354
+ Sentence-by-Sentence Streaming using a parallel producer-consumer pipeline
355
+ to ensure continuous, low-latency audio flow.
356
  """
357
  if not hasattr(app.state, 'tts_wrapper'):
358
  raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
359
 
 
 
 
 
360
  try:
361
+ # Initial audio conversion is still done once, in memory.
362
+ converted_wav_buffer = await convert_to_wav_in_memory(reference_audio)
363
+ ref_audio_bytes = converted_wav_buffer.getvalue()
 
 
364
 
365
+ def stream_generator():
366
+ # 1. Create a queue to communicate between the producer and consumer.
367
+ # A small maxsize acts as a "look-ahead" buffer.
368
+ q = queue.Queue(maxsize=2)
369
+
370
+ # 2. Define the PRODUCER (The "Grill Chef")
371
+ # This function runs in a background thread to generate audio continuously.
372
+ def producer():
373
+ try:
374
+ # Get reference encoding once for the whole stream
375
+ audio_hash = hashlib.sha256(ref_audio_bytes).hexdigest()
376
+ ref_s = app.state.tts_wrapper._get_or_create_reference_encoding(audio_hash, ref_audio_bytes)
377
+
378
+ sentences = app.state.tts_wrapper._split_text_into_chunks(text)
379
+
380
+ for sentence in sentences:
381
+ # Generate the raw audio (CPU-heavy part)
382
+ with torch.no_grad():
383
+ audio_chunk = app.state.tts_wrapper.tts_model.infer(sentence, ref_s, reference_text)
384
+ # Put the finished audio (a numpy array) into the queue
385
+ q.put(audio_chunk)
386
+
387
+ except Exception as e:
388
+ logger.error(f"Error in producer thread: {e}")
389
+ # If an error occurs, put the exception in the queue to notify the consumer
390
+ q.put(e)
391
+ finally:
392
+ # 3. Signal that production is finished by putting None in the queue
393
+ q.put(None)
394
+
395
+ # 4. Start the producer in the background ThreadPoolExecutor
396
+ loop = asyncio.get_event_loop()
397
+ loop.run_in_executor(tts_executor, producer)
398
+
399
+ # 5. The main thread becomes the CONSUMER (The "Finisher")
400
+ while True:
401
+ # Get the next audio chunk from the queue (this will wait if the queue is empty)
402
+ result = q.get()
403
+
404
+ # Check for the "end of stream" signal
405
+ if result is None:
406
+ break
407
+
408
+ # Check if the producer sent an error
409
+ if isinstance(result, Exception):
410
+ logger.error(f"Terminating stream due to producer error: {result}")
411
+ raise result
412
+
413
+ # Convert the raw audio to the desired format and yield it to the user
414
+ yield app.state.tts_wrapper._convert_to_streamable_format(result, output_format)
415
+
416
+ # Return the StreamingResponse with our new high-performance generator
417
  return StreamingResponse(
418
+ stream_generator(),
419
+ media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}"
 
 
 
 
 
 
420
  )
421
 
422
  except Exception as e:
423
  logger.error(f"Streaming setup error: {e}")
 
 
 
 
 
 
 
424
  if isinstance(e, HTTPException):
425
  raise
426
  raise HTTPException(status_code=500, detail=f"Streaming synthesis failed: {e}")
 
427
 
428
  @app.get("/audio/{filename}")
429
  async def get_audio(filename: str):