Rajhuggingface4253 commited on
Commit
01b3f2d
·
verified ·
1 Parent(s): c0df123

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -79
app.py CHANGED
@@ -36,7 +36,7 @@ DEVICE = "cpu"
36
  MAX_WORKERS = 2
37
  tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
38
  SAMPLE_RATE = 24000
39
- CLEANUP_THRESHOLD = 3600 # 1 hour in seconds
40
  TEMP_AUDIO_DIR = "temp_audio"
41
  GENERATED_AUDIO_DIR = "generated_audio"
42
  os.makedirs(TEMP_AUDIO_DIR, exist_ok=True)
@@ -134,7 +134,12 @@ class NeuTTSWrapper:
134
  audio_buffer.seek(0)
135
  return audio_buffer.read()
136
 
137
-
 
 
 
 
 
138
 
139
  def generate_speech_blocking(self, text: str, ref_audio_path: str, reference_text: str) -> np.ndarray:
140
  """Blocking synthesis for standard endpoint."""
@@ -147,60 +152,32 @@ class NeuTTSWrapper:
147
  audio = self.tts_model.infer(text, ref_s, reference_text)
148
  return audio
149
 
150
- def _split_into_streaming_chunks(self, text: str) -> list[str]:
151
- """
152
- Splits text into smaller, more manageable chunks for streaming.
153
- """
154
- sentences = []
155
- current_sentence = ""
156
- for char in text:
157
- current_sentence += char
158
- if char in '.!?;:':
159
- sentences.append(current_sentence.strip())
160
- current_sentence = ""
161
- if current_sentence.strip():
162
- sentences.append(current_sentence.strip())
163
- if not sentences:
164
- if ',' in text:
165
- sentences = [chunk.strip() for chunk in text.split(',') if chunk.strip()]
166
- else:
167
- chunk_size = 100
168
- sentences = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
169
- return [s for s in sentences if s]
170
-
171
- # --- NEW: Parallel Worker (Now a method of the class) ---
172
- def _synthesize_chunk_blocking(self, sentence: str, ref_s: torch.Tensor, ref_text: str) -> np.ndarray:
173
- """Worker function to synthesize a single chunk of text. Runs in a thread pool."""
174
- with torch.no_grad():
175
- # It now correctly calls the model stored in self.tts_model
176
- audio_chunk = self.tts_model.infer(sentence, ref_s, ref_text)
177
- return audio_chunk
178
-
179
- # --- NEW: Parallel Streaming Generator (Now a method of the class) ---
180
- async def stream_speech_parallel(self, text: str, ref_audio_path: str, ref_text: str, executor: ThreadPoolExecutor):
181
- """
182
- Performs streaming synthesis using a parallel producer-consumer pattern.
183
- """
184
- loop = asyncio.get_event_loop()
185
- # It now correctly calls the model's encode_reference method
186
- ref_s = await loop.run_in_executor(
187
- executor, self.tts_model.encode_reference, ref_audio_path
188
- )
189
 
190
- # It now correctly calls its own text splitting method
191
- sentences = self._split_into_streaming_chunks(text)
192
 
193
- tasks = [
194
- loop.run_in_executor(
195
- # It now correctly calls its own worker method
196
- executor, self._synthesize_chunk_blocking, sentence, ref_s, ref_text
197
- )
198
- for sentence in sentences
199
- ]
200
-
201
- for task in tasks:
202
- audio_chunk = await task
203
- yield audio_chunk
 
 
 
 
 
 
 
 
 
204
 
205
  # --- Asynchronous Offloading ---
206
 
