Rajhuggingface4253 commited on
Commit
dacdc6d
·
verified ·
1 Parent(s): c36a042

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -77
app.py CHANGED
@@ -152,33 +152,60 @@ class NeuTTSWrapper:
152
  audio = self.tts_model.infer(text, ref_s, reference_text)
153
  return audio
154
 
155
- def stream_speech_blocking(self, text: str, ref_audio_path: str, reference_text: str, speed: float, audio_format: str) -> Generator[bytes, None, None]:
156
- """Sentence-by-Sentence Streaming (Blocking)."""
157
- logger.info(f"Starting streaming synthesis for text length: {len(text)}")
158
-
159
-
160
-
161
  ref_s = self.tts_model.encode_reference(ref_audio_path)
162
-
163
- # 3. Split text
164
  sentences = self._split_text_into_chunks(text)
165
 
166
- # 4. Stream chunks
167
  for i, sentence in enumerate(sentences):
168
  if not sentence.strip():
169
  continue
170
 
171
- logger.debug(f"Generating streaming chunk {i+1}: '{sentence[:30]}...'")
 
172
 
173
- # Infer sentence
174
  with torch.no_grad():
175
  audio_chunk = self.tts_model.infer(sentence, ref_s, reference_text)
176
 
177
- # Convert and yield
178
- yield self._convert_to_streamable_format(audio_chunk, audio_format)
179
 
180
- logger.info("Streaming synthesis complete.")
 
 
 
 
 
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  # --- Asynchronous Offloading ---
183
 
184
  async def run_blocking_task_async(func, *args, **kwargs):
