auralodyssey commited on
Commit
f78ae4b
Β·
verified Β·
1 Parent(s): 03d1b02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -366
app.py CHANGED
@@ -295,43 +295,44 @@
295
  # if __name__ == "__main__":
296
  # uvicorn.run(final_app, host="0.0.0.0", port=7860)
297
  import os
298
- import json
299
- import time
300
  import re
 
301
  import asyncio
302
- import threading
303
  from concurrent.futures import ThreadPoolExecutor
304
 
305
  import numpy as np
306
- import onnxruntime as ort
307
- from huggingface_hub import hf_hub_download
308
- from misaki import en
309
- from functools import lru_cache
310
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
311
  import uvicorn
312
 
313
- # =========================================================
314
- # HF CPU BOX TUNING (2 vCPU)
315
- # =========================================================
 
 
 
316
  os.environ.setdefault("OMP_NUM_THREADS", "2")
317
  os.environ.setdefault("MKL_NUM_THREADS", "2")
318
  os.environ.setdefault("NUMEXPR_NUM_THREADS", "2")
319
 
 
 
 
 
 
 
 
320
  try:
321
  import uvloop # type: ignore
322
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
323
  except Exception:
324
  pass
325
 
326
- SAMPLE_RATE = 24000
327
-
328
- # =========================================================
329
- # ONNX KOKORO CONFIG (YOUR ONNX STYLE)
330
- # =========================================================
331
- MODEL_REPO = "onnx-community/Kokoro-82M-v1.0-ONNX"
332
- MODEL_FILE = "onnx/model.onnx"
333
- TOKENIZER_FILE = "tokenizer.json"
334
 
 
 
 
335
  VOICE_CHOICES = {
336
  "πŸ‡ΊπŸ‡Έ 🚺 Heart": "af_heart", "πŸ‡ΊπŸ‡Έ 🚺 Bella": "af_bella", "πŸ‡ΊπŸ‡Έ 🚺 Nicole": "af_nicole",
337
  "πŸ‡ΊπŸ‡Έ 🚺 Aoede": "af_aoede", "πŸ‡ΊπŸ‡Έ 🚺 Kore": "af_kore", "πŸ‡ΊπŸ‡Έ 🚺 Sarah": "af_sarah",
@@ -344,333 +345,157 @@ VOICE_CHOICES = {
344
  "πŸ‡¬πŸ‡§ 🚹 George": "bm_george", "πŸ‡¬πŸ‡§ 🚹 Fable": "bm_fable", "πŸ‡¬πŸ‡§ 🚹 Lewis": "bm_lewis",
345
  "πŸ‡¬πŸ‡§ 🚹 Daniel": "bm_daniel",
346
  }
347
- ALLOWED_VOICE_IDS = set(VOICE_CHOICES.values())
348
-
349
- # βœ… DEFAULT VOICE = ONYX
350
- DEFAULT_VOICE_ID = "am_onyx"
351
- DEFAULT_SPEED = 1.0
352
-
353
- print("πŸš€ BOOTING ONNX KOKORO API (LOW LATENCY, API ONLY)")
354
-
355
- # =========================================================
356
- # 1) G2P
357
- # =========================================================
358
- G2P = en.G2P(trf=False, british=False, fallback=None)
359
-
360
- # =========================================================
361
- # 2) TOKENIZER
362
- # =========================================================
363
- vocab_path = hf_hub_download(repo_id=MODEL_REPO, filename=TOKENIZER_FILE)
364
- with open(vocab_path, "r", encoding="utf-8") as f:
365
- data = json.load(f)
366
- TOKENIZER = data["model"]["vocab"] if "model" in data else data.get("vocab", {})
367
-
368
- # =========================================================
369
- # 3) VOICES (LAZY LOAD, CACHE)
370
- # =========================================================
371
- VOICE_CACHE = {} # voice_id -> np.ndarray (T,1,256)
372
-
373
- def _load_voice_bin(voice_id: str) -> np.ndarray:
374
- path = hf_hub_download(repo_id=MODEL_REPO, filename=f"voices/{voice_id}.bin")
375
- return np.fromfile(path, dtype=np.float32).reshape(-1, 1, 256)
376
-
377
- def get_voice(voice_id_or_label: str) -> np.ndarray:
378
- vid = VOICE_CHOICES.get(voice_id_or_label, voice_id_or_label).strip()
379
- if vid not in ALLOWED_VOICE_IDS:
380
- vid = DEFAULT_VOICE_ID
381
-
382
- if vid not in VOICE_CACHE:
383
- try:
384
- print(f"⬇️ Loading Voice: {vid}")
385
- VOICE_CACHE[vid] = _load_voice_bin(vid)
386
- except Exception:
387
- if "af_bella" not in VOICE_CACHE:
388
- print("⚠️ Voice load failed, falling back to af_bella")
389
- VOICE_CACHE["af_bella"] = _load_voice_bin("af_bella")
390
- return VOICE_CACHE["af_bella"]
391
-
392
- return VOICE_CACHE[vid]
393
-
394
- # =========================================================
395
- # 4) ONNX SESSION (TUNED FOR 2 vCPU)
396
- # =========================================================
397
- model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
398
-
399
- sess_options = ort.SessionOptions()
400
- sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
401
- sess_options.add_session_config_entry("session.intra_op.allow_spinning", "0")
402
-
403
- # On 2 vCPU, keep it tight
404
- sess_options.intra_op_num_threads = int(os.environ.get("ORT_INTRA_OP_THREADS", "2"))
405
- sess_options.inter_op_num_threads = int(os.environ.get("ORT_INTER_OP_THREADS", "1"))
406
-
407
- SESSION = ort.InferenceSession(model_path, sess_options, providers=["CPUExecutionProvider"])
408
- print("βœ… ONNX SESSION READY")
409
-
410
- # =========================================================
411
- # TEXT QUALITY FIXES (NAMES, ACRONYMS, CAMELCASE)
412
- # =========================================================
413
- RE_ALLCAPS = re.compile(r"\b([A-Z]{2,})\b")
414
- RE_CAMEL = re.compile(r"([a-z])([A-Z])")
415
- RE_SENT_SPLIT = re.compile(r'([.,!?;:\n]+)')
416
-
417
- def normalize_names(text: str) -> str:
418
- if not text:
419
- return ""
420
- # AI -> A I
421
- text = RE_ALLCAPS.sub(lambda m: " ".join(list(m.group(1))), text)
422
- # OpenAI -> Open AI
423
- text = RE_CAMEL.sub(r"\1 \2", text)
424
- return text
425
 
426
- @lru_cache(maxsize=10000)
427
- def get_tokens_cached(text: str):
428
- # Your IPA hint behavior from v1
429
- if "Kokoro" in text:
430
- text = text.replace("Kokoro", "kˈOkΙ™ΙΉO")
431
- phonemes, _ = G2P(text)
432
- return tuple(TOKENIZER.get(p, 0) for p in phonemes)
433
-
434
- def tuned_splitter(text: str):
435
- # Fast first audio, bigger later chunks
436
- parts = RE_SENT_SPLIT.split(text)
437
- buf = ""
438
- chunk_idx = 0
439
-
440
- for p in parts:
441
- if p is None:
442
- continue
443
- buf += p
444
-
445
- if chunk_idx == 0:
446
- threshold = 60
447
- elif chunk_idx == 1:
448
- threshold = 120
449
- elif chunk_idx == 2:
450
- threshold = 180
451
- else:
452
- threshold = 280
453
-
454
- if buf and re.search(r"[.,!?;:\n]$", buf) and len(buf) >= threshold:
455
- s = buf.strip()
456
- if s:
457
- yield s
458
- chunk_idx += 1
459
- buf = ""
460
-
461
- s = buf.strip()
462
- if s:
463
- yield s
464
-
465
- # =========================================================
466
- # AUDIO POST (LESS AGGRESSIVE TRIM + CROSSFADE TO REMOVE "DROPS")
467
- # =========================================================
468
- def trim_leading(audio_f32: np.ndarray, threshold=0.01, pad=80) -> np.ndarray:
469
- if audio_f32.size == 0:
470
- return audio_f32
471
- mask = np.abs(audio_f32) > threshold
472
- if not np.any(mask):
473
- return audio_f32
474
- start = int(np.argmax(mask))
475
- start = max(0, start - pad)
476
- return audio_f32[start:]
477
-
478
- def trim_trailing(audio_f32: np.ndarray, threshold=0.01, pad=120) -> np.ndarray:
479
- if audio_f32.size == 0:
480
- return audio_f32
481
- mask = np.abs(audio_f32) > threshold
482
- if not np.any(mask):
483
- return audio_f32
484
- end = int(len(mask) - np.argmax(mask[::-1]))
485
- end = min(len(audio_f32), end + pad)
486
- return audio_f32[:end]
487
-
488
- def float_to_pcm_bytes(audio_f32: np.ndarray) -> bytes:
489
- audio_f32 = np.clip(audio_f32, -1.0, 1.0).astype(np.float32)
490
- pcm = (audio_f32 * 32767.0).astype(np.int16)
491
- return pcm.tobytes()
492
-
493
- def crossfade_bytes_stream(chunks_f32, overlap=1200):
494
- """
495
- overlap=1200 samples ~= 50ms at 24kHz
496
- We hold the last overlap of each chunk, blend into next chunk head,
497
- then stream without clicks or "drops".
498
- """
499
- prev_tail = None
500
-
501
- for i, a in enumerate(chunks_f32):
502
- if a is None or a.size == 0:
503
- continue
504
-
505
- if prev_tail is None:
506
- if a.size <= overlap * 2:
507
- yield float_to_pcm_bytes(a)
508
- prev_tail = None
509
- continue
510
 
511
- body = a[:-overlap]
512
- prev_tail = a[-overlap:]
513
- yield float_to_pcm_bytes(body)
514
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
 
516
- if a.size < overlap:
517
- # too small, just append
518
- blended = np.concatenate([prev_tail, a])
519
- prev_tail = None
520
- yield float_to_pcm_bytes(blended)
521
- continue
522
 
523
- fade_out = np.linspace(1.0, 0.0, overlap, dtype=np.float32)
524
- fade_in = 1.0 - fade_out
525
- head = a[:overlap]
526
- blended = (prev_tail * fade_out) + (head * fade_in)
527
 
528
- if a.size <= overlap * 2:
529
- # nothing meaningful to hold
530
- out = np.concatenate([blended, a[overlap:]])
531
- prev_tail = None
532
- yield float_to_pcm_bytes(out)
533
- continue
534
 
535
- mid = a[overlap:-overlap]
536
- prev_tail = a[-overlap:]
537
- out = np.concatenate([blended, mid])
538
- yield float_to_pcm_bytes(out)
539
-
540
- if prev_tail is not None and prev_tail.size > 0:
541
- yield float_to_pcm_bytes(prev_tail)
542
-
543
- # =========================================================
544
- # ONNX INFER (FAST)
545
- # =========================================================
546
- def infer_tokens(tokens, voice_vec, speed: float):
547
- ids = tokens[:510]
548
- if not ids:
549
- return None
550
-
551
- # voice_vec shape: (T,1,256)
552
- style = voice_vec[min(len(ids), voice_vec.shape[0] - 1)] # -> (1,256)
553
-
554
- audio = SESSION.run(
555
- None,
556
- {
557
- "input_ids": np.array([[0, *ids, 0]], dtype=np.int64),
558
- "style": style,
559
- "speed": np.array([float(speed)], dtype=np.float32),
560
- },
561
- )[0] # expected shape: (1, N)
562
-
563
- out = audio[0].astype(np.float32, copy=False)
564
- return out
565
-
566
- # =========================================================
567
- # API ONLY (FASTAPI + WS)
568
- # =========================================================
569
- api = FastAPI()
570
 
571
- # Single worker thread for full job generation (tokens + onnx + crossfade)
572
- INFERENCE_EXECUTOR = ThreadPoolExecutor(max_workers=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573
 
574
- # Queue of jobs: each job is 1 full text for 1 websocket
575
- JOB_QUEUE: asyncio.Queue = asyncio.Queue(maxsize=64)
576
-
577
- def resolve_voice(value: str) -> str:
578
- if not value:
579
- return DEFAULT_VOICE_ID
580
- v = VOICE_CHOICES.get(value, value).strip()
581
- if v not in ALLOWED_VOICE_IDS:
582
- return DEFAULT_VOICE_ID
583
- return v
584
-
585
- @api.get("/health")
586
- async def health():
587
- return {
588
- "ok": True,
589
- "engine": "onnxruntime",
590
- "sample_rate": SAMPLE_RATE,
591
- "default_voice": DEFAULT_VOICE_ID,
592
- }
593
-
594
- def warmup_once():
595
  try:
596
- get_voice(DEFAULT_VOICE_ID)
597
- tokens = get_tokens_cached("Hello.") # cached tuple
598
- _ = infer_tokens(tokens, VOICE_CACHE[DEFAULT_VOICE_ID], 1.0)
599
- print("βœ… WARMUP OK")
600
  except Exception as e:
601
  print(f"⚠️ WARMUP FAILED: {e}")
602
 
603
- @api.on_event("startup")
604
- async def startup():
605
- loop = asyncio.get_running_loop()
606
- await loop.run_in_executor(INFERENCE_EXECUTOR, warmup_once)
607
- asyncio.create_task(engine_loop())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
608
 
609
- async def engine_loop():
 
 
 
610
  print("⚑ API AUDIO PIPELINE STARTED")
611
  loop = asyncio.get_running_loop()
612
 
613
  while True:
614
- ws, voice_id, speed, text = await JOB_QUEUE.get()
615
 
 
616
  if ws.client_state.value > 1:
617
  continue
618
 
619
- # This queue carries PCM frames from the worker thread back to asyncio
620
- frame_q: asyncio.Queue = asyncio.Queue(maxsize=8)
621
- stop_flag = threading.Event()
622
 
623
- def _worker_full_job():
624
  try:
625
- t0 = time.time()
626
-
627
- voice_vec = get_voice(voice_id)
628
-
629
- # Build per-chunk float32 audio list, with light leading trim
630
- audio_chunks = []
631
- first = True
632
-
633
- for chunk in tuned_splitter(text):
634
- if stop_flag.is_set():
635
- break
636
-
637
- # tokenize (cached)
638
- tokens = get_tokens_cached(chunk)
639
- if not tokens:
640
- continue
641
-
642
- a = infer_tokens(tokens, voice_vec, speed)
643
- if a is None or a.size == 0:
644
- continue
645
-
646
- # do NOT aggressively trim every chunk, only leading a bit
647
- if first:
648
- a = trim_leading(a, threshold=0.01, pad=120)
649
- first = False
650
- else:
651
- a = trim_leading(a, threshold=0.01, pad=60)
652
-
653
- audio_chunks.append(a)
654
-
655
- # Push first audio as soon as we have it, no waiting for the full list
656
- if len(audio_chunks) == 1:
657
- for frame in crossfade_bytes_stream(audio_chunks, overlap=1200):
658
- loop.call_soon_threadsafe(frame_q.put_nowait, frame)
659
- audio_chunks.clear()
660
-
661
- # Flush remaining with crossfade
662
- if not stop_flag.is_set():
663
- if audio_chunks:
664
- # trim trailing only at the very end to avoid cutting words mid stream
665
- audio_chunks[-1] = trim_trailing(audio_chunks[-1], threshold=0.01, pad=160)
666
-
667
- for frame in crossfade_bytes_stream(audio_chunks, overlap=1200):
668
- loop.call_soon_threadsafe(frame_q.put_nowait, frame)
669
-
670
  loop.call_soon_threadsafe(frame_q.put_nowait, None)
671
- dt = time.time() - t0
672
- print(f"βœ… job done in {dt:.2f}s")
673
-
674
  except Exception as e:
675
  print(f"API Worker Error: {e}")
676
  try:
@@ -678,7 +503,7 @@ async def engine_loop():
678
  except Exception:
679
  pass
680
 
681
- INFERENCE_EXECUTOR.submit(_worker_full_job)
682
 
683
  first_sent = False
684
  started = time.time()
@@ -689,25 +514,28 @@ async def engine_loop():
689
  break
690
 
691
  if ws.client_state.value > 1:
692
- stop_flag.set()
693
  break
694
 
695
  try:
696
  await ws.send_bytes(frame)
697
  if not first_sent:
 
698
  first_sent = True
699
- print(f"⚑ first audio sent in {time.time() - started:.2f}s")
700
  except Exception:
701
- stop_flag.set()
702
  break
703
 
 
 
 
 
 
 
704
  @api.websocket("/ws/audio")
705
  async def websocket_endpoint(ws: WebSocket):
706
  await ws.accept()
707
 
708
- # per-connection state
709
- voice_id = DEFAULT_VOICE_ID # βœ… default Onyx
710
- speed = DEFAULT_SPEED
711
 
712
  print(f"βœ… Client connected: {ws.client}")
713
 
@@ -726,51 +554,53 @@ async def websocket_endpoint(ws: WebSocket):
726
  try:
727
  data = await ws.receive_json()
728
  except WebSocketDisconnect:
729
- print("❌ Client disconnected")
730
  break
731
- except Exception:
 
732
  break
733
 
734
- # client config
735
- if "config" in data or data.get("type") == "config":
736
- voice_id = resolve_voice(str(data.get("voice", voice_id)))
737
- try:
738
- speed = float(data.get("speed", speed))
739
- except Exception:
740
- speed = DEFAULT_SPEED
741
- # preload voice immediately so the next text has no voice load delay
742
- try:
743
- get_voice(voice_id)
744
- except Exception:
745
- voice_id = DEFAULT_VOICE_ID
746
- get_voice(voice_id)
747
-
748
- # client text
749
- if "text" in data or data.get("type") == "text":
750
- raw = str(data.get("text", ""))
751
- raw = raw.strip()
752
- if not raw:
753
- continue
754
 
755
- # name + acronym fix so it stops skipping brands and people names
756
- raw = normalize_names(raw)
 
 
757
 
758
- # hard cap to prevent one user blocking the box forever
759
- if len(raw) > 6000:
760
- await ws.send_json({"type": "error", "message": "text_too_long", "max_chars": 6000})
761
- continue
762
-
763
- try:
764
- JOB_QUEUE.put_nowait((ws, voice_id, speed, raw))
765
- except asyncio.QueueFull:
766
- await ws.send_json({"type": "error", "message": "server_busy"})
767
-
768
- if "flush" in data or data.get("type") == "flush":
769
- await ws.send_json({"type": "flushed"})
770
 
771
  finally:
772
  heartbeat_task.cancel()
773
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
774
  if __name__ == "__main__":
775
- uvicorn.run(api, host="0.0.0.0", port=7860)
776
-
 
295
  # if __name__ == "__main__":
296
  # uvicorn.run(final_app, host="0.0.0.0", port=7860)
297
  import os
 
 
298
  import re
299
+ import time
300
  import asyncio
 
301
  from concurrent.futures import ThreadPoolExecutor
302
 
303
  import numpy as np
304
+ import gradio as gr
 
 
 
305
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
306
  import uvicorn
307
 
308
+ import torch
309
+ from kokoro import KPipeline
310
+
311
+ # ----------------------------
312
+ # HARD LIMIT CPU THREADS (2 vCPU box)
313
+ # ----------------------------
314
  os.environ.setdefault("OMP_NUM_THREADS", "2")
315
  os.environ.setdefault("MKL_NUM_THREADS", "2")
316
  os.environ.setdefault("NUMEXPR_NUM_THREADS", "2")
317
 
318
+ try:
319
+ torch.set_num_threads(int(os.environ.get("TORCH_NUM_THREADS", "2")))
320
+ torch.set_num_interop_threads(int(os.environ.get("TORCH_NUM_INTEROP_THREADS", "1")))
321
+ except Exception:
322
+ pass
323
+
324
+ # Optional: uvloop for faster event loop on HF Linux
325
  try:
326
  import uvloop # type: ignore
327
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
328
  except Exception:
329
  pass
330
 
331
+ print("πŸš€ BOOTING KOKORO (OFFICIAL PIPELINE, LOW LATENCY)")
 
 
 
 
 
 
 
332
 
333
+ # ----------------------------
334
+ # VOICES
335
+ # ----------------------------
336
  VOICE_CHOICES = {
337
  "πŸ‡ΊπŸ‡Έ 🚺 Heart": "af_heart", "πŸ‡ΊπŸ‡Έ 🚺 Bella": "af_bella", "πŸ‡ΊπŸ‡Έ 🚺 Nicole": "af_nicole",
338
  "πŸ‡ΊπŸ‡Έ 🚺 Aoede": "af_aoede", "πŸ‡ΊπŸ‡Έ 🚺 Kore": "af_kore", "πŸ‡ΊπŸ‡Έ 🚺 Sarah": "af_sarah",
 
345
  "πŸ‡¬πŸ‡§ 🚹 George": "bm_george", "πŸ‡¬πŸ‡§ 🚹 Fable": "bm_fable", "πŸ‡¬πŸ‡§ 🚹 Lewis": "bm_lewis",
346
  "πŸ‡¬πŸ‡§ 🚹 Daniel": "bm_daniel",
347
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
 
349
+ def voice_to_lang_code(voice_code: str) -> str:
350
+ if voice_code.startswith("bf_") or voice_code.startswith("bm_"):
351
+ return "b" # British
352
+ return "a" # American
353
+
354
+ # ----------------------------
355
+ # PIPELINES (keep hot in RAM)
356
+ # ----------------------------
357
+ PIPELINES = {
358
+ "a": KPipeline(lang_code="a"),
359
+ "b": KPipeline(lang_code="b"),
360
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
+ # ----------------------------
363
+ # TEXT NORMALIZATION (matches your pasted official docs)
364
+ # ----------------------------
365
+ def normalize_text(text: str) -> str:
366
+ if not text:
367
+ return ""
368
+ return text.replace("Kokoro", "[Kokoro](/kˈOkΙ™ΙΉO/)")
369
+
370
+ # ----------------------------
371
+ # LOW LATENCY SEGMENTATION
372
+ # One pipeline call per request.
373
+ # We inject newlines to let split_pattern=r"\n+" split inside Kokoro.
374
+ # We also force a small first segment for fast first audio.
375
+ # ----------------------------
376
+ _SENT_BOUNDARY = re.compile(r"([.!?;:])\s+")
377
+
378
+ def inject_newlines_for_fast_stream(text: str) -> str:
379
+ text = normalize_text(text).strip()
380
+ if not text:
381
+ return ""
382
 
383
+ # Sentence boundaries -> newline so official split_pattern can segment
384
+ text = _SENT_BOUNDARY.sub(r"\1\n", text)
 
 
 
 
385
 
386
+ # Also split on existing multi-newlines
387
+ text = re.sub(r"\n{3,}", "\n\n", text)
 
 
388
 
389
+ # Guarantee a small first segment for low time-to-first-audio
390
+ if "\n" not in text and len(text) > 90:
391
+ cut = text.rfind(" ", 0, 70)
392
+ if cut < 35:
393
+ cut = 70
394
+ text = text[:cut].strip() + "\n" + text[cut:].strip()
395
 
396
+ return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
 
398
+ # ----------------------------
399
+ # AUDIO CONVERSION (fast, safe)
400
+ # ----------------------------
401
+ def audio_to_int16_np(audio):
402
+ if isinstance(audio, torch.Tensor):
403
+ audio = audio.detach().cpu()
404
+ audio = torch.clamp(audio, -1.0, 1.0)
405
+ return (audio * 32767.0).to(torch.int16).numpy()
406
+
407
+ audio = np.asarray(audio)
408
+ audio = np.clip(audio, -1.0, 1.0)
409
+ return (audio * 32767.0).astype(np.int16)
410
+
411
+ def audio_to_pcm_bytes(audio) -> bytes:
412
+ return audio_to_int16_np(audio).tobytes()
413
+
414
+ # ----------------------------
415
+ # OFFICIAL GENERATION PATH (single pipeline call)
416
+ # generator = pipeline(text, voice='af_heart', speed=1, split_pattern=r'\n+')
417
+ # ----------------------------
418
+ def kokoro_generator_full(text: str, voice_code: str, speed: float):
419
+ lang_code = voice_to_lang_code(voice_code)
420
+ pipeline = PIPELINES[lang_code]
421
+ text = inject_newlines_for_fast_stream(text)
422
 
423
+ if not text:
424
+ return
425
+
426
+ with torch.inference_mode():
427
+ generator = pipeline(
428
+ text,
429
+ voice=voice_code,
430
+ speed=float(speed),
431
+ split_pattern=r"\n+",
432
+ )
433
+ for _, _, audio in generator:
434
+ yield audio
435
+
436
+ # ----------------------------
437
+ # WARMUP (pay cold-start cost at boot)
438
+ # ----------------------------
439
+ def warmup():
 
 
 
 
440
  try:
441
+ t0 = time.time()
442
+ for _ in kokoro_generator_full("Hello.", "af_bella", 1.0):
443
+ break
444
+ print(f"βœ… WARMUP DONE in {time.time() - t0:.2f}s")
445
  except Exception as e:
446
  print(f"⚠️ WARMUP FAILED: {e}")
447
 
448
+ # ----------------------------
449
+ # GRADIO UI STREAM
450
+ # ----------------------------
451
+ def gradio_stream(text, voice_name, speed):
452
+ voice_code = VOICE_CHOICES.get(voice_name, voice_name)
453
+ text = normalize_text(text)
454
+
455
+ i = 0
456
+ t0 = time.time()
457
+ for audio in kokoro_generator_full(text, voice_code, speed):
458
+ if i == 0:
459
+ print(f"⚑ UI first audio in {time.time() - t0:.2f}s")
460
+ i += 1
461
+ yield 24000, audio_to_int16_np(audio)
462
+
463
+ # ----------------------------
464
+ # FASTAPI WS ENGINE
465
+ # Single worker thread for actual generation.
466
+ # Stream frames to client as soon as they exist.
467
+ # No buffering a full list before sending.
468
+ # ----------------------------
469
+ api = FastAPI()
470
 
471
+ INFERENCE_EXECUTOR = ThreadPoolExecutor(max_workers=1)
472
+ INFERENCE_QUEUE: asyncio.Queue = asyncio.Queue()
473
+
474
+ async def audio_engine_loop():
475
  print("⚑ API AUDIO PIPELINE STARTED")
476
  loop = asyncio.get_running_loop()
477
 
478
  while True:
479
+ ws, voice_code, speed, text = await INFERENCE_QUEUE.get()
480
 
481
+ # Skip dead clients early
482
  if ws.client_state.value > 1:
483
  continue
484
 
485
+ frame_q: asyncio.Queue = asyncio.Queue(maxsize=6)
 
 
486
 
487
+ def _worker():
488
  try:
489
+ for audio in kokoro_generator_full(text, voice_code, speed):
490
+ b = audio_to_pcm_bytes(audio)
491
+ # backpressure aware
492
+ while True:
493
+ try:
494
+ loop.call_soon_threadsafe(frame_q.put_nowait, b)
495
+ break
496
+ except Exception:
497
+ time.sleep(0.001)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
  loop.call_soon_threadsafe(frame_q.put_nowait, None)
 
 
 
499
  except Exception as e:
500
  print(f"API Worker Error: {e}")
501
  try:
 
503
  except Exception:
504
  pass
505
 
506
+ INFERENCE_EXECUTOR.submit(_worker)
507
 
508
  first_sent = False
509
  started = time.time()
 
514
  break
515
 
516
  if ws.client_state.value > 1:
 
517
  break
518
 
519
  try:
520
  await ws.send_bytes(frame)
521
  if not first_sent:
522
+ print(f"⚑ API first audio in {time.time() - started:.2f}s")
523
  first_sent = True
 
524
  except Exception:
 
525
  break
526
 
527
+ @api.on_event("startup")
528
+ async def startup():
529
+ loop = asyncio.get_running_loop()
530
+ await loop.run_in_executor(INFERENCE_EXECUTOR, warmup)
531
+ asyncio.create_task(audio_engine_loop())
532
+
533
  @api.websocket("/ws/audio")
534
  async def websocket_endpoint(ws: WebSocket):
535
  await ws.accept()
536
 
537
+ voice_code = "af_bella"
538
+ speed = 1.0
 
539
 
540
  print(f"βœ… Client connected: {ws.client}")
541
 
 
554
  try:
555
  data = await ws.receive_json()
556
  except WebSocketDisconnect:
557
+ print("❌ Client disconnected cleanly")
558
  break
559
+ except Exception as e:
560
+ print(f"⚠️ Connection lost: {e}")
561
  break
562
 
563
+ if "config" in data:
564
+ voice_name = data.get("voice", "πŸ‡ΊπŸ‡Έ 🚺 Bella")
565
+ voice_code = VOICE_CHOICES.get(voice_name, voice_name)
566
+ speed = float(data.get("speed", speed))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
567
 
568
+ if "text" in data:
569
+ text = normalize_text(data.get("text", ""))
570
+ if text.strip():
571
+ await INFERENCE_QUEUE.put((ws, voice_code, speed, text))
572
 
573
+ if "flush" in data:
574
+ pass
 
 
 
 
 
 
 
 
 
 
575
 
576
  finally:
577
  heartbeat_task.cancel()
578
 
579
+ # ----------------------------
580
+ # GRADIO APP
581
+ # ----------------------------
582
+ with gr.Blocks(title="Kokoro TTS") as app:
583
+ gr.Markdown("## ⚑ Kokoro-82M (Official Pipeline, Low Latency)")
584
+ with gr.Row():
585
+ with gr.Column():
586
+ text_in = gr.Textbox(
587
+ label="Input Text",
588
+ lines=3,
589
+ value="The system is live. Use the Gradio UI, or connect to /ws/audio.",
590
+ )
591
+ voice_in = gr.Dropdown(
592
+ list(VOICE_CHOICES.keys()),
593
+ value="πŸ‡ΊπŸ‡Έ 🚺 Bella",
594
+ label="Voice",
595
+ )
596
+ speed_in = gr.Slider(0.5, 2.0, value=1.0, label="Speed")
597
+ btn = gr.Button("Generate", variant="primary")
598
+ with gr.Column():
599
+ audio_out = gr.Audio(streaming=True, autoplay=True, label="Audio Stream")
600
+
601
+ btn.click(gradio_stream, inputs=[text_in, voice_in, speed_in], outputs=[audio_out])
602
+
603
+ final_app = gr.mount_gradio_app(api, app, path="/")
604
+
605
  if __name__ == "__main__":
606
+ uvicorn.run(final_app, host="0.0.0.0", port=7860)