Rajhuggingface4253 commited on
Commit
24bb5f8
·
verified ·
1 Parent(s): b9c3cb2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -108
app.py CHANGED
@@ -2,19 +2,16 @@ import os
2
  import io
3
  import asyncio
4
  import time
5
- import shutil
6
  import numpy as np
7
  import psutil
8
  import soundfile as sf
9
  import subprocess
10
- import tempfile
11
  from concurrent.futures import ThreadPoolExecutor
12
- from typing import Optional, Generator
13
  from contextlib import asynccontextmanager
14
  import logging
15
- import aiofiles
16
  import torch
17
- from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Query
18
  from fastapi.responses import Response, StreamingResponse
19
  from fastapi.middleware.cors import CORSMiddleware
20
  from pydantic import BaseModel, Field
@@ -38,16 +35,10 @@ DEVICE = "cpu"
38
  MAX_WORKERS = 2
39
  tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
40
  SAMPLE_RATE = 24000
41
- CLEANUP_THRESHOLD = 300 # 1 hour in seconds
42
- TEMP_AUDIO_DIR = "temp_audio"
43
- GENERATED_AUDIO_DIR = "generated_audio"
44
- os.makedirs(TEMP_AUDIO_DIR, exist_ok=True)
45
- os.makedirs(GENERATED_AUDIO_DIR, exist_ok=True)
46
 
47
  class TTSRequestModel(BaseModel):
48
  """Model for non-file inputs to synthesis and streaming."""
49
  text: str = Field(..., min_length=1, max_length=1000)
50
- speed: float = Field(default=1.0, ge=0.5, le=2.0)
51
  output_format: str = Field(default="wav", pattern="^(wav|mp3|flac)$")
52
 
53
 
@@ -151,32 +142,6 @@ class NeuTTSWrapper:
151
  audio = self.tts_model.infer(text, ref_s, reference_text)
152
  return audio
153
 
154
- def stream_speech_blocking(self, text: str, ref_audio_bytes: bytes, reference_text: str, speed: float, audio_format: str) -> Generator[bytes, None, None]:
155
- """Sentence-by-Sentence Streaming using cached reference encoding."""
156
- logger.info(f"Starting streaming synthesis for text length: {len(text)}")
157
-
158
- # 1. Hash the audio bytes once
159
- audio_hash = hashlib.sha256(ref_audio_bytes).hexdigest()
160
-
161
- # 2. Get the reference encoding from cache, once for the whole stream
162
- ref_s = self._get_or_create_reference_encoding(audio_hash, ref_audio_bytes)
163
-
164
- # 3. Split text using the new regex method
165
- sentences = self._split_text_into_chunks(text)
166
-
167
- # 4. Stream chunks
168
- for i, sentence in enumerate(sentences):
169
- if not sentence.strip():
170
- continue
171
-
172
- logger.debug(f"Generating streaming chunk {i+1}: '{sentence[:30]}...'")
173
-
174
- with torch.no_grad():
175
- audio_chunk = self.tts_model.infer(sentence, ref_s, reference_text)
176
-
177
- yield self._convert_to_streamable_format(audio_chunk, audio_format)
178
-
179
- logger.info("Streaming synthesis complete.")
180
 
181
  # --- Asynchronous Offloading ---
182
 
@@ -188,18 +153,6 @@ async def run_blocking_task_async(func, *args, **kwargs):
188
  lambda: func(*args, **kwargs)
189
  )
190
 
191
- async def save_upload_file_async(upload_file: UploadFile) -> str:
192
- """Asynchronously saves the UploadFile to disk."""
193
- temp_filename = os.path.join(TEMP_AUDIO_DIR, f"{time.time()}_{upload_file.filename}")
194
- try:
195
- # Use asyncio to read the file chunks in a non-blocking manner
196
- async with aiofiles.open(temp_filename, 'wb') as out_file:
197
- while content := await upload_file.read(1024 * 1024):
198
- await out_file.write(content)
199
- return temp_filename
200
- except Exception as e:
201
- logger.error(f"Error saving file: {e}")
202
- raise HTTPException(status_code=500, detail="Could not save reference audio file")
203
 
204
  # --- FastAPI Lifespan Manager (Kokoro Feature) ---
205
 
@@ -262,31 +215,6 @@ async def health_check():
262
  }
263
  }
264
 
