Rajhuggingface4253 commited on
Commit
af37689
·
verified ·
1 Parent(s): e94e39e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -31
app.py CHANGED
@@ -355,68 +355,67 @@ async def stream_text_to_speech_cloning(
355
  """
356
  if not hasattr(app.state, 'tts_wrapper'):
357
  raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
358
-
359
- # This async generator is the final, correct implementation.
360
  async def stream_generator():
361
  loop = asyncio.get_event_loop()
362
  q = asyncio.Queue(maxsize=2)
363
 
364
- # The PRODUCER is now an async task that runs in the background.
365
  async def producer():
366
  try:
367
- # The one-time setup cost: convert and encode the reference voice.
368
- # This is done before the loop to ensure the voice is ready.
369
  converted_wav_buffer = await convert_to_wav_in_memory(reference_audio)
370
  ref_audio_bytes = converted_wav_buffer.getvalue()
371
  audio_hash = hashlib.sha256(ref_audio_bytes).hexdigest()
372
- ref_s = await loop.run_in_executor(
373
- tts_executor,
374
- app.state.tts_wrapper._get_or_create_reference_encoding,
375
- audio_hash,
376
- ref_audio_bytes
377
- )
 
 
 
 
 
 
 
378
 
379
  sentences = app.state.tts_wrapper._split_text_into_chunks(text)
380
 
381
- for sentence in sentences:
382
- # Define the blocking work for a single chunk
383
- def process_chunk():
384
- with torch.no_grad():
385
- audio_chunk = app.state.tts_wrapper.tts_model.infer(sentence, ref_s, reference_text)
386
- return app.state.tts_wrapper._convert_to_streamable_format(audio_chunk, output_format)
387
 
388
- # Offload the blocking work to the thread pool
389
- mp3_bytes = await loop.run_in_executor(tts_executor, process_chunk)
390
- # Put the finished MP3 chunk into the async queue
391
- await q.put(mp3_bytes)
392
 
393
  except Exception as e:
394
  logger.error(f"Error in producer task: {e}")
395
  await q.put(e)
396
  finally:
397
- # Signal that production is finished
398
- await q.put(None)
399
 
400
- # Start the producer as a background task. It starts working immediately.
401
  producer_task = asyncio.create_task(producer())
402
 
403
- # The main loop now acts as the CONSUMER.
404
  while True:
405
- # Await the next finished MP3 chunk from the queue.
406
  result = await q.get()
407
-
408
  if result is None:
409
  break
410
 
 
411
  if isinstance(result, Exception):
412
  logger.error(f"Terminating stream due to producer error: {result}")
413
  raise result
414
 
415
- # Yield the chunk to the user. While the network sends this,
416
- # the producer is already working on the next chunk in the background.
417
- yield result
418
 
419
- # Ensure the producer task is cleaned up.
420
  await producer_task
421
 
422
  return StreamingResponse(
 
355
  """
356
  if not hasattr(app.state, 'tts_wrapper'):
357
  raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
358
+
 
359
  async def stream_generator():
360
  loop = asyncio.get_event_loop()
361
  q = asyncio.Queue(maxsize=2)
362
 
363
+ # The PRODUCER's job is to quickly schedule work, not wait for it.
364
  async def producer():
365
  try:
 
 
366
  converted_wav_buffer = await convert_to_wav_in_memory(reference_audio)
367
  ref_audio_bytes = converted_wav_buffer.getvalue()
368
  audio_hash = hashlib.sha256(ref_audio_bytes).hexdigest()
369
+
370
+ # Check cache for reference encoding
371
+ if audio_hash in app.state.tts_wrapper.encoding_cache:
372
+ logger.info(f"Streaming Cache HIT for hash: {audio_hash[:10]}...")
373
+ ref_s = app.state.tts_wrapper.encoding_cache[audio_hash]
374
+ else:
375
+ logger.info(f"Streaming Cache MISS for hash: {audio_hash[:10]}...")
376
+ ref_s = await loop.run_in_executor(
377
+ tts_executor,
378
+ app.state.tts_wrapper.get_reference_encoding,
379
+ ref_audio_bytes
380
+ )
381
+ app.state.tts_wrapper.encoding_cache[audio_hash] = ref_s
382
 
383
  sentences = app.state.tts_wrapper._split_text_into_chunks(text)
384
 
385
+ # This function does the heavy lifting for one chunk.
386
+ def process_chunk(sentence_text):
387
+ with torch.no_grad():
388
+ audio_chunk = app.state.tts_wrapper.tts_model.infer(sentence_text, ref_s, reference_text)
389
+ return app.state.tts_wrapper._convert_to_streamable_format(audio_chunk, output_format)
 
390
 
391
+ # Schedule all chunks to be processed in the background.
392
+ for sentence in sentences:
393
+ task = loop.run_in_executor(tts_executor, process_chunk, sentence)
394
+ await q.put(task) # Put the FUTURE, not the result, in the queue.
395
 
396
  except Exception as e:
397
  logger.error(f"Error in producer task: {e}")
398
  await q.put(e)
399
  finally:
400
+ await q.put(None) # Signal that all tasks have been scheduled.
 
401
 
 
402
  producer_task = asyncio.create_task(producer())
403
 
404
+ # The CONSUMER's job is to wait for each result and yield it.
405
  while True:
 
406
  result = await q.get()
 
407
  if result is None:
408
  break
409
 
410
+ # Check if the item in the queue is a task (future) or an exception
411
  if isinstance(result, Exception):
412
  logger.error(f"Terminating stream due to producer error: {result}")
413
  raise result
414
 
415
+ # Await the result of the background task
416
+ chunk_bytes = await result
417
+ yield chunk_bytes
418
 
 
419
  await producer_task
420
 
421
  return StreamingResponse(