auralodyssey commited on
Commit
d87f15e
Β·
verified Β·
1 Parent(s): eff63e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -365
app.py CHANGED
@@ -1,305 +1,12 @@
1
- # import os
2
- # import json
3
- # import time
4
- # import re
5
- # import numpy as np
6
- # import onnxruntime as ort
7
- # import gradio as gr
8
- # from huggingface_hub import hf_hub_download
9
- # from misaki import en
10
- # from functools import lru_cache
11
- # from fastapi import FastAPI, WebSocket, WebSocketDisconnect
12
- # import asyncio
13
- # import uvloop
14
- # import uvicorn
15
- # from concurrent.futures import ThreadPoolExecutor
16
-
17
- # # --- CONFIGURATION ---
18
- # MODEL_REPO = "onnx-community/Kokoro-82M-v1.0-ONNX"
19
- # MODEL_FILE = "onnx/model.onnx"
20
- # TOKENIZER_FILE = "tokenizer.json"
21
-
22
- # # --- VOICE UI ---
23
- # VOICE_CHOICES = {
24
- # 'πŸ‡ΊπŸ‡Έ 🚺 Heart': 'af_heart', 'πŸ‡ΊπŸ‡Έ 🚺 Bella': 'af_bella', 'πŸ‡ΊπŸ‡Έ 🚺 Nicole': 'af_nicole',
25
- # 'πŸ‡ΊπŸ‡Έ 🚺 Aoede': 'af_aoede', 'πŸ‡ΊπŸ‡Έ 🚺 Kore': 'af_kore', 'πŸ‡ΊπŸ‡Έ 🚺 Sarah': 'af_sarah',
26
- # 'πŸ‡ΊπŸ‡Έ 🚺 Nova': 'af_nova', 'πŸ‡ΊπŸ‡Έ 🚺 Sky': 'af_sky', 'πŸ‡ΊπŸ‡Έ 🚺 Alloy': 'af_alloy',
27
- # 'πŸ‡ΊπŸ‡Έ 🚺 Jessica': 'af_jessica', 'πŸ‡ΊπŸ‡Έ 🚺 River': 'af_river', 'πŸ‡ΊπŸ‡Έ 🚹 Michael': 'am_michael',
28
- # 'πŸ‡ΊπŸ‡Έ 🚹 Fenrir': 'am_fenrir', 'πŸ‡ΊπŸ‡Έ 🚹 Puck': 'am_puck', 'πŸ‡ΊπŸ‡Έ 🚹 Echo': 'am_echo',
29
- # 'πŸ‡ΊπŸ‡Έ 🚹 Eric': 'am_eric', 'πŸ‡ΊπŸ‡Έ 🚹 Liam': 'am_liam', 'πŸ‡ΊπŸ‡Έ 🚹 Onyx': 'am_onyx',
30
- # 'πŸ‡ΊπŸ‡Έ 🚹 Santa': 'am_santa', 'πŸ‡ΊπŸ‡Έ 🚹 Adam': 'am_adam', 'πŸ‡¬πŸ‡§ 🚺 Emma': 'bf_emma',
31
- # 'πŸ‡¬πŸ‡§ 🚺 Isabella': 'bf_isabella', 'πŸ‡¬πŸ‡§ 🚺 Alice': 'bf_alice', 'πŸ‡¬πŸ‡§ 🚺 Lily': 'bf_lily',
32
- # 'πŸ‡¬πŸ‡§ 🚹 George': 'bm_george', 'πŸ‡¬πŸ‡§ 🚹 Fable': 'bm_fable', 'πŸ‡¬πŸ‡§ 🚹 Lewis': 'bm_lewis',
33
- # 'πŸ‡¬πŸ‡§ 🚹 Daniel': 'bm_daniel',
34
- # }
35
-
36
- # # --- ENGINE ---
37
- # print("πŸš€ BOOTING HIGH-RAM ENGINE...")
38
- # # Enable fast networking immediately
39
- # asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
40
-
41
- # # 1. Phonemizer
42
- # G2P = en.G2P(trf=False, british=False, fallback=None)
43
-
44
- # # 2. Tokenizer
45
- # vocab_path = hf_hub_download(repo_id=MODEL_REPO, filename=TOKENIZER_FILE)
46
- # with open(vocab_path, "r", encoding="utf-8") as f:
47
- # data = json.load(f)
48
- # TOKENIZER = data["model"]["vocab"] if "model" in data else data.get("vocab", {})
49
-
50
- # # 3. Voices (Lazy Load)
51
- # VOICE_CACHE = {}
52
- # def get_voice(name):
53
- # code = VOICE_CHOICES.get(name, name)
54
- # if code not in VOICE_CACHE:
55
- # try:
56
- # print(f"⬇️ Loading Voice: {code}")
57
- # path = hf_hub_download(repo_id=MODEL_REPO, filename=f"voices/{code}.bin")
58
- # VOICE_CACHE[code] = np.fromfile(path, dtype=np.float32).reshape(-1, 1, 256)
59
- # except:
60
- # if 'af_bella' not in VOICE_CACHE:
61
- # p = hf_hub_download(repo_id=MODEL_REPO, filename="voices/af_bella.bin")
62
- # VOICE_CACHE['af_bella'] = np.fromfile(p, dtype=np.float32).reshape(-1, 1, 256)
63
- # return VOICE_CACHE['af_bella']
64
- # return VOICE_CACHE[code]
65
-
66
- # # 4. ONNX Engine
67
- # model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
68
- # sess_options = ort.SessionOptions()
69
- # sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
70
- # sess_options.add_session_config_entry("session.intra_op.allow_spinning", "0")
71
- # sess_options.intra_op_num_threads = 0
72
- # sess_options.inter_op_num_threads = 0
73
- # SESSION = ort.InferenceSession(model_path, sess_options, providers=["CPUExecutionProvider"])
74
- # print("βœ… ENGINE READY")
75
-
76
- # # --- CORE LOGIC (Shared by UI and API) ---
77
- # @lru_cache(maxsize=5000)
78
- # def get_tokens(text):
79
- # if "Kokoro" in text: text = text.replace("Kokoro", "kˈOkΙ™ΙΉO")
80
- # phonemes, _ = G2P(text)
81
- # return [TOKENIZER.get(p, 0) for p in phonemes]
82
-
83
- # def trim_silence(audio, threshold=0.01):
84
- # if audio.size == 0: return audio
85
- # mask = np.abs(audio) > threshold
86
- # if not np.any(mask): return audio
87
- # start, end = np.argmax(mask), len(mask) - np.argmax(mask[::-1])
88
- # return audio[max(0, start-50) : min(len(audio), end+50)]
89
-
90
- # def infer(text, voice_name, speed):
91
- # if not text.strip(): return None
92
- # ids = get_tokens(text)[:510]
93
- # if not ids: return None
94
- # voice = get_voice(voice_name)
95
- # style = voice[min(len(ids), voice.shape[0]-1)]
96
- # try:
97
- # audio = SESSION.run(None, {
98
- # "input_ids": np.array([[0] + ids + [0]], dtype=np.int64),
99
- # "style": style,
100
- # "speed": np.array([speed], dtype=np.float32)
101
- # })[0]
102
- # return 24000, (np.clip(trim_silence(audio[0]), -1.0, 1.0) * 32767).astype(np.int16)
103
- # except: return None
104
-
105
- # def tuned_splitter(text):
106
- # chunks = re.split(r'([.,!?;:\n]+)', text)
107
- # buffer = ""
108
- # chunk_count = 0
109
- # for part in chunks:
110
- # buffer += part
111
- # if chunk_count == 0: threshold = 50
112
- # elif chunk_count == 1: threshold = 100
113
- # elif chunk_count == 2: threshold = 150
114
- # else: threshold = 250
115
- # if re.search(r'[.,!?;:\n]$', buffer) and len(buffer) >= threshold:
116
- # if buffer.strip():
117
- # yield buffer
118
- # chunk_count += 1
119
- # buffer = ""
120
- # if buffer.strip():
121
- # yield buffer.strip()
122
-
123
- # def stream_generator(text, voice_name, speed):
124
- # print("--- START STREAM ---")
125
- # get_voice(voice_name)
126
- # for i, chunk in enumerate(tuned_splitter(text)):
127
- # t0 = time.time()
128
- # audio = infer(chunk, voice_name, speed)
129
- # if audio:
130
- # dur = time.time() - t0
131
- # print(f"⚑ Chunk {i}: {len(chunk)} chars in {dur:.2f}s")
132
- # yield audio
133
- # print("--- END STREAM ---")
134
-
135
- # # --- UI DEFINITION ---
136
- # with gr.Blocks(title="Kokoro TTS") as app:
137
- # gr.Markdown("## ⚑ Kokoro-82M (High-RAM Tuned)")
138
- # with gr.Row():
139
- # with gr.Column():
140
- # text_in = gr.Textbox(label="Input Text", lines=3, value="The system is live. Use the Gradio UI for testing, or connect to /ws/audio for the API.")
141
- # voice_in = gr.Dropdown(list(VOICE_CHOICES.keys()), value='πŸ‡ΊπŸ‡Έ 🚺 Bella', label="Voice")
142
- # speed_in = gr.Slider(0.5, 2.0, value=1.0, label="Speed")
143
- # btn = gr.Button("Generate", variant="primary")
144
- # with gr.Column():
145
- # audio_out = gr.Audio(streaming=True, autoplay=True, label="Audio Stream")
146
- # btn.click(stream_generator, inputs=[text_in, voice_in, speed_in], outputs=[audio_out])
147
-
148
- # # --- API INTEGRATION ---
149
- # # --- API INTEGRATION ---
150
- # from concurrent.futures import ThreadPoolExecutor
151
-
152
- # # 1. Define FastAPI
153
- # api = FastAPI()
154
-
155
- # # 2. Define Worker Pools
156
- # # We use max_workers=1 because ONNX is already multithreaded internally.
157
- # # Adding more workers on a 2 vCPU machine will actually SLOW it down due to context switching.
158
- # INFERENCE_EXECUTOR = ThreadPoolExecutor(max_workers=1)
159
- # G2P_EXECUTOR = ThreadPoolExecutor(max_workers=1)
160
- # INFERENCE_QUEUE = asyncio.Queue()
161
-
162
- # # 3. Background Tasks
163
- # def g2p_task(text):
164
- # # Reuses the exact same G2P/Tokenizer logic as the UI
165
- # if "Kokoro" in text: text = text.replace("Kokoro", "kˈOkΙ™ΙΉO")
166
- # phonemes, _ = G2P(text)
167
- # return [TOKENIZER.get(p, 0) for p in phonemes]
168
-
169
- # # This is the "Engine Room". It pulls tickets and cooks them one by one.
170
- # async def audio_engine_loop():
171
- # print("⚑ API AUDIO PIPELINE STARTED")
172
- # loop = asyncio.get_running_loop()
173
-
174
- # while True:
175
- # # Wait for a ticket (text tokens + websocket connection)
176
- # job = await INFERENCE_QUEUE.get()
177
- # tokens, style, speed, ws = job
178
-
179
- # try:
180
- # # Check if client is still connected before doing heavy math
181
- # # (FastAPI WS state: 1 = Connected, 2/3 = Closing/Closed)
182
- # if ws.client_state.value > 1:
183
- # continue
184
-
185
- # # Reuses the exact same SESSION as the UI
186
- # input_ids = np.array([[0, *tokens[:510], 0]], dtype=np.int64)
187
- # style_vec = style[min(len(tokens), style.shape[0]-1)]
188
-
189
- # # --- CRITICAL FIX: Run blocking math in a separate thread ---
190
- # # This allows the main server to keep talking to the other 59 users
191
- # # while this calculation happens in the background.
192
- # audio = await loop.run_in_executor(
193
- # INFERENCE_EXECUTOR,
194
- # lambda: SESSION.run(None, {
195
- # "input_ids": input_ids,
196
- # "style": style_vec,
197
- # "speed": np.array([speed], dtype=np.float32)
198
- # })[0]
199
- # )
200
-
201
- # # Post-Process (Fast enough to run on main thread)
202
- # pcm_bytes = (np.clip(trim_silence(audio[0]), -1.0, 1.0) * 32767).astype(np.int16).tobytes()
203
-
204
- # # Send audio back to the specific user who asked for it
205
- # try:
206
- # await ws.send_bytes(pcm_bytes)
207
- # except Exception:
208
- # # If sending fails, just move on. Don't crash the engine.
209
- # pass
210
-
211
- # except Exception as e:
212
- # print(f"API Engine Error: {e}")
213
-
214
- # @api.on_event("startup")
215
- # async def startup():
216
- # asyncio.create_task(audio_engine_loop())
217
-
218
- # # -------------------------------------------------------
219
- # # ROBUST WEBSOCKET ENDPOINT
220
- # # -------------------------------------------------------
221
- # @api.websocket("/ws/audio")
222
- # async def websocket_endpoint(ws: WebSocket):
223
- # await ws.accept()
224
-
225
- # # Defaults
226
- # voice_key = "af_bella"
227
- # speed = 1.0
228
- # loop = asyncio.get_running_loop()
229
-
230
- # print(f"βœ… Client connected: {ws.client}")
231
-
232
- # # --- HEARTBEAT KEEPER ---
233
- # # This prevents HF Nginx from killing the connection during silence.
234
- # async def keep_alive():
235
- # while True:
236
- # try:
237
- # await asyncio.sleep(15) # Send a ping every 15s
238
- # # We send a text frame as a ping. The browser ignores it or handles it.
239
- # await ws.send_json({"type": "ping"})
240
- # except:
241
- # break
242
-
243
- # heartbeat_task = asyncio.create_task(keep_alive())
244
-
245
- # try:
246
- # while True:
247
- # try:
248
- # # Wait for JSON command
249
- # data = await ws.receive_json()
250
- # except WebSocketDisconnect:
251
- # print("❌ Client disconnected cleanly")
252
- # break # BREAK THE LOOP
253
- # except Exception as e:
254
- # print(f"⚠️ Connection lost: {e}")
255
- # break # BREAK THE LOOP
256
-
257
- # # 1. Config Change
258
- # if "config" in data:
259
- # voice_name = data.get("voice", "πŸ‡ΊπŸ‡Έ 🚺 Bella")
260
- # voice_code = VOICE_CHOICES.get(voice_name, voice_name)
261
- # get_voice(voice_name)
262
- # voice_key = voice_code
263
- # speed = float(data.get("speed", speed))
264
- # # print(f"βš™οΈ Config updated: {voice_key}") # Commented out to reduce log noise
265
-
266
- # # 2. Text Stream
267
- # if "text" in data:
268
- # text = data["text"]
269
- # # The splitter breaks "500 words" into small sentences.
270
- # # These small sentences are added to the queue instantly.
271
- # for chunk in tuned_splitter(text):
272
- # if chunk.strip():
273
- # # Run G2P in thread to avoid blocking input
274
- # tokens = await loop.run_in_executor(G2P_EXECUTOR, g2p_task, chunk)
275
- # if tokens:
276
- # style = VOICE_CACHE.get(voice_key)
277
- # if style is None:
278
- # get_voice(voice_key)
279
- # style = VOICE_CACHE.get(voice_key)
280
-
281
- # # Put the ticket in the global queue
282
- # await INFERENCE_QUEUE.put((tokens, style, speed, ws))
283
-
284
- # if "flush" in data:
285
- # pass
286
-
287
- # except Exception as e:
288
- # print(f"πŸ”₯ Critical WS Error: {e}")
289
- # finally:
290
- # heartbeat_task.cancel() # Clean up the heartbeat task
291
-
292
- # # --- FINAL MOUNT ---
293
- # final_app = gr.mount_gradio_app(api, app, path="/")
294
-
295
- # if __name__ == "__main__":
296
- # uvicorn.run(final_app, host="0.0.0.0", port=7860)
297
  import os