265
- @app.delete("/cleanup")
266
- async def cleanup_files():
267
- """Maintenance endpoint to remove old generated and temporary files."""
268
- await run_blocking_task_async(cleanup_files_blocking)
269
- return {"message": "Cleanup initiated successfully."}
270
-
271
- def cleanup_files_blocking():
272
- """Blocking file cleanup logic (original NeuTTS feature)."""
273
- now = time.time()
274
- deleted_count = 0
275
-
276
- for directory in [GENERATED_AUDIO_DIR, TEMP_AUDIO_DIR]:
277
- for filename in os.listdir(directory):
278
- filepath = os.path.join(directory, filename)
279
- if os.path.isfile(filepath):
280
- try:
281
- # Original cleanup logic: delete if older than CLEANUP_THRESHOLD
282
- if now - os.path.getctime(filepath) > CLEANUP_THRESHOLD:
283
- os.remove(filepath)
284
- deleted_count += 1
285
- except Exception as e:
286
- logger.warning(f"Failed to delete {filepath}: {e}")
287
-
288
- logger.info(f"Cleanup completed: {deleted_count} files removed.")
289
- return deleted_count
290
 
291
 
292
  # --- Core Synthesis Endpoints ---
@@ -295,7 +223,6 @@ def cleanup_files_blocking():
295
  async def text_to_speech(
296
  text: str = Form(...),
297
  reference_text: str = Form(...),
298
- speed: float = Form(1.0, ge=0.5, le=2.0),
299
  output_format: str = Form("wav", pattern="^(wav|mp3|flac)$"),
300
  reference_audio: UploadFile = File(...)):
301
  """
@@ -346,31 +273,30 @@ async def text_to_speech(
346
  async def stream_text_to_speech_cloning(
347
  text: str = Form(..., min_length=1, max_length=5000),
348
  reference_text: str = Form(...),
349
- speed: float = Form(1.0, ge=0.5, le=2.0),
350
  output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"),
351
  reference_audio: UploadFile = File(...)):
352
  """
353
  Sentence-by-Sentence Streaming using a high-performance, asyncio-native
354
- producer-consumer pipeline.
355
  """
356
  if not hasattr(app.state, 'tts_wrapper'):
357
  raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
358
 
359
  async def stream_generator():
360
  loop = asyncio.get_event_loop()
361
- q = asyncio.Queue(maxsize=2)
362
 
363
  async def producer():
364
  try:
365
  converted_wav_buffer = await convert_to_wav_in_memory(reference_audio)
366
  ref_audio_bytes = converted_wav_buffer.getvalue()
367
- audio_hash = hashlib.sha256(ref_audio_bytes).hexdigest()
368
 
369
- # Use LRU cache like blocking endpoint
 
370
  ref_s = await loop.run_in_executor(
371
  tts_executor,
372
  app.state.tts_wrapper._get_or_create_reference_encoding,
373
- audio_hash,
374
  ref_audio_bytes
375
  )
376
 
@@ -381,32 +307,37 @@ async def stream_text_to_speech_cloning(
381
  audio_chunk = app.state.tts_wrapper.tts_model.infer(sentence_text, ref_s, reference_text)
382
  return app.state.tts_wrapper._convert_to_streamable_format(audio_chunk, output_format)
383
 
384
- # Schedule all chunks to be processed in the background.
385
  for sentence in sentences:
386
  task = loop.run_in_executor(tts_executor, process_chunk, sentence)
387
- await q.put(task) # Put the FUTURE, not the result, in the queue.
388
 
389
  except Exception as e:
390
  logger.error(f"Error in producer task: {e}")
391
  await q.put(e)
392
  finally:
393
- await q.put(None) # Signal that all tasks have been scheduled.
394
 
395
  producer_task = asyncio.create_task(producer())
396
 
397
- # The CONSUMER's job is to wait for each result and yield it.
398
- while True:
399
- result = await q.get()
400
- if result is None:
401
- break
402
-
403
- if isinstance(result, Exception):
404
- logger.error(f"Terminating stream due to producer error: {result}")
405
- raise result
 
 
 
406
 
407
- # Await the result of the background task
408
- chunk_bytes = await result
409
  yield chunk_bytes
 
 
 
410
 
411
  await producer_task
412
 
@@ -414,16 +345,3 @@ async def stream_text_to_speech_cloning(
414
  stream_generator(),
415
  media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}"
416
  )
417
-
418
- @app.get("/audio/{filename}")
419
- async def get_audio(filename: str):
420
- """Original NeuTTS feature to serve generated audio files."""
421
- file_path = os.path.join(GENERATED_AUDIO_DIR, filename)
422
- if not os.path.exists(file_path):
423
- raise HTTPException(status_code=404, detail="Audio file not found")
424
-
425
- return Response(
426
- content=open(file_path, "rb").read(),
427
- media_type=f"audio/{filename.split('.')[-1]}", # Simple media type detection
428
- headers={"Content-Disposition": f"attachment; filename={filename}"}
429
- )
 
2
  import io
3
  import asyncio
4
  import time
 
5
  import numpy as np
6
  import psutil
7
  import soundfile as sf
8
  import subprocess
 
9
  from concurrent.futures import ThreadPoolExecutor
10
+ from typing import Generator
11
  from contextlib import asynccontextmanager
12
  import logging
 
13
  import torch
14
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form
15
  from fastapi.responses import Response, StreamingResponse
16
  from fastapi.middleware.cors import CORSMiddleware
17
  from pydantic import BaseModel, Field
 
35
  MAX_WORKERS = 2
36
  tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
37
  SAMPLE_RATE = 24000
 
 
 
 
 
38
 
39
  class TTSRequestModel(BaseModel):
40
  """Model for non-file inputs to synthesis and streaming."""
41
  text: str = Field(..., min_length=1, max_length=1000)
 
42
  output_format: str = Field(default="wav", pattern="^(wav|mp3|flac)$")
43
 
44
 
 
142
  audio = self.tts_model.infer(text, ref_s, reference_text)
143
  return audio
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  # --- Asynchronous Offloading ---
147
 
 
153
  lambda: func(*args, **kwargs)
154
  )
155
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  # --- FastAPI Lifespan Manager (Kokoro Feature) ---
158
 
 
215
  }
216
  }
217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
 
220
  # --- Core Synthesis Endpoints ---
 
223
  async def text_to_speech(
224
  text: str = Form(...),
225
  reference_text: str = Form(...),
 
226
  output_format: str = Form("wav", pattern="^(wav|mp3|flac)$"),
227
  reference_audio: UploadFile = File(...)):
228
  """
 
