auralodyssey commited on
Commit
9bca27a
ยท
verified ยท
1 Parent(s): 667ab5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +343 -158
app.py CHANGED
@@ -295,32 +295,28 @@
295
  # if __name__ == "__main__":
296
  # uvicorn.run(final_app, host="0.0.0.0", port=7860)
297
  import os
298
- import re
299
  import time
 
300
  import asyncio
 
301
  from concurrent.futures import ThreadPoolExecutor
302
 
303
  import numpy as np
304
- import torch
 
 
 
305
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
306
  import uvicorn
307
 
308
- from kokoro import KPipeline
309
-
310
- # ----------------------------
311
- # CPU THREAD CAP (HF free tier is typically 2 vCPU)
312
- # ----------------------------
313
  os.environ.setdefault("OMP_NUM_THREADS", "2")
314
  os.environ.setdefault("MKL_NUM_THREADS", "2")
315
  os.environ.setdefault("NUMEXPR_NUM_THREADS", "2")
316
 
317
- try:
318
- torch.set_num_threads(int(os.environ.get("TORCH_NUM_THREADS", "2")))
319
- torch.set_num_interop_threads(int(os.environ.get("TORCH_NUM_INTEROP_THREADS", "1")))
320
- except Exception:
321
- pass
322
-
323
- # Optional uvloop (safe to skip if not installed)
324
  try:
325
  import uvloop # type: ignore
326
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -329,12 +325,13 @@ except Exception:
329
 
330
  SAMPLE_RATE = 24000
331
 
332
- print("๐Ÿš€ BOOTING KOKORO API ONLY (OFFICIAL PIPELINE)")
 
 
 
 
 
333
 