298
  import json
299
  import time
300
  import re
301
  import numpy as np
 
302
  import gradio as gr
 
 
303
  from functools import lru_cache
304
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
305
  import asyncio
@@ -307,11 +14,10 @@ import uvloop
307
  import uvicorn
308
  from concurrent.futures import ThreadPoolExecutor
309
 
310
- # πŸ”₯ USE KOKORO PIPELINE INSTEAD OF RAW MISAKI
311
- from kokoro import KPipeline
312
-
313
  # --- CONFIGURATION ---
314
- SAMPLE_RATE = 24000
 
 
315
 
316
  # --- VOICE UI ---
317
  VOICE_CHOICES = {
@@ -328,50 +34,75 @@ VOICE_CHOICES = {
328
  }
329
 
330
  # --- ENGINE ---
331
- print("πŸš€ BOOTING KOKORO PIPELINE ENGINE...")
 
332
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
333
 
334
- # Initialize KPipeline - this handles espeak fallback automatically!
335
- PIPELINE = KPipeline(lang_code='a') # 'a' = American English
336
- print("βœ… KOKORO PIPELINE READY")
337
 
338
- # --- CORE LOGIC ---
339
- def generate_audio(text, voice_name, speed):
340
- """Generate audio using KPipeline - handles all phonemes properly!"""
341
- if not text or not text.strip():
342
- return None
343
-
344
- voice = VOICE_CHOICES.get(voice_name, voice_name)
345
-
346
- try:
347
- # KPipeline returns generator of (graphemes, phonemes, audio)
348
- audio_chunks = []
349
- for gs, ps, audio in PIPELINE(text, voice=voice, speed=speed):
350
- if audio is not None and len(audio) > 0:
351
- audio_chunks.append(audio)
352
-
353
- if not audio_chunks:
354
- return None
355
-
356
- # Concatenate all audio chunks
357
- full_audio = np.concatenate(audio_chunks)
358
- return full_audio
359
-
360
- except Exception as e:
361
- print(f"⚠️ Audio generation error: {e}")
362
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
  def trim_silence(audio, threshold=0.01):
365
- if audio is None or audio.size == 0:
366
- return audio
367
  mask = np.abs(audio) > threshold
368
- if not np.any(mask):
369
- return audio
370
  start, end = np.argmax(mask), len(mask) - np.argmax(mask[::-1])
371
  return audio[max(0, start-50) : min(len(audio), end+50)]
372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  def tuned_splitter(text):
374
- """Split text into chunks for streaming"""
375
  chunks = re.split(r'([.,!?;:\n]+)', text)
376
  buffer = ""
377
  chunk_count = 0
@@ -390,26 +121,23 @@ def tuned_splitter(text):
390
  yield buffer.strip()
391
 
392
  def stream_generator(text, voice_name, speed):
393
- """Generate audio stream for Gradio UI"""
394
- print(f"--- START STREAM: {text[:50]}... ---")
395
  for i, chunk in enumerate(tuned_splitter(text)):
396
  t0 = time.time()
397
- audio = generate_audio(chunk, voice_name, speed)
398
- if audio is not None and len(audio) > 0:
399
- audio = trim_silence(audio)
400
  dur = time.time() - t0
401
  print(f"⚑ Chunk {i}: {len(chunk)} chars in {dur:.2f}s")
402
- # Convert to int16 for audio output
403
- audio_int16 = (np.clip(audio, -1.0, 1.0) * 32767).astype(np.int16)
404
- yield (SAMPLE_RATE, audio_int16)
405
  print("--- END STREAM ---")
406
 
407
  # --- UI DEFINITION ---
408
  with gr.Blocks(title="Kokoro TTS") as app:
409
- gr.Markdown("## ⚑ Kokoro-82M with KPipeline (Proper Name Support!)")
410
  with gr.Row():
411
  with gr.Column():
412
- text_in = gr.Textbox(label="Input Text", lines=3, value="Hello! My name is Yaman and I work at Willo. Testing pronunciation of names!")
413
  voice_in = gr.Dropdown(list(VOICE_CHOICES.keys()), value='πŸ‡ΊπŸ‡Έ 🚺 Bella', label="Voice")
414
  speed_in = gr.Slider(0.5, 2.0, value=1.0, label="Speed")
415
  btn = gr.Button("Generate", variant="primary")
@@ -418,59 +146,96 @@ with gr.Blocks(title="Kokoro TTS") as app:
418
  btn.click(stream_generator, inputs=[text_in, voice_in, speed_in], outputs=[audio_out])
419
 
420
  # --- API INTEGRATION ---
 
 
 
 
421
  api = FastAPI()
422
 
 
 
 
423
  INFERENCE_EXECUTOR = ThreadPoolExecutor(max_workers=1)
 
424
  INFERENCE_QUEUE = asyncio.Queue()
425
 
 
 
 
 
 
 
 
 
426
  async def audio_engine_loop():
427
- """Background worker that processes audio requests"""
428
  print("⚑ API AUDIO PIPELINE STARTED")
429
  loop = asyncio.get_running_loop()
430
 
431
  while True:
 
432
  job = await INFERENCE_QUEUE.get()
433
- text, voice, speed, ws = job
434
 
435
  try:
 
 
436
  if ws.client_state.value > 1:
437
  continue
438
 
439
- # Generate audio using KPipeline (in thread to not block)
 
 
 
 
 
 
440
  audio = await loop.run_in_executor(
441
- INFERENCE_EXECUTOR,
442
- lambda: generate_audio(text, voice, speed)
 
 
 
 
443
  )
444
 
445
- if audio is not None and len(audio) > 0:
446
- audio = trim_silence(audio)
447
- pcm_bytes = (np.clip(audio, -1.0, 1.0) * 32767).astype(np.int16).tobytes()
 
 
 
 
 
 
448
 
449
- try:
450
- await ws.send_bytes(pcm_bytes)
451
- except Exception:
452
- pass
453
-
454
  except Exception as e:
455
- print(f"⚠️ API Engine Error: {e}")
456
 
457
  @api.on_event("startup")
458
  async def startup():
459
  asyncio.create_task(audio_engine_loop())
460
 
 
 
 
461
  @api.websocket("/ws/audio")
462
  async def websocket_endpoint(ws: WebSocket):
463
  await ws.accept()
464
 
 
465
  voice_key = "af_bella"
466
  speed = 1.0
 
467
 
468
  print(f"βœ… Client connected: {ws.client}")
469
 
 
 
470
  async def keep_alive():
471
  while True:
472
  try:
473
- await asyncio.sleep(15)
 
474
  await ws.send_json({"type": "ping"})
475
  except:
476
  break
@@ -480,35 +245,49 @@ async def websocket_endpoint(ws: WebSocket):
480
  try:
481
  while True:
482
  try:
 
483
  data = await ws.receive_json()
484
  except WebSocketDisconnect:
485
  print("❌ Client disconnected cleanly")
486
- break
487
  except Exception as e:
488
  print(f"⚠️ Connection lost: {e}")
489
- break
490
 
 
491
  if "config" in data:
492
  voice_name = data.get("voice", "πŸ‡ΊπŸ‡Έ 🚺 Bella")
493
  voice_code = VOICE_CHOICES.get(voice_name, voice_name)
 
494
  voice_key = voice_code
495
  speed = float(data.get("speed", speed))
 
496
 
 
497
  if "text" in data:
498
  text = data["text"]
 
 
499
  for chunk in tuned_splitter(text):
500
  if chunk.strip():
501
- await INFERENCE_QUEUE.put((chunk, voice_key, speed, ws))
 
 
 
 
 
 
 
 
 
502
 
503
  if "flush" in data:
504
  pass
505
 
506
  except Exception as e:
507
  print(f"πŸ”₯ Critical WS Error: {e}")
508
- import traceback
509
- traceback.print_exc()
510
  finally:
511
- heartbeat_task.cancel()
512
 
513
  # --- FINAL MOUNT ---
514
  final_app = gr.mount_gradio_app(api, app, path="/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import json
3
  import time
4
  import re
5
  import numpy as np
6
+ import onnxruntime as ort
7
  import gradio as gr
8
+ from huggingface_hub import hf_hub_download
9
+ from misaki import en
10
  from functools import lru_cache
11
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
12
  import asyncio
 
14
  import uvicorn
15
  from concurrent.futures import ThreadPoolExecutor
16
 
 
 
 
17
  # --- CONFIGURATION ---
18
+ MODEL_REPO = "onnx-community/Kokoro-82M-v1.0-ONNX"
19
+ MODEL_FILE = "onnx/model.onnx"
20
+ TOKENIZER_FILE = "tokenizer.json"
21
 
22
  # --- VOICE UI ---
23
  VOICE_CHOICES = {
 
34
  }
35
 
36
  # --- ENGINE ---
37
+ print("πŸš€ BOOTING HIGH-RAM ENGINE...")
38
+ # Enable fast networking immediately
39
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
40
 
41
+ # 1. Phonemizer
42
+ G2P = en.G2P(trf=False, british=False, fallback=None)
 
43
 
44
+ # 2. Tokenizer
45
+ vocab_path = hf_hub_download(repo_id=MODEL_REPO, filename=TOKENIZER_FILE)
46
+ with open(vocab_path, "r", encoding="utf-8") as f:
47
+ data = json.load(f)
48
+ TOKENIZER = data["model"]["vocab"] if "model" in data else data.get("vocab", {})
49
+
50
+ # 3. Voices (Lazy Load)
51
+ VOICE_CACHE = {}
52
+ def get_voice(name):
53
+ code = VOICE_CHOICES.get(name, name)
54
+ if code not in VOICE_CACHE:
55
+ try:
56
+ print(f"⬇️ Loading Voice: {code}")
57
+ path = hf_hub_download(repo_id=MODEL_REPO, filename=f"voices/{code}.bin")
58
+ VOICE_CACHE[code] = np.fromfile(path, dtype=np.float32).reshape(-1, 1, 256)
59
+ except:
60
+ if 'af_bella' not in VOICE_CACHE:
61
+ p = hf_hub_download(repo_id=MODEL_REPO, filename="voices/af_bella.bin")
62
+ VOICE_CACHE['af_bella'] = np.fromfile(p, dtype=np.float32).reshape(-1, 1, 256)
63
+ return VOICE_CACHE['af_bella']
64
+ return VOICE_CACHE[code]
65
+
66
+ # 4. ONNX Engine
67
+ model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
68
+ sess_options = ort.SessionOptions()
69
+ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
70
+ sess_options.add_session_config_entry("session.intra_op.allow_spinning", "0")
71
+ sess_options.intra_op_num_threads = 0
72
+ sess_options.inter_op_num_threads = 0
73
+ SESSION = ort.InferenceSession(model_path, sess_options, providers=["CPUExecutionProvider"])
74
+ print("βœ… ENGINE READY")
75
+
76
+ # --- CORE LOGIC (Shared by UI and API) ---
77
+ @lru_cache(maxsize=5000)
78
+ def get_tokens(text):
79
+ if "Kokoro" in text: text = text.replace("Kokoro", "kˈOkΙ™ΙΉO")
80
+ phonemes, _ = G2P(text)
81
+ return [TOKENIZER.get(p, 0) for p in phonemes]
82
 
83
  def trim_silence(audio, threshold=0.01):
84
+ if audio.size == 0: return audio
 
85
  mask = np.abs(audio) > threshold
86
+ if not np.any(mask): return audio
 
87
  start, end = np.argmax(mask), len(mask) - np.argmax(mask[::-1])
88
  return audio[max(0, start-50) : min(len(audio), end+50)]
89
 
90
+ def infer(text, voice_name, speed):
91
+ if not text.strip(): return None
92
+ ids = get_tokens(text)[:510]
93
+ if not ids: return None
94
+ voice = get_voice(voice_name)
95
+ style = voice[min(len(ids), voice.shape[0]-1)]
96
+ try:
97
+ audio = SESSION.run(None, {
98
+ "input_ids": np.array([[0] + ids + [0]], dtype=np.int64),
99
+ "style": style,
100
+ "speed": np.array([speed], dtype=np.float32)
101
+ })[0]
102
+ return 24000, (np.clip(trim_silence(audio[0]), -1.0, 1.0) * 32767).astype(np.int16)
103
+ except: return None
104
+
105
  def tuned_splitter(text):
 
106
  chunks = re.split(r'([.,!?;:\n]+)', text)
107
  buffer = ""
108
  chunk_count = 0
 
121
  yield buffer.strip()
122
 
123
  def stream_generator(text, voice_name, speed):
124
+ print("--- START STREAM ---")
125
+ get_voice(voice_name)
126
  for i, chunk in enumerate(tuned_splitter(text)):
127
  t0 = time.time()
128
+ audio = infer(chunk, voice_name, speed)
129
+ if audio:
 
130
  dur = time.time() - t0
131
  print(f"⚑ Chunk {i}: {len(chunk)} chars in {dur:.2f}s")
132
+ yield audio
 
 
133
  print("--- END STREAM ---")
134
 
135
  # --- UI DEFINITION ---
136
  with gr.Blocks(title="Kokoro TTS") as app:
137
+ gr.Markdown("## ⚑ Kokoro-82M (High-RAM Tuned)")
138
  with gr.Row():
139
  with gr.Column():
140
+ text_in = gr.Textbox(label="Input Text", lines=3, value="The system is live. Use the Gradio UI for testing, or connect to /ws/audio for the API.")
141
  voice_in = gr.Dropdown(list(VOICE_CHOICES.keys()), value='πŸ‡ΊπŸ‡Έ 🚺 Bella', label="Voice")
142
  speed_in = gr.Slider(0.5, 2.0, value=1.0, label="Speed")
143
  btn = gr.Button("Generate", variant="primary")
 
146
  btn.click(stream_generator, inputs=[text_in, voice_in, speed_in], outputs=[audio_out])
147
 
148
  # --- API INTEGRATION ---
149
+ # --- API INTEGRATION ---
150
+ from concurrent.futures import ThreadPoolExecutor
151
+
152
+ # 1. Define FastAPI
153
  api = FastAPI()
154
 
155
+ # 2. Define Worker Pools
156
+ # We use max_workers=1 because ONNX is already multithreaded internally.
157
+ # Adding more workers on a 2 vCPU machine will actually SLOW it down due to context switching.
158
  INFERENCE_EXECUTOR = ThreadPoolExecutor(max_workers=1)
159
+ G2P_EXECUTOR = ThreadPoolExecutor(max_workers=1)
160
  INFERENCE_QUEUE = asyncio.Queue()
161
 
162
+ # 3. Background Tasks
163
+ def g2p_task(text):
164
+ # Reuses the exact same G2P/Tokenizer logic as the UI
165
+ if "Kokoro" in text: text = text.replace("Kokoro", "kˈOkΙ™ΙΉO")
166
+ phonemes, _ = G2P(text)
167
+ return [TOKENIZER.get(p, 0) for p in phonemes]
168
+
169
+ # This is the "Engine Room". It pulls tickets and cooks them one by one.
170
  async def audio_engine_loop():
 
171
  print("⚑ API AUDIO PIPELINE STARTED")
172
  loop = asyncio.get_running_loop()
173
 
174
  while True:
175
+ # Wait for a ticket (text tokens + websocket connection)
176
  job = await INFERENCE_QUEUE.get()
177
+ tokens, style, speed, ws = job
178
 
179
  try:
180
+ # Check if client is still connected before doing heavy math
181
+ # (FastAPI WS state: 1 = Connected, 2/3 = Closing/Closed)
182
  if ws.client_state.value > 1:
183
  continue
184
 
185
+ # Reuses the exact same SESSION as the UI
186
+ input_ids = np.array([[0, *tokens[:510], 0]], dtype=np.int64)
187
+ style_vec = style[min(len(tokens), style.shape[0]-1)]
188
+
189
+ # --- CRITICAL FIX: Run blocking math in a separate thread ---
190
+ # This allows the main server to keep talking to the other 59 users
191
+ # while this calculation happens in the background.
192
  audio = await loop.run_in_executor(
193
+ INFERENCE_EXECUTOR,
194
+ lambda: SESSION.run(None, {
195
+ "input_ids": input_ids,
196
+ "style": style_vec,
197
+ "speed": np.array([speed], dtype=np.float32)
198
+ })[0]
199
  )
200
 
201
+ # Post-Process (Fast enough to run on main thread)
202
+ pcm_bytes = (np.clip(trim_silence(audio[0]), -1.0, 1.0) * 32767).astype(np.int16).tobytes()
203
+
204
+ # Send audio back to the specific user who asked for it
205
+ try:
206
+ await ws.send_bytes(pcm_bytes)
207
+ except Exception:
208
+ # If sending fails, just move on. Don't crash the engine.
209
+ pass
210
 
 
 
 
 
 
211
  except Exception as e:
212
+ print(f"API Engine Error: {e}")
213
 
214
  @api.on_event("startup")
215
  async def startup():
216
  asyncio.create_task(audio_engine_loop())
217
 
218
+ # -------------------------------------------------------
219
+ # ROBUST WEBSOCKET ENDPOINT
220
+ # -------------------------------------------------------
221
  @api.websocket("/ws/audio")
222
  async def websocket_endpoint(ws: WebSocket):
223
  await ws.accept()
224
 
225
+ # Defaults
226
  voice_key = "af_bella"
227
  speed = 1.0
228
+ loop = asyncio.get_running_loop()
229
 
230
  print(f"βœ… Client connected: {ws.client}")
231
 
232
+ # --- HEARTBEAT KEEPER ---
233
+ # This prevents HF Nginx from killing the connection during silence.
234
  async def keep_alive():
235
  while True:
236
  try:
237
+ await asyncio.sleep(15) # Send a ping every 15s
238
+ # We send a text frame as a ping. The browser ignores it or handles it.
239
  await ws.send_json({"type": "ping"})
240
  except:
241
  break
 
245
  try:
246
  while True:
247
  try:
248
+ # Wait for JSON command
249
  data = await ws.receive_json()
250
  except WebSocketDisconnect:
251
  print("❌ Client disconnected cleanly")
252
+ break # BREAK THE LOOP
253
  except Exception as e:
254
  print(f"⚠️ Connection lost: {e}")
255
+ break # BREAK THE LOOP
256
 
257
+ # 1. Config Change
258
  if "config" in data:
259
  voice_name = data.get("voice", "πŸ‡ΊπŸ‡Έ 🚺 Bella")
260
  voice_code = VOICE_CHOICES.get(voice_name, voice_name)
261
+ get_voice(voice_name)
262
  voice_key = voice_code
263
  speed = float(data.get("speed", speed))
264
+ # print(f"βš™οΈ Config updated: {voice_key}") # Commented out to reduce log noise
265
 
266
+ # 2. Text Stream
267
  if "text" in data:
268
  text = data["text"]
269
+ # The splitter breaks "500 words" into small sentences.
270
+ # These small sentences are added to the queue instantly.
271
  for chunk in tuned_splitter(text):
272
  if chunk.strip():
273
+ # Run G2P in thread to avoid blocking input
274
+ tokens = await loop.run_in_executor(G2P_EXECUTOR, g2p_task, chunk)
275
+ if tokens:
276
+ style = VOICE_CACHE.get(voice_key)
277
+ if style is None:
278
+ get_voice(voice_key)
279
+ style = VOICE_CACHE.get(voice_key)
280
+
281
+ # Put the ticket in the global queue
282
+ await INFERENCE_QUEUE.put((tokens, style, speed, ws))
283
 
284
  if "flush" in data:
285
  pass
286
 
287
  except Exception as e:
288
  print(f"πŸ”₯ Critical WS Error: {e}")
 
 
289
  finally:
290
+ heartbeat_task.cancel() # Clean up the heartbeat task
291
 
292
  # --- FINAL MOUNT ---
293
  final_app = gr.mount_gradio_app(api, app, path="/")