273
  async def stream_text_to_speech_cloning(
274
  text: str = Form(..., min_length=1, max_length=5000),
275
  reference_text: str = Form(...),
 
276
  output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"),
277
  reference_audio: UploadFile = File(...)):
278
  """
279
  Sentence-by-Sentence Streaming using a high-performance, asyncio-native
280
+ look-ahead pipeline. This ensures true overlap of CPU work and network I/O.
281
  """
282
  if not hasattr(app.state, 'tts_wrapper'):
283
  raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
284
 
285
  async def stream_generator():
286
  loop = asyncio.get_event_loop()
287
+ q = asyncio.Queue(maxsize=MAX_WORKERS + 1) # Queue size based on workers
288
 
289
  async def producer():
290
  try:
291
  converted_wav_buffer = await convert_to_wav_in_memory(reference_audio)
292
  ref_audio_bytes = converted_wav_buffer.getvalue()
 
293
 
294
+ # Perform the one-time voice encoding
295
+ audio_hash = hashlib.sha256(ref_audio_bytes).hexdigest()
296
  ref_s = await loop.run_in_executor(
297
  tts_executor,
298
  app.state.tts_wrapper._get_or_create_reference_encoding,
299
+ audio_hash,
300
  ref_audio_bytes
301
  )
302
 
 
307
  audio_chunk = app.state.tts_wrapper.tts_model.infer(sentence_text, ref_s, reference_text)
308
  return app.state.tts_wrapper._convert_to_streamable_format(audio_chunk, output_format)
309
 
310
+ # Schedule all chunks for background processing
311
  for sentence in sentences:
312
  task = loop.run_in_executor(tts_executor, process_chunk, sentence)
313
+ await q.put(task)
314
 
315
  except Exception as e:
316
  logger.error(f"Error in producer task: {e}")
317
  await q.put(e)
318
  finally:
319
+ await q.put(None)
320
 
321
  producer_task = asyncio.create_task(producer())
322
 
323
+ # --- High-Performance Consumer with Look-Ahead ---
324
+ # Get the first task from the queue to start the process.
325
+ current_task = await q.get()
326
+
327
+ while current_task is not None:
328
+ # Simultaneously, get the NEXT task from the queue.
329
+ # This allows the next chunk to start processing while we wait for the current one.
330
+ next_task = await q.get()
331
+
332
+ # Now, wait for the CURRENT task to finish.
333
+ if isinstance(current_task, Exception):
334
+ raise current_task
335
 
336
+ chunk_bytes = await current_task
 
337
  yield chunk_bytes
338
+
339
+ # The next task becomes the current task for the next iteration.
340
+ current_task = next_task
341
 
342
  await producer_task
343
 
 
345
  stream_generator(),
346
  media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}"
347
  )