Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -152,33 +152,60 @@ class NeuTTSWrapper:
|
|
| 152 |
audio = self.tts_model.infer(text, ref_s, reference_text)
|
| 153 |
return audio
|
| 154 |
|
| 155 |
-
def
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
ref_s = self.tts_model.encode_reference(ref_audio_path)
|
| 162 |
-
|
| 163 |
-
# 3. Split text
|
| 164 |
sentences = self._split_text_into_chunks(text)
|
| 165 |
|
| 166 |
-
# 4. Stream chunks
|
| 167 |
for i, sentence in enumerate(sentences):
|
| 168 |
if not sentence.strip():
|
| 169 |
continue
|
| 170 |
|
| 171 |
-
|
|
|
|
| 172 |
|
| 173 |
-
# Infer sentence
|
| 174 |
with torch.no_grad():
|
| 175 |
audio_chunk = self.tts_model.infer(sentence, ref_s, reference_text)
|
| 176 |
|
| 177 |
-
#
|
| 178 |
-
|
| 179 |
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
# --- Asynchronous Offloading ---
|
| 183 |
|
| 184 |
async def run_blocking_task_async(func, *args, **kwargs):
|
|
@@ -368,78 +395,55 @@ async def text_to_speech(
|
|
| 368 |
async def stream_text_to_speech_cloning(
|
| 369 |
text: str = Form(..., min_length=1, max_length=5000),
|
| 370 |
reference_text: str = Form(...),
|
| 371 |
-
speed: float = Form(1.0, ge=0.5, le=2.0),
|
| 372 |
output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"),
|
| 373 |
-
reference_audio: UploadFile = File(...)
|
|
|
|
| 374 |
"""
|
| 375 |
-
|
| 376 |
-
Fixes race condition by moving cleanup into the streaming generator.
|
| 377 |
"""
|
| 378 |
if not hasattr(app.state, 'tts_wrapper'):
|
| 379 |
raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
|
| 380 |
-
|
| 381 |
-
# 1. Asynchronously save reference audio (non-blocking)
|
| 382 |
temp_ref_path = await save_upload_file_async(reference_audio)
|
| 383 |
-
converted_wav_path = None # Initialize for cleanup
|
| 384 |
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
converted_wav_path =
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
if os.path.exists(temp_ref_path):
|
| 394 |
-
os.unlink(temp_ref_path)
|
| 395 |
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
reference_text,
|
| 404 |
-
speed,
|
| 405 |
-
output_format
|
| 406 |
-
):
|
| 407 |
-
yield chunk_bytes
|
| 408 |
-
except Exception as e:
|
| 409 |
-
# Log the error and raise it to stop the stream
|
| 410 |
-
logger.error(f"Streaming generator error: {e}")
|
| 411 |
-
raise # Re-raise to ensure the stream terminates
|
| 412 |
-
finally:
|
| 413 |
-
# 4. **CRUCIAL FIX:** Clean up the converted file ONLY AFTER GENERATION IS DONE
|
| 414 |
-
if os.path.exists(path_to_delete):
|
| 415 |
-
os.unlink(path_to_delete)
|
| 416 |
-
logger.info(f"Cleaned up converted file: {path_to_delete}")
|
| 417 |
-
|
| 418 |
-
# Return StreamingResponse, passing the path to the generator
|
| 419 |
-
return StreamingResponse(
|
| 420 |
-
stream_generator(converted_wav_path),
|
| 421 |
-
media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
|
| 422 |
-
headers={
|
| 423 |
-
"Content-Disposition": "attachment; filename=tts_live_stream.mp3",
|
| 424 |
-
"Transfer-Encoding": "chunked",
|
| 425 |
-
"Cache-Control": "no-cache",
|
| 426 |
-
"X-Accel-Buffering": "no"
|
| 427 |
-
}
|
| 428 |
-
)
|
| 429 |
-
|
| 430 |
-
except Exception as e:
|
| 431 |
-
logger.error(f"Streaming setup error: {e}")
|
| 432 |
-
# Clean up files only if the setup failed *before* starting the generator
|
| 433 |
-
if os.path.exists(temp_ref_path):
|
| 434 |
-
os.unlink(temp_ref_path)
|
| 435 |
-
if converted_wav_path and os.path.exists(converted_wav_path):
|
| 436 |
-
os.unlink(converted_wav_path)
|
| 437 |
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
|
| 444 |
@app.get("/audio/{filename}")
|
| 445 |
async def get_audio(filename: str):
|
|
|
|
| 152 |
audio = self.tts_model.infer(text, ref_s, reference_text)
|
| 153 |
return audio
|
| 154 |
|
| 155 |
+
def stream_producer(self, queue: asyncio.Queue, text: str, ref_audio_path: str, reference_text: str):
|
| 156 |
+
"""
|
| 157 |
+
[PRODUCER] Runs in a thread, generates audio chunks, and puts them into a queue.
|
| 158 |
+
"""
|
| 159 |
+
try:
|
| 160 |
+
logger.info("Starting audio production thread...")
|
| 161 |
ref_s = self.tts_model.encode_reference(ref_audio_path)
|
|
|
|
|
|
|
| 162 |
sentences = self._split_text_into_chunks(text)
|
| 163 |
|
|
|
|
| 164 |
for i, sentence in enumerate(sentences):
|
| 165 |
if not sentence.strip():
|
| 166 |
continue
|
| 167 |
|
| 168 |
+
# RESTORED: The per-chunk debug log
|
| 169 |
+
logger.debug(f"Producing chunk {i+1}/{len(sentences)}: '{sentence[:30]}...'")
|
| 170 |
|
|
|
|
| 171 |
with torch.no_grad():
|
| 172 |
audio_chunk = self.tts_model.infer(sentence, ref_s, reference_text)
|
| 173 |
|
| 174 |
+
# Put the raw audio chunk into the queue for the consumer
|
| 175 |
+
queue.put_nowait(audio_chunk)
|
| 176 |
|
| 177 |
+
except Exception as e:
|
| 178 |
+
logger.error(f"Error in producer thread: {e}")
|
| 179 |
+
queue.put_nowait(e)
|
| 180 |
+
finally:
|
| 181 |
+
# Signal that production is finished by putting a sentinel value (None)
|
| 182 |
+
queue.put_nowait(None)
|
| 183 |
|
| 184 |
+
async def stream_consumer(queue: asyncio.Queue, output_format: str):
|
| 185 |
+
"""
|
| 186 |
+
[CONSUMER] Asynchronously gets items from the queue and yields them to the client.
|
| 187 |
+
"""
|
| 188 |
+
logger.info("Starting audio consumption...")
|
| 189 |
+
while True:
|
| 190 |
+
# Wait for an item to appear in the queue
|
| 191 |
+
item = await queue.get()
|
| 192 |
+
|
| 193 |
+
if isinstance(item, Exception):
|
| 194 |
+
logger.error(f"Consumer received an error from the producer: {item}")
|
| 195 |
+
break
|
| 196 |
+
|
| 197 |
+
if item is None:
|
| 198 |
+
# Sentinel value received, meaning the stream is finished
|
| 199 |
+
logger.info("Consumer received end-of-stream signal.")
|
| 200 |
+
break
|
| 201 |
+
|
| 202 |
+
# We have a valid audio chunk, convert it to the desired format
|
| 203 |
+
audio_bytes = await run_blocking_task_async(
|
| 204 |
+
app.state.tts_wrapper._convert_to_streamable_format,
|
| 205 |
+
item, # The NumPy array from the queue
|
| 206 |
+
output_format
|
| 207 |
+
)
|
| 208 |
+
yield audio_bytes
|
| 209 |
# --- Asynchronous Offloading ---
|
| 210 |
|
| 211 |
async def run_blocking_task_async(func, *args, **kwargs):
|
|
|
|
| 395 |
async def stream_text_to_speech_cloning(
|
| 396 |
text: str = Form(..., min_length=1, max_length=5000),
|
| 397 |
reference_text: str = Form(...),
|
|
|
|
| 398 |
output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"),
|
| 399 |
+
reference_audio: UploadFile = File(...)
|
| 400 |
+
):
|
| 401 |
"""
|
| 402 |
+
TRUE streaming endpoint using the definitive producer-consumer pattern.
|
|
|
|
| 403 |
"""
|
| 404 |
if not hasattr(app.state, 'tts_wrapper'):
|
| 405 |
raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
|
| 406 |
+
|
|
|
|
| 407 |
temp_ref_path = await save_upload_file_async(reference_audio)
|
|
|
|
| 408 |
|
| 409 |
+
async def cleanup_and_run_stream():
|
| 410 |
+
"""A nested async generator to handle the entire producer-consumer lifecycle and cleanup."""
|
| 411 |
+
converted_wav_path = None
|
| 412 |
+
queue = asyncio.Queue()
|
| 413 |
+
loop = asyncio.get_event_loop()
|
| 414 |
+
try:
|
| 415 |
+
# Convert the uploaded file to the required WAV format
|
| 416 |
+
converted_wav_path = await run_blocking_task_async(convert_to_wav_blocking, temp_ref_path)
|
|
|
|
|
|
|
| 417 |
|
| 418 |
+
# Start the producer (the model) in a background thread.
|
| 419 |
+
# It will start putting audio chunks into the queue.
|
| 420 |
+
loop.run_in_executor(
|
| 421 |
+
tts_executor,
|
| 422 |
+
app.state.tts_wrapper.stream_producer,
|
| 423 |
+
queue, text, converted_wav_path, reference_text
|
| 424 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
|
| 426 |
+
# Start the consumer, which gets chunks from the queue and yields them to the client.
|
| 427 |
+
async for chunk in stream_consumer(queue, output_format):
|
| 428 |
+
yield chunk
|
| 429 |
+
|
| 430 |
+
finally:
|
| 431 |
+
# This block guarantees cleanup after the stream is finished or fails
|
| 432 |
+
if os.path.exists(temp_ref_path):
|
| 433 |
+
os.unlink(temp_ref_path)
|
| 434 |
+
if converted_wav_path and os.path.exists(converted_wav_path):
|
| 435 |
+
os.unlink(converted_wav_path)
|
| 436 |
+
logger.info("Cleaned up temporary stream files.")
|
| 437 |
+
|
| 438 |
+
return StreamingResponse(
|
| 439 |
+
cleanup_and_run_stream(),
|
| 440 |
+
media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
|
| 441 |
+
headers={
|
| 442 |
+
"Content-Disposition": "attachment; filename=tts_live_stream.mp3",
|
| 443 |
+
"Cache-Control": "no-cache",
|
| 444 |
+
"X-Accel-Buffering": "no" # Header to prevent proxy buffering
|
| 445 |
+
}
|
| 446 |
+
)
|
| 447 |
|
| 448 |
@app.get("/audio/{filename}")
|
| 449 |
async def get_audio(filename: str):
|