Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -36,7 +36,7 @@ DEVICE = "cpu"
|
|
| 36 |
MAX_WORKERS = 2
|
| 37 |
tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
|
| 38 |
SAMPLE_RATE = 24000
|
| 39 |
-
CLEANUP_THRESHOLD =
|
| 40 |
TEMP_AUDIO_DIR = "temp_audio"
|
| 41 |
GENERATED_AUDIO_DIR = "generated_audio"
|
| 42 |
os.makedirs(TEMP_AUDIO_DIR, exist_ok=True)
|
|
@@ -134,7 +134,12 @@ class NeuTTSWrapper:
|
|
| 134 |
audio_buffer.seek(0)
|
| 135 |
return audio_buffer.read()
|
| 136 |
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
def generate_speech_blocking(self, text: str, ref_audio_path: str, reference_text: str) -> np.ndarray:
|
| 140 |
"""Blocking synthesis for standard endpoint."""
|
|
@@ -147,60 +152,32 @@ class NeuTTSWrapper:
|
|
| 147 |
audio = self.tts_model.infer(text, ref_s, reference_text)
|
| 148 |
return audio
|
| 149 |
|
| 150 |
-
def
|
| 151 |
-
"""
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
sentences = []
|
| 155 |
-
current_sentence = ""
|
| 156 |
-
for char in text:
|
| 157 |
-
current_sentence += char
|
| 158 |
-
if char in '.!?;:':
|
| 159 |
-
sentences.append(current_sentence.strip())
|
| 160 |
-
current_sentence = ""
|
| 161 |
-
if current_sentence.strip():
|
| 162 |
-
sentences.append(current_sentence.strip())
|
| 163 |
-
if not sentences:
|
| 164 |
-
if ',' in text:
|
| 165 |
-
sentences = [chunk.strip() for chunk in text.split(',') if chunk.strip()]
|
| 166 |
-
else:
|
| 167 |
-
chunk_size = 100
|
| 168 |
-
sentences = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
|
| 169 |
-
return [s for s in sentences if s]
|
| 170 |
-
|
| 171 |
-
# --- NEW: Parallel Worker (Now a method of the class) ---
|
| 172 |
-
def _synthesize_chunk_blocking(self, sentence: str, ref_s: torch.Tensor, ref_text: str) -> np.ndarray:
|
| 173 |
-
"""Worker function to synthesize a single chunk of text. Runs in a thread pool."""
|
| 174 |
-
with torch.no_grad():
|
| 175 |
-
# It now correctly calls the model stored in self.tts_model
|
| 176 |
-
audio_chunk = self.tts_model.infer(sentence, ref_s, ref_text)
|
| 177 |
-
return audio_chunk
|
| 178 |
-
|
| 179 |
-
# --- NEW: Parallel Streaming Generator (Now a method of the class) ---
|
| 180 |
-
async def stream_speech_parallel(self, text: str, ref_audio_path: str, ref_text: str, executor: ThreadPoolExecutor):
|
| 181 |
-
"""
|
| 182 |
-
Performs streaming synthesis using a parallel producer-consumer pattern.
|
| 183 |
-
"""
|
| 184 |
-
loop = asyncio.get_event_loop()
|
| 185 |
-
# It now correctly calls the model's encode_reference method
|
| 186 |
-
ref_s = await loop.run_in_executor(
|
| 187 |
-
executor, self.tts_model.encode_reference, ref_audio_path
|
| 188 |
-
)
|
| 189 |
|
| 190 |
-
# It now correctly calls its own text splitting method
|
| 191 |
-
sentences = self._split_into_streaming_chunks(text)
|
| 192 |
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
# --- Asynchronous Offloading ---
|
| 206 |
|
|
@@ -391,49 +368,54 @@ async def text_to_speech(
|
|
| 391 |
async def stream_text_to_speech_cloning(
|
| 392 |
text: str = Form(..., min_length=1, max_length=5000),
|
| 393 |
reference_text: str = Form(...),
|
| 394 |
-
speed: float = Form(1.0, ge=0.5, le=2.0),
|
| 395 |
output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"),
|
| 396 |
-
reference_audio: UploadFile = File(...)
|
| 397 |
-
):
|
| 398 |
"""
|
| 399 |
-
|
|
|
|
| 400 |
"""
|
| 401 |
if not hasattr(app.state, 'tts_wrapper'):
|
| 402 |
raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
|
| 403 |
-
|
|
|
|
| 404 |
temp_ref_path = await save_upload_file_async(reference_audio)
|
| 405 |
-
converted_wav_path = None
|
| 406 |
-
|
| 407 |
try:
|
|
|
|
| 408 |
converted_wav_path = await run_blocking_task_async(
|
| 409 |
-
convert_to_wav_blocking,
|
|
|
|
| 410 |
)
|
| 411 |
-
|
|
|
|
| 412 |
if os.path.exists(temp_ref_path):
|
| 413 |
os.unlink(temp_ref_path)
|
| 414 |
-
|
| 415 |
-
|
|
|
|
| 416 |
try:
|
| 417 |
-
# This
|
| 418 |
-
|
| 419 |
-
text
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
|
|
|
| 423 |
):
|
| 424 |
-
|
| 425 |
-
sf.write(audio_buffer, audio_chunk, SAMPLE_RATE, format=output_format)
|
| 426 |
-
audio_buffer.seek(0)
|
| 427 |
-
yield audio_buffer.read()
|
| 428 |
-
|
| 429 |
except Exception as e:
|
|
|
|
| 430 |
logger.error(f"Streaming generator error: {e}")
|
| 431 |
-
raise
|
| 432 |
finally:
|
|
|
|
| 433 |
if os.path.exists(path_to_delete):
|
| 434 |
os.unlink(path_to_delete)
|
| 435 |
logger.info(f"Cleaned up converted file: {path_to_delete}")
|
| 436 |
|
|
|
|
| 437 |
return StreamingResponse(
|
| 438 |
stream_generator(converted_wav_path),
|
| 439 |
media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
|
|
@@ -444,17 +426,20 @@ async def stream_text_to_speech_cloning(
|
|
| 444 |
"X-Accel-Buffering": "no"
|
| 445 |
}
|
| 446 |
)
|
| 447 |
-
|
| 448 |
except Exception as e:
|
| 449 |
logger.error(f"Streaming setup error: {e}")
|
|
|
|
| 450 |
if os.path.exists(temp_ref_path):
|
| 451 |
os.unlink(temp_ref_path)
|
| 452 |
if converted_wav_path and os.path.exists(converted_wav_path):
|
| 453 |
os.unlink(converted_wav_path)
|
| 454 |
-
|
|
|
|
| 455 |
if isinstance(e, HTTPException):
|
| 456 |
raise
|
| 457 |
raise HTTPException(status_code=500, detail=f"Streaming synthesis failed: {e}")
|
|
|
|
| 458 |
|
| 459 |
@app.get("/audio/{filename}")
|
| 460 |
async def get_audio(filename: str):
|
|
|
|
| 36 |
MAX_WORKERS = 2
|
| 37 |
tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
|
| 38 |
SAMPLE_RATE = 24000
|
| 39 |
+
CLEANUP_THRESHOLD = 300 # 1 hour in seconds
|
| 40 |
TEMP_AUDIO_DIR = "temp_audio"
|
| 41 |
GENERATED_AUDIO_DIR = "generated_audio"
|
| 42 |
os.makedirs(TEMP_AUDIO_DIR, exist_ok=True)
|
|
|
|
| 134 |
audio_buffer.seek(0)
|
| 135 |
return audio_buffer.read()
|
| 136 |
|
| 137 |
+
def _split_text_into_chunks(self, text: str) -> list[str]:
|
| 138 |
+
"""Simple sentence splitting for streaming (can be enhanced with regex)."""
|
| 139 |
+
sentences = [s.strip() for s in text.split('.') if s.strip()]
|
| 140 |
+
if not sentences:
|
| 141 |
+
sentences = [text.strip()]
|
| 142 |
+
return sentences
|
| 143 |
|
| 144 |
def generate_speech_blocking(self, text: str, ref_audio_path: str, reference_text: str) -> np.ndarray:
|
| 145 |
"""Blocking synthesis for standard endpoint."""
|
|
|
|
| 152 |
audio = self.tts_model.infer(text, ref_s, reference_text)
|
| 153 |
return audio
|
| 154 |
|
| 155 |
+
def stream_speech_blocking(self, text: str, ref_audio_path: str, reference_text: str, speed: float, audio_format: str) -> Generator[bytes, None, None]:
|
| 156 |
+
"""Sentence-by-Sentence Streaming (Blocking)."""
|
| 157 |
+
logger.info(f"Starting streaming synthesis for text length: {len(text)}")
|
| 158 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
|
|
|
|
|
|
| 160 |
|
| 161 |
+
ref_s = self.tts_model.encode_reference(ref_audio_path)
|
| 162 |
+
|
| 163 |
+
# 3. Split text
|
| 164 |
+
sentences = self._split_text_into_chunks(text)
|
| 165 |
+
|
| 166 |
+
# 4. Stream chunks
|
| 167 |
+
for i, sentence in enumerate(sentences):
|
| 168 |
+
if not sentence.strip():
|
| 169 |
+
continue
|
| 170 |
+
|
| 171 |
+
logger.debug(f"Generating streaming chunk {i+1}: '{sentence[:30]}...'")
|
| 172 |
+
|
| 173 |
+
# Infer sentence
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
audio_chunk = self.tts_model.infer(sentence, ref_s, reference_text)
|
| 176 |
+
|
| 177 |
+
# Convert and yield
|
| 178 |
+
yield self._convert_to_streamable_format(audio_chunk, audio_format)
|
| 179 |
+
|
| 180 |
+
logger.info("Streaming synthesis complete.")
|
| 181 |
|
| 182 |
# --- Asynchronous Offloading ---
|
| 183 |
|
|
|
|
| 368 |
async def stream_text_to_speech_cloning(
|
| 369 |
text: str = Form(..., min_length=1, max_length=5000),
|
| 370 |
reference_text: str = Form(...),
|
| 371 |
+
speed: float = Form(1.0, ge=0.5, le=2.0),
|
| 372 |
output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"),
|
| 373 |
+
reference_audio: UploadFile = File(...)):
|
|
|
|
| 374 |
"""
|
| 375 |
+
Sentence-by-Sentence Streaming Endpoint.
|
| 376 |
+
Fixes race condition by moving cleanup into the streaming generator.
|
| 377 |
"""
|
| 378 |
if not hasattr(app.state, 'tts_wrapper'):
|
| 379 |
raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
|
| 380 |
+
|
| 381 |
+
# 1. Asynchronously save reference audio (non-blocking)
|
| 382 |
temp_ref_path = await save_upload_file_async(reference_audio)
|
| 383 |
+
converted_wav_path = None # Initialize for cleanup
|
| 384 |
+
|
| 385 |
try:
|
| 386 |
+
# 2. Convert the uploaded file (WebM, etc.) to a 24kHz WAV file
|
| 387 |
converted_wav_path = await run_blocking_task_async(
|
| 388 |
+
convert_to_wav_blocking,
|
| 389 |
+
temp_ref_path
|
| 390 |
)
|
| 391 |
+
|
| 392 |
+
# 2.5. CLEANUP ORIGINAL FILE IMMEDIATELY: It is no longer needed after conversion
|
| 393 |
if os.path.exists(temp_ref_path):
|
| 394 |
os.unlink(temp_ref_path)
|
| 395 |
+
|
| 396 |
+
# 3. Define the generator function, which will run in the thread pool
|
| 397 |
+
def stream_generator(path_to_delete: str):
|
| 398 |
try:
|
| 399 |
+
# This logic uses the path_to_delete parameter, which is guaranteed to exist
|
| 400 |
+
for chunk_bytes in app.state.tts_wrapper.stream_speech_blocking(
|
| 401 |
+
text,
|
| 402 |
+
path_to_delete, # Pass the CONVERTED WAV path
|
| 403 |
+
reference_text,
|
| 404 |
+
speed,
|
| 405 |
+
output_format
|
| 406 |
):
|
| 407 |
+
yield chunk_bytes
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
except Exception as e:
|
| 409 |
+
# Log the error and raise it to stop the stream
|
| 410 |
logger.error(f"Streaming generator error: {e}")
|
| 411 |
+
raise # Re-raise to ensure the stream terminates
|
| 412 |
finally:
|
| 413 |
+
# 4. **CRUCIAL FIX:** Clean up the converted file ONLY AFTER GENERATION IS DONE
|
| 414 |
if os.path.exists(path_to_delete):
|
| 415 |
os.unlink(path_to_delete)
|
| 416 |
logger.info(f"Cleaned up converted file: {path_to_delete}")
|
| 417 |
|
| 418 |
+
# Return StreamingResponse, passing the path to the generator
|
| 419 |
return StreamingResponse(
|
| 420 |
stream_generator(converted_wav_path),
|
| 421 |
media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
|
|
|
|
| 426 |
"X-Accel-Buffering": "no"
|
| 427 |
}
|
| 428 |
)
|
| 429 |
+
|
| 430 |
except Exception as e:
|
| 431 |
logger.error(f"Streaming setup error: {e}")
|
| 432 |
+
# Clean up files only if the setup failed *before* starting the generator
|
| 433 |
if os.path.exists(temp_ref_path):
|
| 434 |
os.unlink(temp_ref_path)
|
| 435 |
if converted_wav_path and os.path.exists(converted_wav_path):
|
| 436 |
os.unlink(converted_wav_path)
|
| 437 |
+
|
| 438 |
+
# Reraise HTTPExceptions that may have come from the conversion step
|
| 439 |
if isinstance(e, HTTPException):
|
| 440 |
raise
|
| 441 |
raise HTTPException(status_code=500, detail=f"Streaming synthesis failed: {e}")
|
| 442 |
+
# Note: The outer 'finally' block is now removed as its logic is handled in 2.5 and 4.
|
| 443 |
|
| 444 |
@app.get("/audio/{filename}")
|
| 445 |
async def get_audio(filename: str):
|