334
- # ----------------------------
335
- # VOICES (UI label -> kokoro voice id)
336
- # Client can send either label or id.
337
- # ----------------------------
338
  VOICE_CHOICES = {
339
  "๐Ÿ‡บ๐Ÿ‡ธ ๐Ÿšบ Heart": "af_heart", "๐Ÿ‡บ๐Ÿ‡ธ ๐Ÿšบ Bella": "af_bella", "๐Ÿ‡บ๐Ÿ‡ธ ๐Ÿšบ Nicole": "af_nicole",
340
  "๐Ÿ‡บ๐Ÿ‡ธ ๐Ÿšบ Aoede": "af_aoede", "๐Ÿ‡บ๐Ÿ‡ธ ๐Ÿšบ Kore": "af_kore", "๐Ÿ‡บ๐Ÿ‡ธ ๐Ÿšบ Sarah": "af_sarah",
@@ -350,140 +347,330 @@ VOICE_CHOICES = {
350
  ALLOWED_VOICE_IDS = set(VOICE_CHOICES.values())
351
 
352
  # โœ… DEFAULT VOICE = ONYX
353
- DEFAULT_VOICE_LABEL = "๐Ÿ‡บ๐Ÿ‡ธ ๐Ÿšน Onyx"
354
- DEFAULT_VOICE_ID = VOICE_CHOICES[DEFAULT_VOICE_LABEL]
355
  DEFAULT_SPEED = 1.0
356
 
357
- def voice_to_lang_code(voice_id: str) -> str:
358
- if voice_id.startswith("bf_") or voice_id.startswith("bm_"):
359
- return "b" # British
360
- return "a" # American
361
-
362
- # ----------------------------
363
- # PIPELINES (keep hot in RAM)
364
- # ----------------------------
365
- PIPELINES = {
366
- "a": KPipeline(lang_code="a"),
367
- "b": KPipeline(lang_code="b"),
368
- }
369
 
370
- # ----------------------------
371
- # TEXT NORMALIZATION (from your provided docs)
372
- # ----------------------------
373
- _SENT_BOUNDARY = re.compile(r"([.!?;:])\s+")
374
- _MULTI_NL = re.compile(r"\n{3,}")
375
- _CAMEL = re.compile(r"([a-z])([A-Z])")
376
- _ALLCAPS = re.compile(r"\b([A-Z]{2,})\b")
377
 
378
- def normalize_text(text: str) -> str:
379
- if not text:
380
- return ""
381
- return text.replace("Kokoro", "[Kokoro](/kหˆOkษ™ษนO/)")
 
 
 
382
 
383
- def reduce_name_skips(text: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  if not text:
385
  return ""
386
- text = _ALLCAPS.sub(lambda m: " ".join(list(m.group(1))), text)
387
- text = _CAMEL.sub(r"\1 \2", text)
 
 
388
  return text
389
 
390
- def inject_newlines_for_fast_stream(text: str) -> str:
391
- text = normalize_text(text).strip()
392
- if not text:
393
- return ""
394
- text = _SENT_BOUNDARY.sub(r"\1\n", text)
395
- text = _MULTI_NL.sub("\n\n", text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
 
397
- # Ensure a small first segment for faster first audio
398
- if "\n" not in text and len(text) > 90:
399
- cut = text.rfind(" ", 0, 70)
400
- if cut < 35:
401
- cut = 70
402
- text = text[:cut].strip() + "\n" + text[cut:].strip()
403
 
404
- return text
 
 
 
405
 
406
- # ----------------------------
407
- # AUDIO CONVERSION
408
- # ----------------------------
409
- def audio_to_int16_np(audio):
410
- if isinstance(audio, torch.Tensor):
411
- a = audio.detach().cpu()
412
- a = torch.clamp(a, -1.0, 1.0)
413
- return (a * 32767.0).to(torch.int16).numpy()
414
-
415
- a = np.asarray(audio)
416
- a = np.clip(a, -1.0, 1.0)
417
- return (a * 32767.0).astype(np.int16)
418
-
419
- def audio_to_pcm_bytes(audio) -> bytes:
420
- return audio_to_int16_np(audio).tobytes()
421
-
422
- # ----------------------------
423
- # OFFICIAL GENERATION PATH (single pipeline call per request)
424
- # ----------------------------
425
- def kokoro_audio_iter(text: str, voice_id: str, speed: float):
426
- lang_code = voice_to_lang_code(voice_id)
427
- pipeline = PIPELINES[lang_code]
428
- prepared = inject_newlines_for_fast_stream(text)
429
- if not prepared:
430
- return
431
-
432
- with torch.inference_mode():
433
- gen = pipeline(
434
- prepared,
435
- voice=voice_id,
436
- speed=float(speed),
437
- split_pattern=r"\n+",
438
- )
439
- for _, _, audio in gen:
440
- yield audio
441
-
442
- def warmup():
443
- try:
444
- t0 = time.time()
445
- for _ in kokoro_audio_iter("Hello.", DEFAULT_VOICE_ID, 1.0):
446
- break
447
- print(f"โœ… WARMUP DONE in {time.time() - t0:.2f}s")
448
- except Exception as e:
449
- print(f"โš ๏ธ WARMUP FAILED: {e}")
450
 
451
- # ----------------------------
452
- # FASTAPI APP (API ONLY)
453
- # ----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
  api = FastAPI()
455
 
 
456
  INFERENCE_EXECUTOR = ThreadPoolExecutor(max_workers=1)
457
- INFERENCE_QUEUE: asyncio.Queue = asyncio.Queue(maxsize=64)
 
 
 
 
 
 
 
 
 
 
458
 
459
  @api.get("/health")
460
  async def health():
461
- return {"ok": True, "model": "kokoro", "sample_rate": SAMPLE_RATE, "default_voice": DEFAULT_VOICE_ID}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462
 
463
- async def audio_engine_loop():
 
 
 
 
 
 
464
  print("โšก API AUDIO PIPELINE STARTED")
465
  loop = asyncio.get_running_loop()
466
 
467
  while True:
468
- ws, voice_id, speed, text = await INFERENCE_QUEUE.get()
469
 
470
  if ws.client_state.value > 1:
471
  continue
472
 
 
473
  frame_q: asyncio.Queue = asyncio.Queue(maxsize=8)
 
474
 
475
- def _worker():
476
  try:
 
 
 
 
 
 
477
  first = True
478
- started = time.time()
479
- for audio in kokoro_audio_iter(text, voice_id, speed):
480
- b = audio_to_pcm_bytes(audio)
481
- loop.call_soon_threadsafe(frame_q.put_nowait, b)
 
 
 
 
 
 
 
 
 
 
 
482
  if first:
 
483
  first = False
484
- dt = time.time() - started
485
- print(f"โšก first audio ready in {dt:.2f}s")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
  loop.call_soon_threadsafe(frame_q.put_nowait, None)
 
 
 
487
  except Exception as e:
488
  print(f"API Worker Error: {e}")
489
  try:
@@ -491,7 +678,10 @@ async def audio_engine_loop():
491
  except Exception:
492
  pass
493
 
494
- INFERENCE_EXECUTOR.submit(_worker)
 
 
 
495
 
496
  while True:
497
  frame = await frame_q.get()
@@ -499,36 +689,23 @@ async def audio_engine_loop():
499
  break
500
 
501
  if ws.client_state.value > 1:
 
502
  break
503
 
504
  try:
505
  await ws.send_bytes(frame)
 
 
 
506
  except Exception:
 
507
  break
508
 
509
- @api.on_event("startup")
510
- async def startup():
511
- loop = asyncio.get_running_loop()
512
- await loop.run_in_executor(INFERENCE_EXECUTOR, warmup)
513
- asyncio.create_task(audio_engine_loop())
514
-
515
- def resolve_voice(value: str) -> str:
516
- if not value:
517
- return DEFAULT_VOICE_ID
518
-
519
- if value in VOICE_CHOICES:
520
- vid = VOICE_CHOICES[value]
521
- else:
522
- vid = value.strip()
523
-
524
- if vid not in ALLOWED_VOICE_IDS:
525
- return DEFAULT_VOICE_ID
526
- return vid
527
-
528
  @api.websocket("/ws/audio")
529
  async def websocket_endpoint(ws: WebSocket):
530
  await ws.accept()
531
 
 
532
  voice_id = DEFAULT_VOICE_ID # โœ… default Onyx
533
  speed = DEFAULT_SPEED
534
 
@@ -554,34 +731,42 @@ async def websocket_endpoint(ws: WebSocket):
554
  except Exception:
555
  break
556
 
557
- is_config = ("config" in data) or (data.get("type") == "config")
558
- if is_config:
559
  voice_id = resolve_voice(str(data.get("voice", voice_id)))
560
  try:
561
  speed = float(data.get("speed", speed))
562
  except Exception:
563
  speed = DEFAULT_SPEED
 
 
 
 
 
 
564
 
565
- has_text = ("text" in data) or (data.get("type") == "text")
566
- if has_text:
567
- raw = data.get("text", "")
568
- raw = reduce_name_skips(raw)
569
- raw = normalize_text(raw)
570
-
571
- if raw and raw.strip():
572
- try:
573
- INFERENCE_QUEUE.put_nowait((ws, voice_id, speed, raw))
574
- except asyncio.QueueFull:
575
- try:
576
- await ws.send_json({"type": "error", "message": "server_busy"})
577
- except Exception:
578
- pass
579
 
580
- if "flush" in data or data.get("type") == "flush":
581
  try:
582
- await ws.send_json({"type": "flushed"})
583
- except Exception:
584
- pass
 
 
 
585
 
586
  finally:
587
  heartbeat_task.cancel()
 
295
  # if __name__ == "__main__":
296
  # uvicorn.run(final_app, host="0.0.0.0", port=7860)
297
  import os
298
+ import json
299
  import time
300
+ import re
301
  import asyncio
302
+ import threading
303
  from concurrent.futures import ThreadPoolExecutor
304
 
305
  import numpy as np
306
+ import onnxruntime as ort
307
+ from huggingface_hub import hf_hub_download
308
+ from misaki import en
309
+ from functools import lru_cache
310
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
311
  import uvicorn
312
 
313
+ # =========================================================
314
+ # HF CPU BOX TUNING (2 vCPU)
315
+ # =========================================================
 
 
316
  os.environ.setdefault("OMP_NUM_THREADS", "2")
317
  os.environ.setdefault("MKL_NUM_THREADS", "2")
318
  os.environ.setdefault("NUMEXPR_NUM_THREADS", "2")
319
 
 
 
 
 
 
 
 
320
  try:
321
  import uvloop # type: ignore
322
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
 
325
 
326
  SAMPLE_RATE = 24000
327
 
328
+ # =========================================================
329
+ # ONNX KOKORO CONFIG (YOUR ONNX STYLE)
330
+ # =========================================================
331
+ MODEL_REPO = "onnx-community/Kokoro-82M-v1.0-ONNX"
332
+ MODEL_FILE = "onnx/model.onnx"
333
+ TOKENIZER_FILE = "tokenizer.json"
334
 
 
 
 
 
335
  VOICE_CHOICES = {
336
  "๐Ÿ‡บ๐Ÿ‡ธ ๐Ÿšบ Heart": "af_heart", "๐Ÿ‡บ๐Ÿ‡ธ ๐Ÿšบ Bella": "af_bella", "๐Ÿ‡บ๐Ÿ‡ธ ๐Ÿšบ Nicole": "af_nicole",
337
  "๐Ÿ‡บ๐Ÿ‡ธ ๐Ÿšบ Aoede": "af_aoede", "๐Ÿ‡บ๐Ÿ‡ธ ๐Ÿšบ Kore": "af_kore", "๐Ÿ‡บ๐Ÿ‡ธ ๐Ÿšบ Sarah": "af_sarah",
 
347
  ALLOWED_VOICE_IDS = set(VOICE_CHOICES.values())
348
 
349
  # โœ… DEFAULT VOICE = ONYX
350
+ DEFAULT_VOICE_ID = "am_onyx"
 
351
  DEFAULT_SPEED = 1.0
352
 
353
+ print("๐Ÿš€ BOOTING ONNX KOKORO API (LOW LATENCY, API ONLY)")
 
 
 
 
 
 
 
 
 
 
 
354
 
355
+ # =========================================================
356
+ # 1) G2P
357
+ # =========================================================
358
+ G2P = en.G2P(trf=False, british=False, fallback=None)
 
 
 
359
 
360
+ # =========================================================
361
+ # 2) TOKENIZER
362
+ # =========================================================
363
+ vocab_path = hf_hub_download(repo_id=MODEL_REPO, filename=TOKENIZER_FILE)
364
+ with open(vocab_path, "r", encoding="utf-8") as f:
365
+ data = json.load(f)
366
+ TOKENIZER = data["model"]["vocab"] if "model" in data else data.get("vocab", {})
367
 
368
+ # =========================================================
369
+ # 3) VOICES (LAZY LOAD, CACHE)
370
+ # =========================================================
371
+ VOICE_CACHE = {} # voice_id -> np.ndarray (T,1,256)
372
+
373
+ def _load_voice_bin(voice_id: str) -> np.ndarray:
374
+ path = hf_hub_download(repo_id=MODEL_REPO, filename=f"voices/{voice_id}.bin")
375
+ return np.fromfile(path, dtype=np.float32).reshape(-1, 1, 256)
376
+
377
+ def get_voice(voice_id_or_label: str) -> np.ndarray:
378
+ vid = VOICE_CHOICES.get(voice_id_or_label, voice_id_or_label).strip()
379
+ if vid not in ALLOWED_VOICE_IDS:
380
+ vid = DEFAULT_VOICE_ID
381
+
382
+ if vid not in VOICE_CACHE:
383
+ try:
384
+ print(f"โฌ‡๏ธ Loading Voice: {vid}")
385
+ VOICE_CACHE[vid] = _load_voice_bin(vid)
386
+ except Exception:
387
+ if "af_bella" not in VOICE_CACHE:
388
+ print("โš ๏ธ Voice load failed, falling back to af_bella")
389
+ VOICE_CACHE["af_bella"] = _load_voice_bin("af_bella")
390
+ return VOICE_CACHE["af_bella"]
391
+
392
+ return VOICE_CACHE[vid]
393
+
394
+ # =========================================================
395
+ # 4) ONNX SESSION (TUNED FOR 2 vCPU)
396
+ # =========================================================
397
+ model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
398
+
399
+ sess_options = ort.SessionOptions()
400
+ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
401
+ sess_options.add_session_config_entry("session.intra_op.allow_spinning", "0")
402
+
403
+ # On 2 vCPU, keep it tight
404
+ sess_options.intra_op_num_threads = int(os.environ.get("ORT_INTRA_OP_THREADS", "2"))
405
+ sess_options.inter_op_num_threads = int(os.environ.get("ORT_INTER_OP_THREADS", "1"))
406
+
407
+ SESSION = ort.InferenceSession(model_path, sess_options, providers=["CPUExecutionProvider"])
408
+ print("โœ… ONNX SESSION READY")
409
+
410
+ # =========================================================
411
+ # TEXT QUALITY FIXES (NAMES, ACRONYMS, CAMELCASE)
412
+ # =========================================================
413
+ RE_ALLCAPS = re.compile(r"\b([A-Z]{2,})\b")
414
+ RE_CAMEL = re.compile(r"([a-z])([A-Z])")
415
+ RE_SENT_SPLIT = re.compile(r'([.,!?;:\n]+)')
416
+
417
+ def normalize_names(text: str) -> str:
418
  if not text:
419
  return ""
420
+ # AI -> A I
421
+ text = RE_ALLCAPS.sub(lambda m: " ".join(list(m.group(1))), text)
422
+ # OpenAI -> Open AI
423
+ text = RE_CAMEL.sub(r"\1 \2", text)
424
  return text
425
 
426
+ @lru_cache(maxsize=10000)
427
+ def get_tokens_cached(text: str):
428
+ # Your IPA hint behavior from v1
429
+ if "Kokoro" in text:
430
+ text = text.replace("Kokoro", "kหˆOkษ™ษนO")
431
+ phonemes, _ = G2P(text)
432
+ return tuple(TOKENIZER.get(p, 0) for p in phonemes)
433
+
434
+ def tuned_splitter(text: str):
435
+ # Fast first audio, bigger later chunks
436
+ parts = RE_SENT_SPLIT.split(text)
437
+ buf = ""
438
+ chunk_idx = 0
439
+
440
+ for p in parts:
441
+ if p is None:
442
+ continue
443
+ buf += p
444
+
445
+ if chunk_idx == 0:
446
+ threshold = 60
447
+ elif chunk_idx == 1:
448
+ threshold = 120
449
+ elif chunk_idx == 2:
450
+ threshold = 180
451
+ else:
452
+ threshold = 280
453
+
454
+ if buf and re.search(r"[.,!?;:\n]$", buf) and len(buf) >= threshold:
455
+ s = buf.strip()
456
+ if s:
457
+ yield s
458
+ chunk_idx += 1
459
+ buf = ""
460
+
461
+ s = buf.strip()
462
+ if s:
463
+ yield s
464
+
465
+ # =========================================================
466
+ # AUDIO POST (LESS AGGRESSIVE TRIM + CROSSFADE TO REMOVE "DROPS")
467
+ # =========================================================
468
+ def trim_leading(audio_f32: np.ndarray, threshold=0.01, pad=80) -> np.ndarray:
469
+ if audio_f32.size == 0:
470
+ return audio_f32
471
+ mask = np.abs(audio_f32) > threshold
472
+ if not np.any(mask):
473
+ return audio_f32
474
+ start = int(np.argmax(mask))
475
+ start = max(0, start - pad)
476
+ return audio_f32[start:]
477
+
478
+ def trim_trailing(audio_f32: np.ndarray, threshold=0.01, pad=120) -> np.ndarray:
479
+ if audio_f32.size == 0:
480
+ return audio_f32
481
+ mask = np.abs(audio_f32) > threshold
482
+ if not np.any(mask):
483
+ return audio_f32
484
+ end = int(len(mask) - np.argmax(mask[::-1]))
485
+ end = min(len(audio_f32), end + pad)
486
+ return audio_f32[:end]
487
+
488
+ def float_to_pcm_bytes(audio_f32: np.ndarray) -> bytes:
489
+ audio_f32 = np.clip(audio_f32, -1.0, 1.0).astype(np.float32)
490
+ pcm = (audio_f32 * 32767.0).astype(np.int16)
491
+ return pcm.tobytes()
492
+
493
+ def crossfade_bytes_stream(chunks_f32, overlap=1200):
494
+ """
495
+ overlap=1200 samples ~= 50ms at 24kHz
496
+ We hold the last overlap of each chunk, blend into next chunk head,
497
+ then stream without clicks or "drops".
498
+ """
499
+ prev_tail = None
500
+
501
+ for i, a in enumerate(chunks_f32):
502
+ if a is None or a.size == 0:
503
+ continue
504
 
505
+ if prev_tail is None:
506
+ if a.size <= overlap * 2:
507
+ yield float_to_pcm_bytes(a)
508
+ prev_tail = None
509
+ continue
 
510
 
511
+ body = a[:-overlap]
512
+ prev_tail = a[-overlap:]
513
+ yield float_to_pcm_bytes(body)
514
+ continue
515
 
516
+ if a.size < overlap:
517
+ # too small, just append
518
+ blended = np.concatenate([prev_tail, a])
519
+ prev_tail = None
520
+ yield float_to_pcm_bytes(blended)
521
+ continue
522
+
523
+ fade_out = np.linspace(1.0, 0.0, overlap, dtype=np.float32)
524
+ fade_in = 1.0 - fade_out
525
+ head = a[:overlap]
526
+ blended = (prev_tail * fade_out) + (head * fade_in)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
 
528
+ if a.size <= overlap * 2:
529
+ # nothing meaningful to hold
530
+ out = np.concatenate([blended, a[overlap:]])
531
+ prev_tail = None
532
+ yield float_to_pcm_bytes(out)
533
+ continue
534
+
535
+ mid = a[overlap:-overlap]
536
+ prev_tail = a[-overlap:]
537
+ out = np.concatenate([blended, mid])
538
+ yield float_to_pcm_bytes(out)
539
+
540
+ if prev_tail is not None and prev_tail.size > 0:
541
+ yield float_to_pcm_bytes(prev_tail)
542
+
543
+ # =========================================================
544
+ # ONNX INFER (FAST)
545
+ # =========================================================
546
+ def infer_tokens(tokens, voice_vec, speed: float):
547
+ ids = tokens[:510]
548
+ if not ids:
549
+ return None
550
+
551
+ # voice_vec shape: (T,1,256)
552
+ style = voice_vec[min(len(ids), voice_vec.shape[0] - 1)] # -> (1,256)
553
+
554
+ audio = SESSION.run(
555
+ None,
556
+ {
557
+ "input_ids": np.array([[0, *ids, 0]], dtype=np.int64),
558
+ "style": style,
559
+ "speed": np.array([float(speed)], dtype=np.float32),
560
+ },
561
+ )[0] # expected shape: (1, N)
562
+
563
+ out = audio[0].astype(np.float32, copy=False)
564
+ return out
565
+
566
+ # =========================================================
567
+ # API ONLY (FASTAPI + WS)
568
+ # =========================================================
569
  api = FastAPI()
570
 
571
+ # Single worker thread for full job generation (tokens + onnx + crossfade)
572
  INFERENCE_EXECUTOR = ThreadPoolExecutor(max_workers=1)
573
+
574
+ # Queue of jobs: each job is 1 full text for 1 websocket
575
+ JOB_QUEUE: asyncio.Queue = asyncio.Queue(maxsize=64)
576
+
577
+ def resolve_voice(value: str) -> str:
578
+ if not value:
579
+ return DEFAULT_VOICE_ID
580
+ v = VOICE_CHOICES.get(value, value).strip()
581
+ if v not in ALLOWED_VOICE_IDS:
582
+ return DEFAULT_VOICE_ID
583
+ return v
584
 
585
  @api.get("/health")
586
  async def health():
587
+ return {
588
+ "ok": True,
589
+ "engine": "onnxruntime",
590
+ "sample_rate": SAMPLE_RATE,
591
+ "default_voice": DEFAULT_VOICE_ID,
592
+ }
593
+
594
+ def warmup_once():
595
+ try:
596
+ get_voice(DEFAULT_VOICE_ID)
597
+ tokens = get_tokens_cached("Hello.") # cached tuple
598
+ _ = infer_tokens(tokens, VOICE_CACHE[DEFAULT_VOICE_ID], 1.0)
599
+ print("โœ… WARMUP OK")
600
+ except Exception as e:
601
+ print(f"โš ๏ธ WARMUP FAILED: {e}")
602
 
603
+ @api.on_event("startup")
604
+ async def startup():
605
+ loop = asyncio.get_running_loop()
606
+ await loop.run_in_executor(INFERENCE_EXECUTOR, warmup_once)
607
+ asyncio.create_task(engine_loop())
608
+
609
+ async def engine_loop():
610
  print("โšก API AUDIO PIPELINE STARTED")
611
  loop = asyncio.get_running_loop()
612
 
613
  while True:
614
+ ws, voice_id, speed, text = await JOB_QUEUE.get()
615
 
616
  if ws.client_state.value > 1:
617
  continue
618
 
619
+ # This queue carries PCM frames from the worker thread back to asyncio
620
  frame_q: asyncio.Queue = asyncio.Queue(maxsize=8)
621
+ stop_flag = threading.Event()
622
 
623
+ def _worker_full_job():
624
  try:
625
+ t0 = time.time()
626
+
627
+ voice_vec = get_voice(voice_id)
628
+
629
+ # Build per-chunk float32 audio list, with light leading trim
630
+ audio_chunks = []
631
  first = True
632
+
633
+ for chunk in tuned_splitter(text):
634
+ if stop_flag.is_set():
635
+ break
636
+
637
+ # tokenize (cached)
638
+ tokens = get_tokens_cached(chunk)
639
+ if not tokens:
640
+ continue
641
+
642
+ a = infer_tokens(tokens, voice_vec, speed)
643
+ if a is None or a.size == 0:
644
+ continue
645
+
646
+ # do NOT aggressively trim every chunk, only leading a bit
647
  if first:
648
+ a = trim_leading(a, threshold=0.01, pad=120)
649
  first = False
650
+ else:
651
+ a = trim_leading(a, threshold=0.01, pad=60)
652
+
653
+ audio_chunks.append(a)
654
+
655
+ # Push first audio as soon as we have it, no waiting for the full list
656
+ if len(audio_chunks) == 1:
657
+ for frame in crossfade_bytes_stream(audio_chunks, overlap=1200):
658
+ loop.call_soon_threadsafe(frame_q.put_nowait, frame)
659
+ audio_chunks.clear()
660
+
661
+ # Flush remaining with crossfade
662
+ if not stop_flag.is_set():
663
+ if audio_chunks:
664
+ # trim trailing only at the very end to avoid cutting words mid stream
665
+ audio_chunks[-1] = trim_trailing(audio_chunks[-1], threshold=0.01, pad=160)
666
+
667
+ for frame in crossfade_bytes_stream(audio_chunks, overlap=1200):
668
+ loop.call_soon_threadsafe(frame_q.put_nowait, frame)
669
+
670
  loop.call_soon_threadsafe(frame_q.put_nowait, None)
671
+ dt = time.time() - t0
672
+ print(f"โœ… job done in {dt:.2f}s")
673
+
674
  except Exception as e:
675
  print(f"API Worker Error: {e}")
676
  try:
 
678
  except Exception:
679
  pass
680
 
681
+ INFERENCE_EXECUTOR.submit(_worker_full_job)
682
+
683
+ first_sent = False
684
+ started = time.time()
685
 
686
  while True:
687
  frame = await frame_q.get()
 
689
  break
690
 
691
  if ws.client_state.value > 1:
692
+ stop_flag.set()
693
  break
694
 
695
  try:
696
  await ws.send_bytes(frame)
697
+ if not first_sent:
698
+ first_sent = True
699
+ print(f"โšก first audio sent in {time.time() - started:.2f}s")
700
  except Exception:
701
+ stop_flag.set()
702
  break
703
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
704
  @api.websocket("/ws/audio")
705
  async def websocket_endpoint(ws: WebSocket):
706
  await ws.accept()
707
 
708
+ # per-connection state
709
  voice_id = DEFAULT_VOICE_ID # โœ… default Onyx
710
  speed = DEFAULT_SPEED
711
 
 
731
  except Exception:
732
  break
733
 
734
+ # client config
735
+ if "config" in data or data.get("type") == "config":
736
  voice_id = resolve_voice(str(data.get("voice", voice_id)))
737
  try:
738
  speed = float(data.get("speed", speed))
739
  except Exception:
740
  speed = DEFAULT_SPEED
741
+ # preload voice immediately so the next text has no voice load delay
742
+ try:
743
+ get_voice(voice_id)
744
+ except Exception:
745
+ voice_id = DEFAULT_VOICE_ID
746
+ get_voice(voice_id)
747
 
748
+ # client text
749
+ if "text" in data or data.get("type") == "text":
750
+ raw = str(data.get("text", ""))
751
+ raw = raw.strip()
752
+ if not raw:
753
+ continue
754
+
755
+ # name + acronym fix so it stops skipping brands and people names
756
+ raw = normalize_names(raw)
757
+
758
+ # hard cap to prevent one user blocking the box forever
759
+ if len(raw) > 6000:
760
+ await ws.send_json({"type": "error", "message": "text_too_long", "max_chars": 6000})
761
+ continue
762
 
 
763
  try:
764
+ JOB_QUEUE.put_nowait((ws, voice_id, speed, raw))
765
+ except asyncio.QueueFull:
766
+ await ws.send_json({"type": "error", "message": "server_busy"})
767
+
768
+ if "flush" in data or data.get("type") == "flush":
769
+ await ws.send_json({"type": "flushed"})
770
 
771
  finally:
772
  heartbeat_task.cancel()