auralodyssey commited on
Commit
667ab5c
Β·
verified Β·
1 Parent(s): df0d1ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +206 -259
app.py CHANGED
@@ -298,280 +298,239 @@ import os
298
  import re
299
  import time
300
  import asyncio
301
- import uvloop
 
302
  import numpy as np
303
- import gradio as gr
304
  import torch
305
-
306
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
307
- from concurrent.futures import ThreadPoolExecutor
308
  import uvicorn
309
 
310
- # Official pipeline
311
  from kokoro import KPipeline
312
 
313
- # -------------------------
314
- # CPU + runtime tuning
315
- # -------------------------
316
- # Keep these conservative. HF CPU is usually 2 vCPU.
317
  os.environ.setdefault("OMP_NUM_THREADS", "2")
318
  os.environ.setdefault("MKL_NUM_THREADS", "2")
319
  os.environ.setdefault("NUMEXPR_NUM_THREADS", "2")
320
 
321
- torch.set_num_threads(2)
322
- torch.set_num_interop_threads(1)
 
 
 
323
 
324
- asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
 
 
 
 
 
325
 
326
  SAMPLE_RATE = 24000
327
 
328
- # -------------------------
329
- # Voices (use Kokoro voice ids)
330
- # -------------------------
 
 
 
331
  VOICE_CHOICES = {
332
- "πŸ‡ΊπŸ‡Έ 🚺 Heart": "af_heart",
333
- "πŸ‡ΊπŸ‡Έ 🚺 Bella": "af_bella",
334
- "πŸ‡ΊπŸ‡Έ 🚺 Nicole": "af_nicole",
335
- "πŸ‡ΊπŸ‡Έ 🚺 Aoede": "af_aoede",
336
- "πŸ‡ΊπŸ‡Έ 🚺 Kore": "af_kore",
337
- "πŸ‡ΊπŸ‡Έ 🚺 Sarah": "af_sarah",
338
- "πŸ‡ΊπŸ‡Έ 🚺 Nova": "af_nova",
339
- "πŸ‡ΊπŸ‡Έ 🚺 Sky": "af_sky",
340
- "πŸ‡ΊπŸ‡Έ 🚺 Alloy": "af_alloy",
341
- "πŸ‡ΊπŸ‡Έ 🚺 Jessica": "af_jessica",
342
- "πŸ‡ΊπŸ‡Έ 🚺 River": "af_river",
343
- "πŸ‡ΊπŸ‡Έ 🚹 Michael": "am_michael",
344
- "πŸ‡ΊπŸ‡Έ 🚹 Fenrir": "am_fenrir",
345
- "πŸ‡ΊπŸ‡Έ 🚹 Puck": "am_puck",
346
- "πŸ‡ΊπŸ‡Έ 🚹 Echo": "am_echo",
347
- "πŸ‡ΊπŸ‡Έ 🚹 Eric": "am_eric",
348
- "πŸ‡ΊπŸ‡Έ 🚹 Liam": "am_liam",
349
- "πŸ‡ΊπŸ‡Έ 🚹 Onyx": "am_onyx",
350
- "πŸ‡ΊπŸ‡Έ 🚹 Santa": "am_santa",
351
- "πŸ‡ΊπŸ‡Έ 🚹 Adam": "am_adam",
352
- "πŸ‡¬πŸ‡§ 🚺 Emma": "bf_emma",
353
- "πŸ‡¬πŸ‡§ 🚺 Isabella": "bf_isabella",
354
- "πŸ‡¬πŸ‡§ 🚺 Alice": "bf_alice",
355
- "πŸ‡¬πŸ‡§ 🚺 Lily": "bf_lily",
356
- "πŸ‡¬πŸ‡§ 🚹 George": "bm_george",
357
- "πŸ‡¬πŸ‡§ 🚹 Fable": "bm_fable",
358
- "πŸ‡¬πŸ‡§ 🚹 Lewis": "bm_lewis",
359
  "πŸ‡¬πŸ‡§ 🚹 Daniel": "bm_daniel",
360
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
- DEFAULT_VOICE_UI = "πŸ‡ΊπŸ‡Έ 🚺 Bella"
363
- DEFAULT_VOICE = VOICE_CHOICES[DEFAULT_VOICE_UI]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
- # -------------------------
366
- # Kokoro pipeline (global)
367
- # -------------------------
368
- print("πŸš€ BOOTING KOKORO (OFFICIAL PIPELINE)")
369
- PIPELINE = KPipeline(lang_code="a")
 
 
 
 
 
 
 
 
370
 
371
- # -------------------------
372
- # Helpers
373
- # -------------------------
374
- def _to_numpy_audio(audio):
375
- # Kokoro may return a torch.Tensor or numpy array
376
- if isinstance(audio, torch.Tensor):
377
- return audio.detach().cpu().numpy()
378
- return np.asarray(audio)
379
-
380
- def _float_to_int16(audio_f32):
381
- audio_f32 = np.clip(audio_f32, -1.0, 1.0).astype(np.float32)
382
- return (audio_f32 * 32767.0).astype(np.int16)
383
-
384
- def trim_silence(audio_f32, threshold=0.01, pad=240):
385
- # audio_f32 is float32, shape [N]
386
- if audio_f32.size == 0:
387
- return audio_f32
388
- mask = np.abs(audio_f32) > threshold
389
- if not np.any(mask):
390
- return audio_f32
391
- start = int(np.argmax(mask))
392
- end = int(len(mask) - np.argmax(mask[::-1]))
393
- start = max(0, start - pad)
394
- end = min(len(audio_f32), end + pad)
395
- return audio_f32[start:end]
396
-
397
- def crossfade_concat(a, b, overlap=1200):
398
- # overlap ~ 50ms at 24k
399
- if a is None:
400
- return b
401
- if b is None:
402
- return a
403
- if len(a) < overlap or len(b) < overlap:
404
- return np.concatenate([a, b])
405
-
406
- fade_out = np.linspace(1.0, 0.0, overlap, dtype=np.float32)
407
- fade_in = 1.0 - fade_out
408
-
409
- a_tail = a[-overlap:] * fade_out
410
- b_head = b[:overlap] * fade_in
411
-
412
- mixed = a_tail + b_head
413
- return np.concatenate([a[:-overlap], mixed, b[overlap:]])
414
-
415
- def tuned_splitter(text):
416
- # First chunk small for fast first packet, later chunks larger for efficiency
417
- parts = re.split(r"([.,!?;:\n]+)", text)
418
- buf = ""
419
- chunk_idx = 0
420
- for p in parts:
421
- buf += p
422
- if chunk_idx == 0:
423
- threshold = 80
424
- elif chunk_idx == 1:
425
- threshold = 140
426
- elif chunk_idx == 2:
427
- threshold = 220
428
- else:
429
- threshold = 320
430
-
431
- if re.search(r"[.,!?;:\n]$", buf) and len(buf) >= threshold:
432
- s = buf.strip()
433
- if s:
434
- yield s
435
- chunk_idx += 1
436
- buf = ""
437
-
438
- s = buf.strip()
439
- if s:
440
- yield s
441
-
442
- def normalize_names_minimally(text):
443
- # Cheap heuristics to reduce skipped acronyms and CamelCase
444
- # 1) Split ALLCAPS as letters: "AI" -> "A I"
445
- text = re.sub(r"\b([A-Z]{2,})\b", lambda m: " ".join(list(m.group(1))), text)
446
- # 2) Split CamelCase boundaries: "OpenAI" -> "Open AI"
447
- text = re.sub(r"([a-z])([A-Z])", r"\1 \2", text)
448
- # Keep your Kokoro IPA hint example
449
- text = text.replace("Kokoro", "Kokoro") # keep as-is unless you inject IPA tags in client
450
  return text
451
 
452
- def synthesize_one_chunk(chunk, voice_id, speed):
453
- # Make sure no nested splitting happens inside a chunk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
  with torch.inference_mode():
455
- gen = PIPELINE(
456
- chunk,
457
  voice=voice_id,
458
  speed=float(speed),
459
- split_pattern=r"\n+", # chunk text has no newlines in practice
460
  )
461
- # gen yields (gs, ps, audio)
462
- out_audio = None
463
  for _, _, audio in gen:
464
- audio_np = _to_numpy_audio(audio).astype(np.float32)
465
- audio_np = trim_silence(audio_np)
466
- out_audio = crossfade_concat(out_audio, audio_np, overlap=1200)
467
- return out_audio
468
-
469
- # -------------------------
470
- # Warmup to remove cold start latency
471
- # -------------------------
472
  def warmup():
473
  try:
474
  t0 = time.time()
475
- _ = synthesize_one_chunk("Warmup.", DEFAULT_VOICE, 1.0)
476
- dt = time.time() - t0
477
- print(f"βœ… Warmup done in {dt:.2f}s")
478
  except Exception as e:
479
- print(f"⚠️ Warmup failed: {e}")
480
-
481
- # Run warmup in background thread once
482
- WARMUP_EXECUTOR = ThreadPoolExecutor(max_workers=1)
483
- WARMUP_EXECUTOR.submit(warmup)
484
-
485
- # -------------------------
486
- # Streaming strategy
487
- # -------------------------
488
- def stream_generator(text, voice_ui, speed):
489
- voice_id = VOICE_CHOICES.get(voice_ui, DEFAULT_VOICE)
490
- text = normalize_names_minimally(text)
491
-
492
- print("--- START UI STREAM ---")
493
- first = True
494
-
495
- # Buffer audio after the first packet to reduce gaps from too many tiny yields
496
- buffer_audio = None
497
- buffer_min_seconds = 0.9
498
-
499
- for chunk_idx, chunk in enumerate(tuned_splitter(text)):
500
- t0 = time.time()
501
- audio_f32 = synthesize_one_chunk(chunk, voice_id, speed)
502
- if audio_f32 is None or len(audio_f32) == 0:
503
- continue
504
-
505
- dt = time.time() - t0
506
- print(f"⚑ UI chunk {chunk_idx}: {len(chunk)} chars in {dt:.2f}s")
507
-
508
- if first:
509
- # First packet: yield immediately for low perceived latency
510
- first = False
511
- yield (SAMPLE_RATE, _float_to_int16(audio_f32))
512
- continue
513
-
514
- buffer_audio = crossfade_concat(buffer_audio, audio_f32, overlap=1200)
515
- if buffer_audio is not None:
516
- if len(buffer_audio) >= int(buffer_min_seconds * SAMPLE_RATE):
517
- yield (SAMPLE_RATE, _float_to_int16(buffer_audio))
518
- buffer_audio = None
519
-
520
- if buffer_audio is not None and len(buffer_audio) > 0:
521
- yield (SAMPLE_RATE, _float_to_int16(buffer_audio))
522
 
523
- print("--- END UI STREAM ---")
524
-
525
- # -------------------------
526
- # API (FastAPI + WS)
527
- # -------------------------
528
  api = FastAPI()
529
 
530
- # One inference worker is the right call on 2 vCPU
531
  INFERENCE_EXECUTOR = ThreadPoolExecutor(max_workers=1)
532
- INFERENCE_QUEUE = asyncio.Queue()
 
 
 
 
533
 
534
  async def audio_engine_loop():
535
  print("⚑ API AUDIO PIPELINE STARTED")
536
  loop = asyncio.get_running_loop()
537
 
538
  while True:
539
- job = await INFERENCE_QUEUE.get()
540
- text, voice_id, speed, ws = job
541
 
542
- try:
543
- if ws.client_state.value > 1:
544
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
 
546
- # Run synthesis in the single worker thread
547
- audio_f32 = await loop.run_in_executor(
548
- INFERENCE_EXECUTOR,
549
- lambda: synthesize_one_chunk(text, voice_id, speed),
550
- )
551
 
552
- if audio_f32 is None or len(audio_f32) == 0:
553
- continue
554
 
555
- pcm = _float_to_int16(audio_f32).tobytes()
556
  try:
557
- await ws.send_bytes(pcm)
558
  except Exception:
559
- pass
560
-
561
- except Exception as e:
562
- print(f"API Engine Error: {e}")
563
 
564
  @api.on_event("startup")
565
  async def startup():
 
 
566
  asyncio.create_task(audio_engine_loop())
567
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568
  @api.websocket("/ws/audio")
569
  async def websocket_endpoint(ws: WebSocket):
570
  await ws.accept()
571
 
572
- voice_id = DEFAULT_VOICE
573
- speed = 1.0
574
- loop = asyncio.get_running_loop()
575
 
576
  print(f"βœ… Client connected: {ws.client}")
577
 
@@ -590,54 +549,42 @@ async def websocket_endpoint(ws: WebSocket):
590
  try:
591
  data = await ws.receive_json()
592
  except WebSocketDisconnect:
 
593
  break
594
  except Exception:
595
  break
596
 
597
- if "config" in data:
598
- voice_ui = data.get("voice", DEFAULT_VOICE_UI)
599
- voice_id = VOICE_CHOICES.get(voice_ui, DEFAULT_VOICE)
600
- speed = float(data.get("speed", speed))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
601
 
602
- if "text" in data:
603
- raw = data["text"]
604
- raw = normalize_names_minimally(raw)
605
-
606
- # First chunk tiny, rest larger, same as UI
607
- for chunk in tuned_splitter(raw):
608
- if not chunk.strip():
609
- continue
610
- await INFERENCE_QUEUE.put((chunk, voice_id, speed, ws))
611
-
612
- if "flush" in data:
613
- pass
614
-
615
- except Exception as e:
616
- print(f"πŸ”₯ Critical WS Error: {e}")
617
  finally:
618
  heartbeat_task.cancel()
619
 
620
- # -------------------------
621
- # Gradio UI
622
- # -------------------------
623
- with gr.Blocks(title="Kokoro TTS") as app:
624
- gr.Markdown("## ⚑ Kokoro-82M (Official Pipeline, Low Latency)")
625
- with gr.Row():
626
- with gr.Column():
627
- text_in = gr.Textbox(
628
- label="Input Text",
629
- lines=4,
630
- value="The system is live. Use the UI or connect to /ws/audio.",
631
- )
632
- voice_in = gr.Dropdown(list(VOICE_CHOICES.keys()), value=DEFAULT_VOICE_UI, label="Voice")
633
- speed_in = gr.Slider(0.5, 2.0, value=1.0, label="Speed")
634
- btn = gr.Button("Generate", variant="primary")
635
- with gr.Column():
636
- audio_out = gr.Audio(streaming=True, autoplay=True, label="Audio Stream")
637
-
638
- btn.click(stream_generator, inputs=[text_in, voice_in, speed_in], outputs=[audio_out])
639
-
640
- final_app = gr.mount_gradio_app(api, app, path="/")
641
-
642
  if __name__ == "__main__":
643
- uvicorn.run(final_app, host="0.0.0.0", port=7860)
 
298
  import re
299
  import time
300
  import asyncio
301
+ from concurrent.futures import ThreadPoolExecutor
302
+
303
  import numpy as np
 
304
  import torch
 
305
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
 
306
  import uvicorn
307
 
 
308
  from kokoro import KPipeline
309
 
310
+ # ----------------------------
311
+ # CPU THREAD CAP (HF free tier is typically 2 vCPU)
312
+ # ----------------------------
 
313
  os.environ.setdefault("OMP_NUM_THREADS", "2")
314
  os.environ.setdefault("MKL_NUM_THREADS", "2")
315
  os.environ.setdefault("NUMEXPR_NUM_THREADS", "2")
316
 
317
+ try:
318
+ torch.set_num_threads(int(os.environ.get("TORCH_NUM_THREADS", "2")))
319
+ torch.set_num_interop_threads(int(os.environ.get("TORCH_NUM_INTEROP_THREADS", "1")))
320
+ except Exception:
321
+ pass
322
 
323
+ # Optional uvloop (safe to skip if not installed)
324
+ try:
325
+ import uvloop # type: ignore
326
+ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
327
+ except Exception:
328
+ pass
329
 
330
  SAMPLE_RATE = 24000
331
 
332
+ print("πŸš€ BOOTING KOKORO API ONLY (OFFICIAL PIPELINE)")
333
+
334
+ # ----------------------------
335
+ # VOICES (UI label -> kokoro voice id)
336
+ # Client can send either label or id.
337
+ # ----------------------------
338
  VOICE_CHOICES = {
339
+ "πŸ‡ΊπŸ‡Έ 🚺 Heart": "af_heart", "πŸ‡ΊπŸ‡Έ 🚺 Bella": "af_bella", "πŸ‡ΊπŸ‡Έ 🚺 Nicole": "af_nicole",
340
+ "πŸ‡ΊπŸ‡Έ 🚺 Aoede": "af_aoede", "πŸ‡ΊπŸ‡Έ 🚺 Kore": "af_kore", "πŸ‡ΊπŸ‡Έ 🚺 Sarah": "af_sarah",
341
+ "πŸ‡ΊπŸ‡Έ 🚺 Nova": "af_nova", "πŸ‡ΊπŸ‡Έ 🚺 Sky": "af_sky", "πŸ‡ΊπŸ‡Έ 🚺 Alloy": "af_alloy",
342
+ "πŸ‡ΊπŸ‡Έ 🚺 Jessica": "af_jessica", "πŸ‡ΊπŸ‡Έ 🚺 River": "af_river", "πŸ‡ΊπŸ‡Έ 🚹 Michael": "am_michael",
343
+ "πŸ‡ΊπŸ‡Έ 🚹 Fenrir": "am_fenrir", "πŸ‡ΊπŸ‡Έ 🚹 Puck": "am_puck", "πŸ‡ΊπŸ‡Έ 🚹 Echo": "am_echo",
344
+ "πŸ‡ΊπŸ‡Έ 🚹 Eric": "am_eric", "πŸ‡ΊπŸ‡Έ 🚹 Liam": "am_liam", "πŸ‡ΊπŸ‡Έ 🚹 Onyx": "am_onyx",
345
+ "πŸ‡ΊπŸ‡Έ 🚹 Santa": "am_santa", "πŸ‡ΊπŸ‡Έ 🚹 Adam": "am_adam", "πŸ‡¬πŸ‡§ 🚺 Emma": "bf_emma",
346
+ "πŸ‡¬πŸ‡§ 🚺 Isabella": "bf_isabella", "πŸ‡¬πŸ‡§ 🚺 Alice": "bf_alice", "πŸ‡¬πŸ‡§ 🚺 Lily": "bf_lily",
347
+ "πŸ‡¬πŸ‡§ 🚹 George": "bm_george", "πŸ‡¬πŸ‡§ 🚹 Fable": "bm_fable", "πŸ‡¬πŸ‡§ 🚹 Lewis": "bm_lewis",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  "πŸ‡¬πŸ‡§ 🚹 Daniel": "bm_daniel",
349
  }
350
+ ALLOWED_VOICE_IDS = set(VOICE_CHOICES.values())
351
+
352
+ # βœ… DEFAULT VOICE = ONYX
353
+ DEFAULT_VOICE_LABEL = "πŸ‡ΊπŸ‡Έ 🚹 Onyx"
354
+ DEFAULT_VOICE_ID = VOICE_CHOICES[DEFAULT_VOICE_LABEL]
355
+ DEFAULT_SPEED = 1.0
356
+
357
+ def voice_to_lang_code(voice_id: str) -> str:
358
+ if voice_id.startswith("bf_") or voice_id.startswith("bm_"):
359
+ return "b" # British
360
+ return "a" # American
361
+
362
+ # ----------------------------
363
+ # PIPELINES (keep hot in RAM)
364
+ # ----------------------------
365
+ PIPELINES = {
366
+ "a": KPipeline(lang_code="a"),
367
+ "b": KPipeline(lang_code="b"),
368
+ }
369
 
370
+ # ----------------------------
371
+ # TEXT NORMALIZATION (from your provided docs)
372
+ # ----------------------------
373
+ _SENT_BOUNDARY = re.compile(r"([.!?;:])\s+")
374
+ _MULTI_NL = re.compile(r"\n{3,}")
375
+ _CAMEL = re.compile(r"([a-z])([A-Z])")
376
+ _ALLCAPS = re.compile(r"\b([A-Z]{2,})\b")
377
+
378
+ def normalize_text(text: str) -> str:
379
+ if not text:
380
+ return ""
381
+ return text.replace("Kokoro", "[Kokoro](/kˈOkΙ™ΙΉO/)")
382
+
383
+ def reduce_name_skips(text: str) -> str:
384
+ if not text:
385
+ return ""
386
+ text = _ALLCAPS.sub(lambda m: " ".join(list(m.group(1))), text)
387
+ text = _CAMEL.sub(r"\1 \2", text)
388
+ return text
389
 
390
+ def inject_newlines_for_fast_stream(text: str) -> str:
391
+ text = normalize_text(text).strip()
392
+ if not text:
393
+ return ""
394
+ text = _SENT_BOUNDARY.sub(r"\1\n", text)
395
+ text = _MULTI_NL.sub("\n\n", text)
396
+
397
+ # Ensure a small first segment for faster first audio
398
+ if "\n" not in text and len(text) > 90:
399
+ cut = text.rfind(" ", 0, 70)
400
+ if cut < 35:
401
+ cut = 70
402
+ text = text[:cut].strip() + "\n" + text[cut:].strip()
403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  return text
405
 
406
+ # ----------------------------
407
+ # AUDIO CONVERSION
408
+ # ----------------------------
409
+ def audio_to_int16_np(audio):
410
+ if isinstance(audio, torch.Tensor):
411
+ a = audio.detach().cpu()
412
+ a = torch.clamp(a, -1.0, 1.0)
413
+ return (a * 32767.0).to(torch.int16).numpy()
414
+
415
+ a = np.asarray(audio)
416
+ a = np.clip(a, -1.0, 1.0)
417
+ return (a * 32767.0).astype(np.int16)
418
+
419
+ def audio_to_pcm_bytes(audio) -> bytes:
420
+ return audio_to_int16_np(audio).tobytes()
421
+
422
+ # ----------------------------
423
+ # OFFICIAL GENERATION PATH (single pipeline call per request)
424
+ # ----------------------------
425
+ def kokoro_audio_iter(text: str, voice_id: str, speed: float):
426
+ lang_code = voice_to_lang_code(voice_id)
427
+ pipeline = PIPELINES[lang_code]
428
+ prepared = inject_newlines_for_fast_stream(text)
429
+ if not prepared:
430
+ return
431
+
432
  with torch.inference_mode():
433
+ gen = pipeline(
434
+ prepared,
435
  voice=voice_id,
436
  speed=float(speed),
437
+ split_pattern=r"\n+",
438
  )
 
 
439
  for _, _, audio in gen:
440
+ yield audio
441
+
 
 
 
 
 
 
442
  def warmup():
443
  try:
444
  t0 = time.time()
445
+ for _ in kokoro_audio_iter("Hello.", DEFAULT_VOICE_ID, 1.0):
446
+ break
447
+ print(f"βœ… WARMUP DONE in {time.time() - t0:.2f}s")
448
  except Exception as e:
449
+ print(f"⚠️ WARMUP FAILED: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
 
451
+ # ----------------------------
452
+ # FASTAPI APP (API ONLY)
453
+ # ----------------------------
 
 
454
  api = FastAPI()
455
 
 
456
  INFERENCE_EXECUTOR = ThreadPoolExecutor(max_workers=1)
457
+ INFERENCE_QUEUE: asyncio.Queue = asyncio.Queue(maxsize=64)
458
+
459
+ @api.get("/health")
460
+ async def health():
461
+ return {"ok": True, "model": "kokoro", "sample_rate": SAMPLE_RATE, "default_voice": DEFAULT_VOICE_ID}
462
 
463
  async def audio_engine_loop():
464
  print("⚑ API AUDIO PIPELINE STARTED")
465
  loop = asyncio.get_running_loop()
466
 
467
  while True:
468
+ ws, voice_id, speed, text = await INFERENCE_QUEUE.get()
 
469
 
470
+ if ws.client_state.value > 1:
471
+ continue
472
+
473
+ frame_q: asyncio.Queue = asyncio.Queue(maxsize=8)
474
+
475
+ def _worker():
476
+ try:
477
+ first = True
478
+ started = time.time()
479
+ for audio in kokoro_audio_iter(text, voice_id, speed):
480
+ b = audio_to_pcm_bytes(audio)
481
+ loop.call_soon_threadsafe(frame_q.put_nowait, b)
482
+ if first:
483
+ first = False
484
+ dt = time.time() - started
485
+ print(f"⚑ first audio ready in {dt:.2f}s")
486
+ loop.call_soon_threadsafe(frame_q.put_nowait, None)
487
+ except Exception as e:
488
+ print(f"API Worker Error: {e}")
489
+ try:
490
+ loop.call_soon_threadsafe(frame_q.put_nowait, None)
491
+ except Exception:
492
+ pass
493
+
494
+ INFERENCE_EXECUTOR.submit(_worker)
495
 
496
+ while True:
497
+ frame = await frame_q.get()
498
+ if frame is None:
499
+ break
 
500
 
501
+ if ws.client_state.value > 1:
502
+ break
503
 
 
504
  try:
505
+ await ws.send_bytes(frame)
506
  except Exception:
507
+ break
 
 
 
508
 
509
  @api.on_event("startup")
510
  async def startup():
511
+ loop = asyncio.get_running_loop()
512
+ await loop.run_in_executor(INFERENCE_EXECUTOR, warmup)
513
  asyncio.create_task(audio_engine_loop())
514
 
515
+ def resolve_voice(value: str) -> str:
516
+ if not value:
517
+ return DEFAULT_VOICE_ID
518
+
519
+ if value in VOICE_CHOICES:
520
+ vid = VOICE_CHOICES[value]
521
+ else:
522
+ vid = value.strip()
523
+
524
+ if vid not in ALLOWED_VOICE_IDS:
525
+ return DEFAULT_VOICE_ID
526
+ return vid
527
+
528
  @api.websocket("/ws/audio")
529
  async def websocket_endpoint(ws: WebSocket):
530
  await ws.accept()
531
 
532
+ voice_id = DEFAULT_VOICE_ID # βœ… default Onyx
533
+ speed = DEFAULT_SPEED
 
534
 
535
  print(f"βœ… Client connected: {ws.client}")
536
 
 
549
  try:
550
  data = await ws.receive_json()
551
  except WebSocketDisconnect:
552
+ print("❌ Client disconnected")
553
  break
554
  except Exception:
555
  break
556
 
557
+ is_config = ("config" in data) or (data.get("type") == "config")
558
+ if is_config:
559
+ voice_id = resolve_voice(str(data.get("voice", voice_id)))
560
+ try:
561
+ speed = float(data.get("speed", speed))
562
+ except Exception:
563
+ speed = DEFAULT_SPEED
564
+
565
+ has_text = ("text" in data) or (data.get("type") == "text")
566
+ if has_text:
567
+ raw = data.get("text", "")
568
+ raw = reduce_name_skips(raw)
569
+ raw = normalize_text(raw)
570
+
571
+ if raw and raw.strip():
572
+ try:
573
+ INFERENCE_QUEUE.put_nowait((ws, voice_id, speed, raw))
574
+ except asyncio.QueueFull:
575
+ try:
576
+ await ws.send_json({"type": "error", "message": "server_busy"})
577
+ except Exception:
578
+ pass
579
+
580
+ if "flush" in data or data.get("type") == "flush":
581
+ try:
582
+ await ws.send_json({"type": "flushed"})
583
+ except Exception:
584
+ pass
585
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
586
  finally:
587
  heartbeat_task.cancel()
588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
589
  if __name__ == "__main__":
590
+ uvicorn.run(api, host="0.0.0.0", port=7860)