@@ -368,78 +395,55 @@ async def text_to_speech(
368
  async def stream_text_to_speech_cloning(
369
  text: str = Form(..., min_length=1, max_length=5000),
370
  reference_text: str = Form(...),
371
- speed: float = Form(1.0, ge=0.5, le=2.0),
372
  output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"),
373
- reference_audio: UploadFile = File(...)):
 
374
  """
375
- Sentence-by-Sentence Streaming Endpoint.
376
- Fixes race condition by moving cleanup into the streaming generator.
377
  """
378
  if not hasattr(app.state, 'tts_wrapper'):
379
  raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
380
-
381
- # 1. Asynchronously save reference audio (non-blocking)
382
  temp_ref_path = await save_upload_file_async(reference_audio)
383
- converted_wav_path = None # Initialize for cleanup
384
 
385
- try:
386
- # 2. Convert the uploaded file (WebM, etc.) to a 24kHz WAV file
387
- converted_wav_path = await run_blocking_task_async(
388
- convert_to_wav_blocking,
389
- temp_ref_path
390
- )
391
-
392
- # 2.5. CLEANUP ORIGINAL FILE IMMEDIATELY: It is no longer needed after conversion
393
- if os.path.exists(temp_ref_path):
394
- os.unlink(temp_ref_path)
395
 
396
- # 3. Define the generator function, which will run in the thread pool
397
- def stream_generator(path_to_delete: str):
398
- try:
399
- # This logic uses the path_to_delete parameter, which is guaranteed to exist
400
- for chunk_bytes in app.state.tts_wrapper.stream_speech_blocking(
401
- text,
402
- path_to_delete, # Pass the CONVERTED WAV path
403
- reference_text,
404
- speed,
405
- output_format
406
- ):
407
- yield chunk_bytes
408
- except Exception as e:
409
- # Log the error and raise it to stop the stream
410
- logger.error(f"Streaming generator error: {e}")
411
- raise # Re-raise to ensure the stream terminates
412
- finally:
413
- # 4. **CRUCIAL FIX:** Clean up the converted file ONLY AFTER GENERATION IS DONE
414
- if os.path.exists(path_to_delete):
415
- os.unlink(path_to_delete)
416
- logger.info(f"Cleaned up converted file: {path_to_delete}")
417
-
418
- # Return StreamingResponse, passing the path to the generator
419
- return StreamingResponse(
420
- stream_generator(converted_wav_path),
421
- media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
422
- headers={
423
- "Content-Disposition": "attachment; filename=tts_live_stream.mp3",
424
- "Transfer-Encoding": "chunked",
425
- "Cache-Control": "no-cache",
426
- "X-Accel-Buffering": "no"
427
- }
428
- )
429
-
430
- except Exception as e:
431
- logger.error(f"Streaming setup error: {e}")
432
- # Clean up files only if the setup failed *before* starting the generator
433
- if os.path.exists(temp_ref_path):
434
- os.unlink(temp_ref_path)
435
- if converted_wav_path and os.path.exists(converted_wav_path):
436
- os.unlink(converted_wav_path)
437
 
438
- # Reraise HTTPExceptions that may have come from the conversion step
439
- if isinstance(e, HTTPException):
440
- raise
441
- raise HTTPException(status_code=500, detail=f"Streaming synthesis failed: {e}")
442
- # Note: The outer 'finally' block is now removed as its logic is handled in 2.5 and 4.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
 
444
  @app.get("/audio/{filename}")
445
  async def get_audio(filename: str):
 
152
  audio = self.tts_model.infer(text, ref_s, reference_text)
153
  return audio
154
 
155
+ def stream_producer(self, queue: asyncio.Queue, text: str, ref_audio_path: str, reference_text: str):
156
+ """
157
+ [PRODUCER] Runs in a thread, generates audio chunks, and puts them into a queue.
158
+ """
159
+ try:
160
+ logger.info("Starting audio production thread...")
161
  ref_s = self.tts_model.encode_reference(ref_audio_path)
 
 
162
  sentences = self._split_text_into_chunks(text)
163
 
 
164
  for i, sentence in enumerate(sentences):
165
  if not sentence.strip():
166
  continue
167
 
168
+ # RESTORED: The per-chunk debug log
169
+ logger.debug(f"Producing chunk {i+1}/{len(sentences)}: '{sentence[:30]}...'")
170
 
 
171
  with torch.no_grad():
172
  audio_chunk = self.tts_model.infer(sentence, ref_s, reference_text)
173
 
174
+ # Put the raw audio chunk into the queue for the consumer
175
+ queue.put_nowait(audio_chunk)
176
 
177
+ except Exception as e:
178
+ logger.error(f"Error in producer thread: {e}")
179
+ queue.put_nowait(e)
180
+ finally:
181
+ # Signal that production is finished by putting a sentinel value (None)
182
+ queue.put_nowait(None)
183
 
184
+ async def stream_consumer(queue: asyncio.Queue, output_format: str):
185
+ """
186
+ [CONSUMER] Asynchronously gets items from the queue and yields them to the client.
187
+ """
188
+ logger.info("Starting audio consumption...")
189
+ while True:
190
+ # Wait for an item to appear in the queue
191
+ item = await queue.get()
192
+
193
+ if isinstance(item, Exception):
194
+ logger.error(f"Consumer received an error from the producer: {item}")
195
+ break
196
+
197
+ if item is None:
198
+ # Sentinel value received, meaning the stream is finished
199
+ logger.info("Consumer received end-of-stream signal.")
200
+ break
201
+
202
+ # We have a valid audio chunk, convert it to the desired format
203
+ audio_bytes = await run_blocking_task_async(
204
+ app.state.tts_wrapper._convert_to_streamable_format,
205
+ item, # The NumPy array from the queue
206
+ output_format
207
+ )
208
+ yield audio_bytes
209
  # --- Asynchronous Offloading ---
210
 
211
  async def run_blocking_task_async(func, *args, **kwargs):
 
395
  async def stream_text_to_speech_cloning(
396
  text: str = Form(..., min_length=1, max_length=5000),
397
  reference_text: str = Form(...),
 
398
  output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"),
399
+ reference_audio: UploadFile = File(...)
400
+ ):
401
  """
402
+ TRUE streaming endpoint using the definitive producer-consumer pattern.
 
403
  """
404
  if not hasattr(app.state, 'tts_wrapper'):
405
  raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
406
+
 
407
  temp_ref_path = await save_upload_file_async(reference_audio)
 
408
 
409
+ async def cleanup_and_run_stream():
410
+ """A nested async generator to handle the entire producer-consumer lifecycle and cleanup."""
411
+ converted_wav_path = None
412
+ queue = asyncio.Queue()
413
+ loop = asyncio.get_event_loop()
414
+ try:
415
+ # Convert the uploaded file to the required WAV format
416
+ converted_wav_path = await run_blocking_task_async(convert_to_wav_blocking, temp_ref_path)
 
 
417
 
418
+ # Start the producer (the model) in a background thread.
419
+ # It will start putting audio chunks into the queue.
420
+ loop.run_in_executor(
421
+ tts_executor,
422
+ app.state.tts_wrapper.stream_producer,
423
+ queue, text, converted_wav_path, reference_text
424
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
 
426
+ # Start the consumer, which gets chunks from the queue and yields them to the client.
427
+ async for chunk in stream_consumer(queue, output_format):
428
+ yield chunk
429
+
430
+ finally:
431
+ # This block guarantees cleanup after the stream is finished or fails
432
+ if os.path.exists(temp_ref_path):
433
+ os.unlink(temp_ref_path)
434
+ if converted_wav_path and os.path.exists(converted_wav_path):
435
+ os.unlink(converted_wav_path)
436
+ logger.info("Cleaned up temporary stream files.")
437
+
438
+ return StreamingResponse(
439
+ cleanup_and_run_stream(),
440
+ media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
441
+ headers={
442
+ "Content-Disposition": "attachment; filename=tts_live_stream.mp3",
443
+ "Cache-Control": "no-cache",
444
+ "X-Accel-Buffering": "no" # Header to prevent proxy buffering
445
+ }
446
+ )
447
 
448
  @app.get("/audio/{filename}")
449
  async def get_audio(filename: str):