auralodyssey commited on
Commit
83977c6
·
verified ·
1 Parent(s): 7576e85

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -37
app.py CHANGED
@@ -317,13 +317,19 @@ except Exception:
317
 
318
  print("🚀 BOOTING KOKORO (OFFICIAL PIPELINE)")
319
 
 
 
 
 
 
 
 
320
  # ------------------------------------------------------------
321
- # OFFICIAL PIPELINES (per docs you pasted)
322
- # 🇺🇸 'a' => American English, 🇬🇧 'b' => British English
323
  # ------------------------------------------------------------
324
  PIPELINES = {
325
- "a": KPipeline(lang_code="a"),
326
- "b": KPipeline(lang_code="b"),
327
  }
328
 
329
  VOICE_CHOICES = {
@@ -340,57 +346,84 @@ VOICE_CHOICES = {
340
  }
341
 
342
  def voice_to_lang_code(voice_code: str) -> str:
343
- # bf_ / bm_ are British
344
  if voice_code.startswith("bf_") or voice_code.startswith("bm_"):
345
  return "b"
346
  return "a"
347
 
348
  # ------------------------------------------------------------
349
- # TEXT HELPERS (sticking to your pasted docs format)
350
- # Use IPA markup like: [Kokoro](/kˈOkəɹO/)
351
  # ------------------------------------------------------------
352
  def normalize_text(text: str) -> str:
353
  if not text:
354
  return text
355
- # Your docs show this exact IPA form for Kokoro
356
- text = text.replace("Kokoro", "[Kokoro](/kˈOkəɹO/)")
357
- return text
358
 
359
  # ------------------------------------------------------------
360
- # CHUNKING
361
- # Main goal: avoid tiny chunks that cause audible discontinuity.
 
362
  # ------------------------------------------------------------
363
- _SENT_SPLIT = re.compile(r"(?<=[.!?])\s+|\n+")
364
 
365
  def tuned_splitter(text: str):
366
  text = (text or "").strip()
367
  if not text:
368
  return
369
- parts = [p.strip() for p in _SENT_SPLIT.split(text) if p and p.strip()]
370
 
371
- buf = ""
372
- for p in parts:
373
- if not buf:
374
- buf = p
375
- continue
376
-
377
- # Grow chunks to reduce boundary artifacts
378
- if len(buf) < 220:
379
- buf = f"{buf} {p}"
380
- continue
381
-
382
- yield buf
383
- buf = p
 
 
 
 
 
 
 
 
 
 
 
 
 
384
 
385
- if buf:
386
- yield buf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
 
388
  # ------------------------------------------------------------
389
  # AUDIO CONVERSION FIX
390
  # Fixes: "'Tensor' object has no attribute 'astype'"
391
  # ------------------------------------------------------------
392
  def audio_to_int16_np(audio):
393
- # audio can be torch.Tensor or np.ndarray
394
  if isinstance(audio, torch.Tensor):
395
  audio = audio.detach().cpu()
396
  audio = torch.clamp(audio, -1.0, 1.0)
@@ -405,25 +438,34 @@ def audio_to_pcm_bytes(audio) -> bytes:
405
  return audio_to_int16_np(audio).tobytes()
406
 
407
  # ------------------------------------------------------------
408
- # OFFICIAL GENERATION (per your docs)
409
  # generator = pipeline(text, voice='af_heart', speed=1, split_pattern=r'\n+')
410
  # ------------------------------------------------------------
411
  def kokoro_generate(chunk: str, voice_code: str, speed: float):
412
  lang_code = voice_to_lang_code(voice_code)
413
  pipeline = PIPELINES[lang_code]
414
 
415
- # Keep split_pattern exactly in the spirit of your docs
416
- # Our own splitter already splits on sentence/newlines, so this stays light.
417
  generator = pipeline(
418
  chunk,
419
  voice=voice_code,
420
  speed=float(speed),
421
  split_pattern=r"\n+",
422
  )
423
-
424
  for _, _, audio in generator:
425
  yield audio
426
 
 
 
 
 
 
 
 
 
 
 
 
 
427
  # ------------------------------------------------------------
428
  # GRADIO STREAM
429
  # ------------------------------------------------------------
@@ -442,7 +484,6 @@ def gradio_stream_generator(text, voice_name, speed):
442
 
443
  # ------------------------------------------------------------
444
  # FASTAPI + WEBSOCKET QUEUE
445
- # Keep it single-file on CPU to stay stable under load.
446
  # ------------------------------------------------------------
447
  api = FastAPI()
448
  INFERENCE_EXECUTOR = ThreadPoolExecutor(max_workers=1)
@@ -454,7 +495,6 @@ async def audio_engine_loop():
454
 
455
  while True:
456
  ws, voice_code, speed, chunk = await INFERENCE_QUEUE.get()
457
-
458
  try:
459
  if ws.client_state.value > 1:
460
  continue
@@ -478,6 +518,9 @@ async def audio_engine_loop():
478
 
479
  @api.on_event("startup")
480
  async def startup():
 
 
 
481
  asyncio.create_task(audio_engine_loop())
482
 
483
  @api.websocket("/ws/audio")
@@ -517,6 +560,7 @@ async def websocket_endpoint(ws: WebSocket):
517
 
518
  if "text" in data:
519
  text = normalize_text(data["text"])
 
520
  for chunk in tuned_splitter(text):
521
  if chunk.strip():
522
  await INFERENCE_QUEUE.put((ws, voice_code, speed, chunk))
@@ -533,7 +577,7 @@ async def websocket_endpoint(ws: WebSocket):
533
  # GRADIO UI
534
  # ------------------------------------------------------------
535
  with gr.Blocks(title="Kokoro TTS") as app:
536
- gr.Markdown("## ⚡ Kokoro-82M (Official Pipeline)")
537
  with gr.Row():
538
  with gr.Column():
539
  text_in = gr.Textbox(
 
317
 
318
  print("🚀 BOOTING KOKORO (OFFICIAL PIPELINE)")
319
 
320
+ # Keep CPU threads predictable
321
+ try:
322
+ torch.set_num_threads(int(os.environ.get("TORCH_NUM_THREADS", "2")))
323
+ torch.set_num_interop_threads(int(os.environ.get("TORCH_NUM_INTEROP_THREADS", "1")))
324
+ except Exception:
325
+ pass
326
+
327
  # ------------------------------------------------------------
328
+ # OFFICIAL PIPELINES (per your pasted docs)
 
329
  # ------------------------------------------------------------
330
  PIPELINES = {
331
+ "a": KPipeline(lang_code="a"), # 🇺🇸 American English
332
+ "b": KPipeline(lang_code="b"), # 🇬🇧 British English
333
  }
334
 
335
  VOICE_CHOICES = {
 
346
  }
347
 
348
  def voice_to_lang_code(voice_code: str) -> str:
 
349
  if voice_code.startswith("bf_") or voice_code.startswith("bm_"):
350
  return "b"
351
  return "a"
352
 
353
  # ------------------------------------------------------------
354
+ # TEXT NORMALIZATION (stays within the docs you pasted)
355
+ # Docs show: [Kokoro](/kˈOkəɹO/)
356
  # ------------------------------------------------------------
357
  def normalize_text(text: str) -> str:
358
  if not text:
359
  return text
360
+ return text.replace("Kokoro", "[Kokoro](/kˈOkəɹO/)")
 
 
361
 
362
  # ------------------------------------------------------------
363
+ # FAST-FIRST-AUDIO SPLITTER (your old technique)
364
+ # Progressive thresholds so first chunk is quick.
365
+ # Also includes a fallback to cut long text even without punctuation.
366
  # ------------------------------------------------------------
367
+ _PUNCT_END = re.compile(r"[.,!?;:\n]$")
368
 
369
  def tuned_splitter(text: str):
370
  text = (text or "").strip()
371
  if not text:
372
  return
 
373
 
374
+ parts = re.split(r"([.,!?;:\n]+)", text)
375
+ buffer = ""
376
+ chunk_count = 0
377
+
378
+ def threshold_for(n: int) -> int:
379
+ if n == 0:
380
+ return 60 # fast first audio
381
+ if n == 1:
382
+ return 120
383
+ if n == 2:
384
+ return 180
385
+ return 260
386
+
387
+ for part in parts:
388
+ buffer += part
389
+
390
+ threshold = threshold_for(chunk_count)
391
+
392
+ # Emit when punctuation boundary is hit and buffer is big enough
393
+ if _PUNCT_END.search(buffer) and len(buffer) >= threshold:
394
+ out = buffer.strip()
395
+ if out:
396
+ yield out
397
+ chunk_count += 1
398
+ buffer = ""
399
+ continue
400
 
401
+ # Fallback: if no punctuation for too long, cut at last space
402
+ hard_max = 320 if chunk_count == 0 else 520
403
+ if len(buffer) >= hard_max:
404
+ cut = buffer.rfind(" ")
405
+ if cut > 40:
406
+ out = buffer[:cut].strip()
407
+ rest = buffer[cut:].strip()
408
+ if out:
409
+ yield out
410
+ chunk_count += 1
411
+ buffer = rest
412
+ else:
413
+ out = buffer.strip()
414
+ if out:
415
+ yield out
416
+ chunk_count += 1
417
+ buffer = ""
418
+
419
+ if buffer.strip():
420
+ yield buffer.strip()
421
 
422
  # ------------------------------------------------------------
423
  # AUDIO CONVERSION FIX
424
  # Fixes: "'Tensor' object has no attribute 'astype'"
425
  # ------------------------------------------------------------
426
  def audio_to_int16_np(audio):
 
427
  if isinstance(audio, torch.Tensor):
428
  audio = audio.detach().cpu()
429
  audio = torch.clamp(audio, -1.0, 1.0)
 
438
  return audio_to_int16_np(audio).tobytes()
439
 
440
  # ------------------------------------------------------------
441
+ # OFFICIAL GENERATION (exact pattern from your docs)
442
  # generator = pipeline(text, voice='af_heart', speed=1, split_pattern=r'\n+')
443
  # ------------------------------------------------------------
444
  def kokoro_generate(chunk: str, voice_code: str, speed: float):
445
  lang_code = voice_to_lang_code(voice_code)
446
  pipeline = PIPELINES[lang_code]
447
 
 
 
448
  generator = pipeline(
449
  chunk,
450
  voice=voice_code,
451
  speed=float(speed),
452
  split_pattern=r"\n+",
453
  )
 
454
  for _, _, audio in generator:
455
  yield audio
456
 
457
+ # ------------------------------------------------------------
458
+ # WARMUP
459
+ # Moves the first-call latency to startup instead of first user request.
460
+ # ------------------------------------------------------------
461
+ def warmup():
462
+ try:
463
+ for _ in kokoro_generate("Hello.", "af_bella", 1.0):
464
+ break
465
+ print("✅ WARMUP DONE")
466
+ except Exception as e:
467
+ print(f"⚠️ WARMUP FAILED: {e}")
468
+
469
  # ------------------------------------------------------------
470
  # GRADIO STREAM
471
  # ------------------------------------------------------------
 
484
 
485
  # ------------------------------------------------------------
486
  # FASTAPI + WEBSOCKET QUEUE
 
487
  # ------------------------------------------------------------
488
  api = FastAPI()
489
  INFERENCE_EXECUTOR = ThreadPoolExecutor(max_workers=1)
 
495
 
496
  while True:
497
  ws, voice_code, speed, chunk = await INFERENCE_QUEUE.get()
 
498
  try:
499
  if ws.client_state.value > 1:
500
  continue
 
518
 
519
  @api.on_event("startup")
520
  async def startup():
521
+ # Warmup in executor so startup does not block event loop
522
+ loop = asyncio.get_running_loop()
523
+ await loop.run_in_executor(INFERENCE_EXECUTOR, warmup)
524
  asyncio.create_task(audio_engine_loop())
525
 
526
  @api.websocket("/ws/audio")
 
560
 
561
  if "text" in data:
562
  text = normalize_text(data["text"])
563
+ # Enqueue fast first chunk first
564
  for chunk in tuned_splitter(text):
565
  if chunk.strip():
566
  await INFERENCE_QUEUE.put((ws, voice_code, speed, chunk))
 
577
  # GRADIO UI
578
  # ------------------------------------------------------------
579
  with gr.Blocks(title="Kokoro TTS") as app:
580
+ gr.Markdown("## ⚡ Kokoro-82M (Official Pipeline, Fast First Audio)")
581
  with gr.Row():
582
  with gr.Column():
583
  text_in = gr.Textbox(