@@ -391,49 +368,54 @@ async def text_to_speech(
391
  async def stream_text_to_speech_cloning(
392
  text: str = Form(..., min_length=1, max_length=5000),
393
  reference_text: str = Form(...),
394
- speed: float = Form(1.0, ge=0.5, le=2.0), # Kept for API compatibility, not used in this logic
395
  output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"),
396
- reference_audio: UploadFile = File(...)
397
- ):
398
  """
399
- High-performance parallel streaming endpoint using the local wrapper.
 
400
  """
401
  if not hasattr(app.state, 'tts_wrapper'):
402
  raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
403
-
 
404
  temp_ref_path = await save_upload_file_async(reference_audio)
405
- converted_wav_path = None
406
-
407
  try:
 
408
  converted_wav_path = await run_blocking_task_async(
409
- convert_to_wav_blocking, temp_ref_path
 
410
  )
411
-
 
412
  if os.path.exists(temp_ref_path):
413
  os.unlink(temp_ref_path)
414
-
415
- async def stream_generator(path_to_delete: str):
 
416
  try:
417
- # This now calls our new wrapper's parallel streaming method
418
- async for audio_chunk in app.state.tts_wrapper.stream_speech_parallel(
419
- text=text,
420
- ref_audio_path=path_to_delete,
421
- ref_text=reference_text,
422
- executor=tts_executor
 
423
  ):
424
- audio_buffer = io.BytesIO()
425
- sf.write(audio_buffer, audio_chunk, SAMPLE_RATE, format=output_format)
426
- audio_buffer.seek(0)
427
- yield audio_buffer.read()
428
-
429
  except Exception as e:
 
430
  logger.error(f"Streaming generator error: {e}")
431
- raise
432
  finally:
 
433
  if os.path.exists(path_to_delete):
434
  os.unlink(path_to_delete)
435
  logger.info(f"Cleaned up converted file: {path_to_delete}")
436
 
 
437
  return StreamingResponse(
438
  stream_generator(converted_wav_path),
439
  media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
@@ -444,17 +426,20 @@ async def stream_text_to_speech_cloning(
444
  "X-Accel-Buffering": "no"
445
  }
446
  )
447
-
448
  except Exception as e:
449
  logger.error(f"Streaming setup error: {e}")
 
450
  if os.path.exists(temp_ref_path):
451
  os.unlink(temp_ref_path)
452
  if converted_wav_path and os.path.exists(converted_wav_path):
453
  os.unlink(converted_wav_path)
454
-
 
455
  if isinstance(e, HTTPException):
456
  raise
457
  raise HTTPException(status_code=500, detail=f"Streaming synthesis failed: {e}")
 
458
 
459
  @app.get("/audio/{filename}")
460
  async def get_audio(filename: str):
 
36
  MAX_WORKERS = 2
37
  tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
38
  SAMPLE_RATE = 24000
39
+ CLEANUP_THRESHOLD = 300 # 1 hour in seconds
40
  TEMP_AUDIO_DIR = "temp_audio"
41
  GENERATED_AUDIO_DIR = "generated_audio"
42
  os.makedirs(TEMP_AUDIO_DIR, exist_ok=True)
 
134
  audio_buffer.seek(0)
135
  return audio_buffer.read()
136
 
137
+ def _split_text_into_chunks(self, text: str) -> list[str]:
138
+ """Simple sentence splitting for streaming (can be enhanced with regex)."""
139
+ sentences = [s.strip() for s in text.split('.') if s.strip()]
140
+ if not sentences:
141
+ sentences = [text.strip()]
142
+ return sentences
143
 
144
  def generate_speech_blocking(self, text: str, ref_audio_path: str, reference_text: str) -> np.ndarray:
145
  """Blocking synthesis for standard endpoint."""
 
152
  audio = self.tts_model.infer(text, ref_s, reference_text)
153
  return audio
154
 
155
+ def stream_speech_blocking(self, text: str, ref_audio_path: str, reference_text: str, speed: float, audio_format: str) -> Generator[bytes, None, None]:
156
+ """Sentence-by-Sentence Streaming (Blocking)."""
157
+ logger.info(f"Starting streaming synthesis for text length: {len(text)}")
158
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
 
 
160
 
161
+ ref_s = self.tts_model.encode_reference(ref_audio_path)
162
+
163
+ # 3. Split text
164
+ sentences = self._split_text_into_chunks(text)
165
+
166
+ # 4. Stream chunks
167
+ for i, sentence in enumerate(sentences):
168
+ if not sentence.strip():
169
+ continue
170
+
171
+ logger.debug(f"Generating streaming chunk {i+1}: '{sentence[:30]}...'")
172
+
173
+ # Infer sentence
174
+ with torch.no_grad():
175
+ audio_chunk = self.tts_model.infer(sentence, ref_s, reference_text)
176
+
177
+ # Convert and yield
178
+ yield self._convert_to_streamable_format(audio_chunk, audio_format)
179
+
180
+ logger.info("Streaming synthesis complete.")
181
 
182
  # --- Asynchronous Offloading ---
183
 
 
368
  async def stream_text_to_speech_cloning(
369
  text: str = Form(..., min_length=1, max_length=5000),
370
  reference_text: str = Form(...),
371
+ speed: float = Form(1.0, ge=0.5, le=2.0),
372
  output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"),
373
+ reference_audio: UploadFile = File(...)):
 
374
  """
375
+ Sentence-by-Sentence Streaming Endpoint.
376
+ Fixes race condition by moving cleanup into the streaming generator.
377
  """
378
  if not hasattr(app.state, 'tts_wrapper'):
379
  raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
380
+
381
+ # 1. Asynchronously save reference audio (non-blocking)
382
  temp_ref_path = await save_upload_file_async(reference_audio)
383
+ converted_wav_path = None # Initialize for cleanup
384
+
385
  try:
386
+ # 2. Convert the uploaded file (WebM, etc.) to a 24kHz WAV file
387
  converted_wav_path = await run_blocking_task_async(
388
+ convert_to_wav_blocking,
389
+ temp_ref_path
390
  )
391
+
392
+ # 2.5. CLEANUP ORIGINAL FILE IMMEDIATELY: It is no longer needed after conversion
393
  if os.path.exists(temp_ref_path):
394
  os.unlink(temp_ref_path)
395
+
396
+ # 3. Define the generator function, which will run in the thread pool
397
+ def stream_generator(path_to_delete: str):
398
  try:
399
+ # This logic uses the path_to_delete parameter, which is guaranteed to exist
400
+ for chunk_bytes in app.state.tts_wrapper.stream_speech_blocking(
401
+ text,
402
+ path_to_delete, # Pass the CONVERTED WAV path
403
+ reference_text,
404
+ speed,
405
+ output_format
406
  ):
407
+ yield chunk_bytes
 
 
 
 
408
  except Exception as e:
409
+ # Log the error and raise it to stop the stream
410
  logger.error(f"Streaming generator error: {e}")
411
+ raise # Re-raise to ensure the stream terminates
412
  finally:
413
+ # 4. **CRUCIAL FIX:** Clean up the converted file ONLY AFTER GENERATION IS DONE
414
  if os.path.exists(path_to_delete):
415
  os.unlink(path_to_delete)
416
  logger.info(f"Cleaned up converted file: {path_to_delete}")
417
 
418
+ # Return StreamingResponse, passing the path to the generator
419
  return StreamingResponse(
420
  stream_generator(converted_wav_path),
421
  media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
 
426
  "X-Accel-Buffering": "no"
427
  }
428
  )
429
+
430
  except Exception as e:
431
  logger.error(f"Streaming setup error: {e}")
432
+ # Clean up files only if the setup failed *before* starting the generator
433
  if os.path.exists(temp_ref_path):
434
  os.unlink(temp_ref_path)
435
  if converted_wav_path and os.path.exists(converted_wav_path):
436
  os.unlink(converted_wav_path)
437
+
438
+ # Reraise HTTPExceptions that may have come from the conversion step
439
  if isinstance(e, HTTPException):
440
  raise
441
  raise HTTPException(status_code=500, detail=f"Streaming synthesis failed: {e}")
442
+ # Note: The outer 'finally' block is now removed as its logic is handled in 2.5 and 4.
443
 
444
  @app.get("/audio/{filename}")
445
  async def get_audio(filename: str):