auralodyssey commited on
Commit
70de827
·
verified ·
1 Parent(s): ee7f838

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +527 -223
app.py CHANGED
@@ -1,25 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
- import json
3
- import time
4
  import re
 
 
 
 
 
5
  import numpy as np
6
- import onnxruntime as ort
7
  import gradio as gr
8
- from huggingface_hub import hf_hub_download
9
- from misaki import en
10
- from functools import lru_cache
11
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
12
- import asyncio
13
- import uvloop
14
  import uvicorn
15
- from concurrent.futures import ThreadPoolExecutor
16
 
17
- # --- CONFIGURATION ---
18
- MODEL_REPO = "onnx-community/Kokoro-82M-v1.0-ONNX"
19
- MODEL_FILE = "onnx/model.onnx"
20
- TOKENIZER_FILE = "tokenizer.json"
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # --- VOICE UI ---
 
 
 
 
 
 
 
23
  VOICE_CHOICES = {
24
  '🇺🇸 🚺 Heart': 'af_heart', '🇺🇸 🚺 Bella': 'af_bella', '🇺🇸 🚺 Nicole': 'af_nicole',
25
  '🇺🇸 🚺 Aoede': 'af_aoede', '🇺🇸 🚺 Kore': 'af_kore', '🇺🇸 🚺 Sarah': 'af_sarah',
@@ -33,263 +344,256 @@ VOICE_CHOICES = {
33
  '🇬🇧 🚹 Daniel': 'bm_daniel',
34
  }
35
 
36
- # --- ENGINE ---
37
- print("🚀 BOOTING HIGH-RAM ENGINE...")
38
- # Enable fast networking immediately
39
- asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
40
-
41
- # 1. Phonemizer
42
- G2P = en.G2P(trf=False, british=False, fallback=None)
43
-
44
- # 2. Tokenizer
45
- vocab_path = hf_hub_download(repo_id=MODEL_REPO, filename=TOKENIZER_FILE)
46
- with open(vocab_path, "r", encoding="utf-8") as f:
47
- data = json.load(f)
48
- TOKENIZER = data["model"]["vocab"] if "model" in data else data.get("vocab", {})
49
-
50
- # 3. Voices (Lazy Load)
51
- VOICE_CACHE = {}
52
- def get_voice(name):
53
- code = VOICE_CHOICES.get(name, name)
54
- if code not in VOICE_CACHE:
55
- try:
56
- print(f"⬇️ Loading Voice: {code}")
57
- path = hf_hub_download(repo_id=MODEL_REPO, filename=f"voices/{code}.bin")
58
- VOICE_CACHE[code] = np.fromfile(path, dtype=np.float32).reshape(-1, 1, 256)
59
- except:
60
- if 'af_bella' not in VOICE_CACHE:
61
- p = hf_hub_download(repo_id=MODEL_REPO, filename="voices/af_bella.bin")
62
- VOICE_CACHE['af_bella'] = np.fromfile(p, dtype=np.float32).reshape(-1, 1, 256)
63
- return VOICE_CACHE['af_bella']
64
- return VOICE_CACHE[code]
65
-
66
- # 4. ONNX Engine
67
- model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
68
- sess_options = ort.SessionOptions()
69
- sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
70
- sess_options.add_session_config_entry("session.intra_op.allow_spinning", "0")
71
- sess_options.intra_op_num_threads = 0
72
- sess_options.inter_op_num_threads = 0
73
- SESSION = ort.InferenceSession(model_path, sess_options, providers=["CPUExecutionProvider"])
74
- print("✅ ENGINE READY")
75
-
76
- # --- CORE LOGIC (Shared by UI and API) ---
77
- @lru_cache(maxsize=5000)
78
- def get_tokens(text):
79
- if "Kokoro" in text: text = text.replace("Kokoro", "kˈOkəɹO")
80
- phonemes, _ = G2P(text)
81
- return [TOKENIZER.get(p, 0) for p in phonemes]
82
-
83
- def trim_silence(audio, threshold=0.01):
84
- if audio.size == 0: return audio
85
- mask = np.abs(audio) > threshold
86
- if not np.any(mask): return audio
87
- start, end = np.argmax(mask), len(mask) - np.argmax(mask[::-1])
88
- return audio[max(0, start-50) : min(len(audio), end+50)]
89
-
90
- def infer(text, voice_name, speed):
91
- if not text.strip(): return None
92
- ids = get_tokens(text)[:510]
93
- if not ids: return None
94
- voice = get_voice(voice_name)
95
- style = voice[min(len(ids), voice.shape[0]-1)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  try:
97
- audio = SESSION.run(None, {
98
- "input_ids": np.array([[0] + ids + [0]], dtype=np.int64),
99
- "style": style,
100
- "speed": np.array([speed], dtype=np.float32)
101
- })[0]
102
- return 24000, (np.clip(trim_silence(audio[0]), -1.0, 1.0) * 32767).astype(np.int16)
103
- except: return None
104
-
105
- def tuned_splitter(text):
106
- chunks = re.split(r'([.,!?;:\n]+)', text)
107
- buffer = ""
108
- chunk_count = 0
109
- for part in chunks:
110
- buffer += part
111
- if chunk_count == 0: threshold = 50
112
- elif chunk_count == 1: threshold = 100
113
- elif chunk_count == 2: threshold = 150
114
- else: threshold = 250
115
- if re.search(r'[.,!?;:\n]$', buffer) and len(buffer) >= threshold:
116
- if buffer.strip():
117
- yield buffer
118
- chunk_count += 1
119
- buffer = ""
120
- if buffer.strip():
121
- yield buffer.strip()
122
-
123
- def stream_generator(text, voice_name, speed):
124
- print("--- START STREAM ---")
125
- get_voice(voice_name)
126
- for i, chunk in enumerate(tuned_splitter(text)):
127
- t0 = time.time()
128
- audio = infer(chunk, voice_name, speed)
129
- if audio:
130
- dur = time.time() - t0
131
- print(f"⚡ Chunk {i}: {len(chunk)} chars in {dur:.2f}s")
132
- yield audio
133
- print("--- END STREAM ---")
134
-
135
- # --- UI DEFINITION ---
136
- with gr.Blocks(title="Kokoro TTS") as app:
137
- gr.Markdown("## ⚡ Kokoro-82M (High-RAM Tuned)")
138
- with gr.Row():
139
- with gr.Column():
140
- text_in = gr.Textbox(label="Input Text", lines=3, value="The system is live. Use the Gradio UI for testing, or connect to /ws/audio for the API.")
141
- voice_in = gr.Dropdown(list(VOICE_CHOICES.keys()), value='🇺🇸 🚺 Bella', label="Voice")
142
- speed_in = gr.Slider(0.5, 2.0, value=1.0, label="Speed")
143
- btn = gr.Button("Generate", variant="primary")
144
- with gr.Column():
145
- audio_out = gr.Audio(streaming=True, autoplay=True, label="Audio Stream")
146
- btn.click(stream_generator, inputs=[text_in, voice_in, speed_in], outputs=[audio_out])
147
 
148
- # --- API INTEGRATION ---
149
- # --- API INTEGRATION ---
150
- from concurrent.futures import ThreadPoolExecutor
151
 
152
- # 1. Define FastAPI
153
- api = FastAPI()
154
 
155
- # 2. Define Worker Pools
156
- # We use max_workers=1 because ONNX is already multithreaded internally.
157
- # Adding more workers on a 2 vCPU machine will actually SLOW it down due to context switching.
158
  INFERENCE_EXECUTOR = ThreadPoolExecutor(max_workers=1)
159
  G2P_EXECUTOR = ThreadPoolExecutor(max_workers=1)
160
- INFERENCE_QUEUE = asyncio.Queue()
161
-
162
- # 3. Background Tasks
163
- def g2p_task(text):
164
- # Reuses the exact same G2P/Tokenizer logic as the UI
165
- if "Kokoro" in text: text = text.replace("Kokoro", "kˈOkəɹO")
166
- phonemes, _ = G2P(text)
167
- return [TOKENIZER.get(p, 0) for p in phonemes]
168
 
169
- # This is the "Engine Room". It pulls tickets and cooks them one by one.
170
  async def audio_engine_loop():
171
  print("⚡ API AUDIO PIPELINE STARTED")
 
172
  loop = asyncio.get_running_loop()
173
-
174
  while True:
175
- # Wait for a ticket (text tokens + websocket connection)
176
  job = await INFERENCE_QUEUE.get()
177
- tokens, style, speed, ws = job
178
-
179
- try:
180
- # Check if client is still connected before doing heavy math
181
- # (FastAPI WS state: 1 = Connected, 2/3 = Closing/Closed)
 
 
182
  if ws.client_state.value > 1:
183
- continue
184
 
185
- # Reuses the exact same SESSION as the UI
186
- input_ids = np.array([[0, *tokens[:510], 0]], dtype=np.int64)
187
- style_vec = style[min(len(tokens), style.shape[0]-1)]
188
-
189
- # --- CRITICAL FIX: Run blocking math in a separate thread ---
190
- # This allows the main server to keep talking to the other 59 users
191
- # while this calculation happens in the background.
192
- audio = await loop.run_in_executor(
193
- INFERENCE_EXECUTOR,
194
- lambda: SESSION.run(None, {
195
- "input_ids": input_ids,
196
- "style": style_vec,
197
- "speed": np.array([speed], dtype=np.float32)
198
- })[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  )
200
-
201
- # Post-Process (Fast enough to run on main thread)
202
- pcm_bytes = (np.clip(trim_silence(audio[0]), -1.0, 1.0) * 32767).astype(np.int16).tobytes()
203
-
204
- # Send audio back to the specific user who asked for it
205
- try:
206
- await ws.send_bytes(pcm_bytes)
207
- except Exception:
208
- # If sending fails, just move on. Don't crash the engine.
209
- pass
210
-
211
- except Exception as e:
212
- print(f"API Engine Error: {e}")
213
 
214
  @api.on_event("startup")
215
  async def startup():
216
  asyncio.create_task(audio_engine_loop())
217
 
218
- # -------------------------------------------------------
219
- # ROBUST WEBSOCKET ENDPOINT
220
- # -------------------------------------------------------
221
  @api.websocket("/ws/audio")
222
  async def websocket_endpoint(ws: WebSocket):
223
  await ws.accept()
224
-
225
- # Defaults
226
- voice_key = "af_bella"
227
  speed = 1.0
228
  loop = asyncio.get_running_loop()
229
-
230
- print(f"✅ Client connected: {ws.client}")
231
 
232
- # --- HEARTBEAT KEEPER ---
233
- # This prevents HF Nginx from killing the connection during silence.
234
- async def keep_alive():
235
- while True:
236
- try:
237
- await asyncio.sleep(15) # Send a ping every 15s
238
- # We send a text frame as a ping. The browser ignores it or handles it.
239
- await ws.send_json({"type": "ping"})
240
- except:
241
- break
242
-
243
- heartbeat_task = asyncio.create_task(keep_alive())
244
 
245
  try:
246
  while True:
247
  try:
248
- # Wait for JSON command
249
  data = await ws.receive_json()
250
  except WebSocketDisconnect:
251
  print("❌ Client disconnected cleanly")
252
- break # BREAK THE LOOP
253
  except Exception as e:
254
  print(f"⚠️ Connection lost: {e}")
255
- break # BREAK THE LOOP
256
 
257
- # 1. Config Change
258
  if "config" in data:
259
  voice_name = data.get("voice", "🇺🇸 🚺 Bella")
260
  voice_code = VOICE_CHOICES.get(voice_name, voice_name)
261
- get_voice(voice_name)
262
- voice_key = voice_code
263
  speed = float(data.get("speed", speed))
264
- # print(f"⚙️ Config updated: {voice_key}") # Commented out to reduce log noise
265
-
266
- # 2. Text Stream
267
  if "text" in data:
268
  text = data["text"]
269
- # The splitter breaks "500 words" into small sentences.
270
- # These small sentences are added to the queue instantly.
271
- for chunk in tuned_splitter(text):
272
- if chunk.strip():
273
- # Run G2P in thread to avoid blocking input
274
- tokens = await loop.run_in_executor(G2P_EXECUTOR, g2p_task, chunk)
275
- if tokens:
276
- style = VOICE_CACHE.get(voice_key)
277
- if style is None:
278
- get_voice(voice_key)
279
- style = VOICE_CACHE.get(voice_key)
280
-
281
- # Put the ticket in the global queue
282
- await INFERENCE_QUEUE.put((tokens, style, speed, ws))
283
-
 
 
 
 
 
 
 
 
 
284
  if "flush" in data:
 
285
  pass
286
 
287
  except Exception as e:
288
  print(f"🔥 Critical WS Error: {e}")
289
- finally:
290
- heartbeat_task.cancel() # Clean up the heartbeat task
291
 
292
- # --- FINAL MOUNT ---
293
  final_app = gr.mount_gradio_app(api, app, path="/")
294
 
295
  if __name__ == "__main__":
 
1
+ # import os
2
+ # import json
3
+ # import time
4
+ # import re
5
+ # import numpy as np
6
+ # import onnxruntime as ort
7
+ # import gradio as gr
8
+ # from huggingface_hub import hf_hub_download
9
+ # from misaki import en
10
+ # from functools import lru_cache
11
+ # from fastapi import FastAPI, WebSocket, WebSocketDisconnect
12
+ # import asyncio
13
+ # import uvloop
14
+ # import uvicorn
15
+ # from concurrent.futures import ThreadPoolExecutor
16
+
17
+ # # --- CONFIGURATION ---
18
+ # MODEL_REPO = "onnx-community/Kokoro-82M-v1.0-ONNX"
19
+ # MODEL_FILE = "onnx/model.onnx"
20
+ # TOKENIZER_FILE = "tokenizer.json"
21
+
22
+ # # --- VOICE UI ---
23
+ # VOICE_CHOICES = {
24
+ # '🇺🇸 🚺 Heart': 'af_heart', '🇺🇸 🚺 Bella': 'af_bella', '🇺🇸 🚺 Nicole': 'af_nicole',
25
+ # '🇺🇸 🚺 Aoede': 'af_aoede', '🇺🇸 🚺 Kore': 'af_kore', '🇺🇸 🚺 Sarah': 'af_sarah',
26
+ # '🇺🇸 🚺 Nova': 'af_nova', '🇺🇸 🚺 Sky': 'af_sky', '🇺🇸 🚺 Alloy': 'af_alloy',
27
+ # '🇺🇸 🚺 Jessica': 'af_jessica', '🇺🇸 🚺 River': 'af_river', '🇺🇸 🚹 Michael': 'am_michael',
28
+ # '🇺🇸 🚹 Fenrir': 'am_fenrir', '🇺🇸 🚹 Puck': 'am_puck', '🇺🇸 🚹 Echo': 'am_echo',
29
+ # '🇺🇸 🚹 Eric': 'am_eric', '🇺🇸 🚹 Liam': 'am_liam', '🇺🇸 🚹 Onyx': 'am_onyx',
30
+ # '🇺🇸 🚹 Santa': 'am_santa', '🇺🇸 🚹 Adam': 'am_adam', '🇬🇧 🚺 Emma': 'bf_emma',
31
+ # '🇬🇧 🚺 Isabella': 'bf_isabella', '🇬🇧 🚺 Alice': 'bf_alice', '🇬🇧 🚺 Lily': 'bf_lily',
32
+ # '🇬🇧 🚹 George': 'bm_george', '🇬🇧 🚹 Fable': 'bm_fable', '🇬🇧 🚹 Lewis': 'bm_lewis',
33
+ # '🇬🇧 🚹 Daniel': 'bm_daniel',
34
+ # }
35
+
36
+ # # --- ENGINE ---
37
+ # print("🚀 BOOTING HIGH-RAM ENGINE...")
38
+ # # Enable fast networking immediately
39
+ # asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
40
+
41
+ # # 1. Phonemizer
42
+ # G2P = en.G2P(trf=False, british=False, fallback=None)
43
+
44
+ # # 2. Tokenizer
45
+ # vocab_path = hf_hub_download(repo_id=MODEL_REPO, filename=TOKENIZER_FILE)
46
+ # with open(vocab_path, "r", encoding="utf-8") as f:
47
+ # data = json.load(f)
48
+ # TOKENIZER = data["model"]["vocab"] if "model" in data else data.get("vocab", {})
49
+
50
+ # # 3. Voices (Lazy Load)
51
+ # VOICE_CACHE = {}
52
+ # def get_voice(name):
53
+ # code = VOICE_CHOICES.get(name, name)
54
+ # if code not in VOICE_CACHE:
55
+ # try:
56
+ # print(f"⬇️ Loading Voice: {code}")
57
+ # path = hf_hub_download(repo_id=MODEL_REPO, filename=f"voices/{code}.bin")
58
+ # VOICE_CACHE[code] = np.fromfile(path, dtype=np.float32).reshape(-1, 1, 256)
59
+ # except:
60
+ # if 'af_bella' not in VOICE_CACHE:
61
+ # p = hf_hub_download(repo_id=MODEL_REPO, filename="voices/af_bella.bin")
62
+ # VOICE_CACHE['af_bella'] = np.fromfile(p, dtype=np.float32).reshape(-1, 1, 256)
63
+ # return VOICE_CACHE['af_bella']
64
+ # return VOICE_CACHE[code]
65
+
66
+ # # 4. ONNX Engine
67
+ # model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
68
+ # sess_options = ort.SessionOptions()
69
+ # sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
70
+ # sess_options.add_session_config_entry("session.intra_op.allow_spinning", "0")
71
+ # sess_options.intra_op_num_threads = 0
72
+ # sess_options.inter_op_num_threads = 0
73
+ # SESSION = ort.InferenceSession(model_path, sess_options, providers=["CPUExecutionProvider"])
74
+ # print("✅ ENGINE READY")
75
+
76
+ # # --- CORE LOGIC (Shared by UI and API) ---
77
+ # @lru_cache(maxsize=5000)
78
+ # def get_tokens(text):
79
+ # if "Kokoro" in text: text = text.replace("Kokoro", "kˈOkəɹO")
80
+ # phonemes, _ = G2P(text)
81
+ # return [TOKENIZER.get(p, 0) for p in phonemes]
82
+
83
+ # def trim_silence(audio, threshold=0.01):
84
+ # if audio.size == 0: return audio
85
+ # mask = np.abs(audio) > threshold
86
+ # if not np.any(mask): return audio
87
+ # start, end = np.argmax(mask), len(mask) - np.argmax(mask[::-1])
88
+ # return audio[max(0, start-50) : min(len(audio), end+50)]
89
+
90
+ # def infer(text, voice_name, speed):
91
+ # if not text.strip(): return None
92
+ # ids = get_tokens(text)[:510]
93
+ # if not ids: return None
94
+ # voice = get_voice(voice_name)
95
+ # style = voice[min(len(ids), voice.shape[0]-1)]
96
+ # try:
97
+ # audio = SESSION.run(None, {
98
+ # "input_ids": np.array([[0] + ids + [0]], dtype=np.int64),
99
+ # "style": style,
100
+ # "speed": np.array([speed], dtype=np.float32)
101
+ # })[0]
102
+ # return 24000, (np.clip(trim_silence(audio[0]), -1.0, 1.0) * 32767).astype(np.int16)
103
+ # except: return None
104
+
105
+ # def tuned_splitter(text):
106
+ # chunks = re.split(r'([.,!?;:\n]+)', text)
107
+ # buffer = ""
108
+ # chunk_count = 0
109
+ # for part in chunks:
110
+ # buffer += part
111
+ # if chunk_count == 0: threshold = 50
112
+ # elif chunk_count == 1: threshold = 100
113
+ # elif chunk_count == 2: threshold = 150
114
+ # else: threshold = 250
115
+ # if re.search(r'[.,!?;:\n]$', buffer) and len(buffer) >= threshold:
116
+ # if buffer.strip():
117
+ # yield buffer
118
+ # chunk_count += 1
119
+ # buffer = ""
120
+ # if buffer.strip():
121
+ # yield buffer.strip()
122
+
123
+ # def stream_generator(text, voice_name, speed):
124
+ # print("--- START STREAM ---")
125
+ # get_voice(voice_name)
126
+ # for i, chunk in enumerate(tuned_splitter(text)):
127
+ # t0 = time.time()
128
+ # audio = infer(chunk, voice_name, speed)
129
+ # if audio:
130
+ # dur = time.time() - t0
131
+ # print(f"⚡ Chunk {i}: {len(chunk)} chars in {dur:.2f}s")
132
+ # yield audio
133
+ # print("--- END STREAM ---")
134
+
135
+ # # --- UI DEFINITION ---
136
+ # with gr.Blocks(title="Kokoro TTS") as app:
137
+ # gr.Markdown("## ⚡ Kokoro-82M (High-RAM Tuned)")
138
+ # with gr.Row():
139
+ # with gr.Column():
140
+ # text_in = gr.Textbox(label="Input Text", lines=3, value="The system is live. Use the Gradio UI for testing, or connect to /ws/audio for the API.")
141
+ # voice_in = gr.Dropdown(list(VOICE_CHOICES.keys()), value='🇺🇸 🚺 Bella', label="Voice")
142
+ # speed_in = gr.Slider(0.5, 2.0, value=1.0, label="Speed")
143
+ # btn = gr.Button("Generate", variant="primary")
144
+ # with gr.Column():
145
+ # audio_out = gr.Audio(streaming=True, autoplay=True, label="Audio Stream")
146
+ # btn.click(stream_generator, inputs=[text_in, voice_in, speed_in], outputs=[audio_out])
147
+
148
+ # # --- API INTEGRATION ---
149
+ # # --- API INTEGRATION ---
150
+ # from concurrent.futures import ThreadPoolExecutor
151
+
152
+ # # 1. Define FastAPI
153
+ # api = FastAPI()
154
+
155
+ # # 2. Define Worker Pools
156
+ # # We use max_workers=1 because ONNX is already multithreaded internally.
157
+ # # Adding more workers on a 2 vCPU machine will actually SLOW it down due to context switching.
158
+ # INFERENCE_EXECUTOR = ThreadPoolExecutor(max_workers=1)
159
+ # G2P_EXECUTOR = ThreadPoolExecutor(max_workers=1)
160
+ # INFERENCE_QUEUE = asyncio.Queue()
161
+
162
+ # # 3. Background Tasks
163
+ # def g2p_task(text):
164
+ # # Reuses the exact same G2P/Tokenizer logic as the UI
165
+ # if "Kokoro" in text: text = text.replace("Kokoro", "kˈOkəɹO")
166
+ # phonemes, _ = G2P(text)
167
+ # return [TOKENIZER.get(p, 0) for p in phonemes]
168
+
169
+ # # This is the "Engine Room". It pulls tickets and cooks them one by one.
170
+ # async def audio_engine_loop():
171
+ # print("⚡ API AUDIO PIPELINE STARTED")
172
+ # loop = asyncio.get_running_loop()
173
+
174
+ # while True:
175
+ # # Wait for a ticket (text tokens + websocket connection)
176
+ # job = await INFERENCE_QUEUE.get()
177
+ # tokens, style, speed, ws = job
178
+
179
+ # try:
180
+ # # Check if client is still connected before doing heavy math
181
+ # # (FastAPI WS state: 1 = Connected, 2/3 = Closing/Closed)
182
+ # if ws.client_state.value > 1:
183
+ # continue
184
+
185
+ # # Reuses the exact same SESSION as the UI
186
+ # input_ids = np.array([[0, *tokens[:510], 0]], dtype=np.int64)
187
+ # style_vec = style[min(len(tokens), style.shape[0]-1)]
188
+
189
+ # # --- CRITICAL FIX: Run blocking math in a separate thread ---
190
+ # # This allows the main server to keep talking to the other 59 users
191
+ # # while this calculation happens in the background.
192
+ # audio = await loop.run_in_executor(
193
+ # INFERENCE_EXECUTOR,
194
+ # lambda: SESSION.run(None, {
195
+ # "input_ids": input_ids,
196
+ # "style": style_vec,
197
+ # "speed": np.array([speed], dtype=np.float32)
198
+ # })[0]
199
+ # )
200
+
201
+ # # Post-Process (Fast enough to run on main thread)
202
+ # pcm_bytes = (np.clip(trim_silence(audio[0]), -1.0, 1.0) * 32767).astype(np.int16).tobytes()
203
+
204
+ # # Send audio back to the specific user who asked for it
205
+ # try:
206
+ # await ws.send_bytes(pcm_bytes)
207
+ # except Exception:
208
+ # # If sending fails, just move on. Don't crash the engine.
209
+ # pass
210
+
211
+ # except Exception as e:
212
+ # print(f"API Engine Error: {e}")
213
+
214
+ # @api.on_event("startup")
215
+ # async def startup():
216
+ # asyncio.create_task(audio_engine_loop())
217
+
218
+ # # -------------------------------------------------------
219
+ # # ROBUST WEBSOCKET ENDPOINT
220
+ # # -------------------------------------------------------
221
+ # @api.websocket("/ws/audio")
222
+ # async def websocket_endpoint(ws: WebSocket):
223
+ # await ws.accept()
224
+
225
+ # # Defaults
226
+ # voice_key = "af_bella"
227
+ # speed = 1.0
228
+ # loop = asyncio.get_running_loop()
229
+
230
+ # print(f"✅ Client connected: {ws.client}")
231
+
232
+ # # --- HEARTBEAT KEEPER ---
233
+ # # This prevents HF Nginx from killing the connection during silence.
234
+ # async def keep_alive():
235
+ # while True:
236
+ # try:
237
+ # await asyncio.sleep(15) # Send a ping every 15s
238
+ # # We send a text frame as a ping. The browser ignores it or handles it.
239
+ # await ws.send_json({"type": "ping"})
240
+ # except:
241
+ # break
242
+
243
+ # heartbeat_task = asyncio.create_task(keep_alive())
244
+
245
+ # try:
246
+ # while True:
247
+ # try:
248
+ # # Wait for JSON command
249
+ # data = await ws.receive_json()
250
+ # except WebSocketDisconnect:
251
+ # print("❌ Client disconnected cleanly")
252
+ # break # BREAK THE LOOP
253
+ # except Exception as e:
254
+ # print(f"⚠️ Connection lost: {e}")
255
+ # break # BREAK THE LOOP
256
+
257
+ # # 1. Config Change
258
+ # if "config" in data:
259
+ # voice_name = data.get("voice", "🇺🇸 🚺 Bella")
260
+ # voice_code = VOICE_CHOICES.get(voice_name, voice_name)
261
+ # get_voice(voice_name)
262
+ # voice_key = voice_code
263
+ # speed = float(data.get("speed", speed))
264
+ # # print(f"⚙️ Config updated: {voice_key}") # Commented out to reduce log noise
265
+
266
+ # # 2. Text Stream
267
+ # if "text" in data:
268
+ # text = data["text"]
269
+ # # The splitter breaks "500 words" into small sentences.
270
+ # # These small sentences are added to the queue instantly.
271
+ # for chunk in tuned_splitter(text):
272
+ # if chunk.strip():
273
+ # # Run G2P in thread to avoid blocking input
274
+ # tokens = await loop.run_in_executor(G2P_EXECUTOR, g2p_task, chunk)
275
+ # if tokens:
276
+ # style = VOICE_CACHE.get(voice_key)
277
+ # if style is None:
278
+ # get_voice(voice_key)
279
+ # style = VOICE_CACHE.get(voice_key)
280
+
281
+ # # Put the ticket in the global queue
282
+ # await INFERENCE_QUEUE.put((tokens, style, speed, ws))
283
+
284
+ # if "flush" in data:
285
+ # pass
286
+
287
+ # except Exception as e:
288
+ # print(f"🔥 Critical WS Error: {e}")
289
+ # finally:
290
+ # heartbeat_task.cancel() # Clean up the heartbeat task
291
+
292
+ # # --- FINAL MOUNT ---
293
+ # final_app = gr.mount_gradio_app(api, app, path="/")
294
+
295
+ # if __name__ == "__main__":
296
+ # uvicorn.run(final_app, host="0.0.0.0", port=7860)
297
  import os
 
 
298
  import re
299
+ import time
300
+ import asyncio
301
+ from functools import lru_cache
302
+ from concurrent.futures import ThreadPoolExecutor
303
+
304
  import numpy as np
 
305
  import gradio as gr
 
 
 
306
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
 
 
307
  import uvicorn
 
308
 
309
+ # Kokoro official inference lib (PyTorch)
310
+ from kokoro import KPipeline
311
+
312
+ # -----------------------------
313
+ # CONFIG
314
+ # -----------------------------
315
+ KOKORO_REPO_ID = os.getenv("KOKORO_REPO_ID", "hexgrad/Kokoro-82M")
316
+ AUDIO_SR = 24000
317
+
318
+ # Split early to reduce latency on long paragraphs
319
+ # Sentences or newlines
320
+ SPLIT_PATTERN = os.getenv("KOKORO_SPLIT_PATTERN", r"(?<=[.!?])\s+|\n+")
321
+
322
+ # Hard safety caps for HF free tier
323
+ MAX_QUEUE = int(os.getenv("MAX_QUEUE", "100"))
324
+ MAX_CHUNKS_PER_UTTERANCE = int(os.getenv("MAX_CHUNKS_PER_UTTERANCE", "120"))
325
 
326
+ # Keep CPU thread usage predictable on 2 vCPU
327
+ os.environ.setdefault("OMP_NUM_THREADS", "2")
328
+ os.environ.setdefault("MKL_NUM_THREADS", "2")
329
+ os.environ.setdefault("NUMEXPR_NUM_THREADS", "2")
330
+
331
+ # -----------------------------
332
+ # VOICES
333
+ # -----------------------------
334
  VOICE_CHOICES = {
335
  '🇺🇸 🚺 Heart': 'af_heart', '🇺🇸 🚺 Bella': 'af_bella', '🇺🇸 🚺 Nicole': 'af_nicole',
336
  '🇺🇸 🚺 Aoede': 'af_aoede', '🇺🇸 🚺 Kore': 'af_kore', '🇺🇸 🚺 Sarah': 'af_sarah',
 
344
  '🇬🇧 🚹 Daniel': 'bm_daniel',
345
  }
346
 
347
+ def _is_uk_voice(voice_code: str) -> bool:
348
+ return voice_code.startswith("bf_") or voice_code.startswith("bm_")
349
+
350
+ # -----------------------------
351
+ # BOOT
352
+ # -----------------------------
353
+ print("🚀 BOOTING KOKORO (OFFICIAL PIPELINE)")
354
+
355
+ # 1) One shared model instance for both pipelines (loads weights once)
356
+ PIPE_A = KPipeline(lang_code="a", repo_id=KOKORO_REPO_ID, trf=False, device="cpu")
357
+ MODEL = PIPE_A.model
358
+ PIPE_B = KPipeline(lang_code="b", repo_id=KOKORO_REPO_ID, trf=False, device="cpu", model=MODEL)
359
+
360
+ # 2) Quiet pipelines for fast G2P + chunking without inference
361
+ QUIET_A = KPipeline(lang_code="a", repo_id=KOKORO_REPO_ID, trf=False, model=False)
362
+ QUIET_B = KPipeline(lang_code="b", repo_id=KOKORO_REPO_ID, trf=False, model=False)
363
+
364
+ # 3) Voice cache (on device)
365
+ VOICE_PACK_CACHE = {}
366
+
367
+ def _pick_pipes(voice_code: str):
368
+ if _is_uk_voice(voice_code):
369
+ return PIPE_B, QUIET_B
370
+ return PIPE_A, QUIET_A
371
+
372
+ def get_voice_pack(voice_code: str):
373
+ if voice_code in VOICE_PACK_CACHE:
374
+ return VOICE_PACK_CACHE[voice_code]
375
+
376
+ pipe, _ = _pick_pipes(voice_code)
377
+ pack = pipe.load_voice(voice_code) # cached inside pipeline too, but we pin our own ref
378
+ VOICE_PACK_CACHE[voice_code] = pack
379
+ return pack
380
+
381
+ # -----------------------------
382
+ # TEXT NORMALIZATION
383
+ # -----------------------------
384
+ _KOKORO_IPA = "[Kokoro](/kˈOkəɹO/)" # official usage pattern :contentReference[oaicite:5]{index=5}
385
+
386
+ def normalize_text(text: str) -> str:
387
+ if not text:
388
+ return ""
389
+
390
+ t = text.strip()
391
+
392
+ # Stable fixes for common “skipped” tokens
393
+ t = t.replace("&", " and ")
394
+ t = t.replace("@", " at ")
395
+ t = t.replace("_", " ")
396
+
397
+ # Split CamelCase to reduce OOD risk: OpenAI -> Open AI
398
+ t = re.sub(r"(?<=[a-z])(?=[A-Z])", " ", t)
399
+
400
+ # Expand short acronyms: CEO -> C E O
401
+ t = re.sub(r"\b([A-Z]{2,6})\b", lambda m: " ".join(list(m.group(1))), t)
402
+
403
+ # Force Kokoro pronunciation in a way the official pipeline supports
404
+ t = re.sub(r"\bKokoro\b", _KOKORO_IPA, t)
405
+
406
+ # Compress whitespace
407
+ t = re.sub(r"\s+", " ", t).strip()
408
+ return t
409
+
410
+ # -----------------------------
411
+ # CHUNKING: text -> phoneme chunks
412
+ # -----------------------------
413
+ @lru_cache(maxsize=2000)
414
+ def _split_segments(text: str):
415
+ # cached split only
416
+ parts = re.split(SPLIT_PATTERN, text)
417
+ return [p.strip() for p in parts if p and p.strip()]
418
+
419
+ def text_to_phoneme_chunks(text: str, voice_code: str):
420
+ _, quiet = _pick_pipes(voice_code)
421
+ t = normalize_text(text)
422
+ if not t:
423
+ return []
424
+
425
+ chunks = []
426
+ for seg in _split_segments(t):
427
+ # g2p returns (phoneme_str, tokens)
428
+ _, tokens = quiet.g2p(seg)
429
+
430
+ # en_tokenize returns (graphemes, phonemes, token_chunk)
431
+ for _, ps, _ in quiet.en_tokenize(tokens):
432
+ if ps:
433
+ chunks.append(ps)
434
+ if len(chunks) >= MAX_CHUNKS_PER_UTTERANCE:
435
+ return chunks
436
+ return chunks
437
+
438
+ # -----------------------------
439
+ # INFERENCE: phonemes -> audio
440
+ # -----------------------------
441
+ def infer_phonemes(ps: str, voice_code: str, speed: float):
442
+ pipe, _ = _pick_pipes(voice_code)
443
+ pack = get_voice_pack(voice_code)
444
+
445
+ # This calls the same internal path as KPipeline.generate_from_tokens
446
+ audio = pipe.infer(ps, voice=pack, speed=speed)
447
+
448
+ # audio can be numpy or torch depending on kokoro version
449
  try:
450
+ import torch
451
+ if torch.is_tensor(audio):
452
+ audio = audio.detach().cpu().numpy()
453
+ except Exception:
454
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
 
456
+ audio = np.asarray(audio, dtype=np.float32)
457
+ audio = np.clip(audio, -1.0, 1.0)
 
458
 
459
+ pcm16 = (audio * 32767.0).astype(np.int16)
460
+ return pcm16
461
 
462
+ # -----------------------------
463
+ # EXECUTORS + QUEUE (HF free tier safe)
464
+ # -----------------------------
465
  INFERENCE_EXECUTOR = ThreadPoolExecutor(max_workers=1)
466
  G2P_EXECUTOR = ThreadPoolExecutor(max_workers=1)
467
+ INFERENCE_QUEUE = asyncio.Queue(maxsize=MAX_QUEUE)
 
 
 
 
 
 
 
468
 
 
469
  async def audio_engine_loop():
470
  print("⚡ API AUDIO PIPELINE STARTED")
471
+
472
  loop = asyncio.get_running_loop()
473
+
474
  while True:
 
475
  job = await INFERENCE_QUEUE.get()
476
+ ws = job["ws"]
477
+ voice_code = job["voice"]
478
+ speed = job["speed"]
479
+ phoneme_chunks = job["chunks"]
480
+
481
+ # Do not interleave chunks across users within one utterance
482
+ for ps in phoneme_chunks:
483
  if ws.client_state.value > 1:
484
+ break
485
 
486
+ try:
487
+ pcm16 = await loop.run_in_executor(
488
+ INFERENCE_EXECUTOR,
489
+ lambda: infer_phonemes(ps, voice_code, speed)
490
+ )
491
+ await ws.send_bytes(pcm16.tobytes())
492
+ except Exception as e:
493
+ print(f"API Engine Error: {e}")
494
+ break
495
+
496
+ # -----------------------------
497
+ # GRADIO UI (streaming)
498
+ # -----------------------------
499
+ def gradio_stream(text: str, voice_name: str, speed: float):
500
+ voice_code = VOICE_CHOICES.get(voice_name, voice_name)
501
+ get_voice_pack(voice_code)
502
+
503
+ phoneme_chunks = text_to_phoneme_chunks(text, voice_code)
504
+ for i, ps in enumerate(phoneme_chunks):
505
+ t0 = time.time()
506
+ pcm16 = infer_phonemes(ps, voice_code, float(speed))
507
+ dt = time.time() - t0
508
+ print(f"⚡ UI chunk {i}: {len(ps)} phonemes in {dt:.2f}s")
509
+ yield (AUDIO_SR, pcm16)
510
+
511
+ with gr.Blocks(title="Kokoro TTS (Official)") as app:
512
+ gr.Markdown("## ⚡ Kokoro-82M (Official Pipeline, HF Free Tier Safe)")
513
+ with gr.Row():
514
+ with gr.Column():
515
+ text_in = gr.Textbox(
516
+ label="Input Text",
517
+ lines=3,
518
+ value="The system is live. Use Gradio for testing, or connect to /ws/audio for the API."
519
  )
520
+ voice_in = gr.Dropdown(list(VOICE_CHOICES.keys()), value="🇺🇸 🚺 Bella", label="Voice")
521
+ speed_in = gr.Slider(0.5, 2.0, value=1.0, label="Speed")
522
+ btn = gr.Button("Generate", variant="primary")
523
+ with gr.Column():
524
+ audio_out = gr.Audio(streaming=True, autoplay=True, label="Audio Stream")
525
+ btn.click(gradio_stream, inputs=[text_in, voice_in, speed_in], outputs=[audio_out])
526
+
527
+ # -----------------------------
528
+ # FASTAPI + WEBSOCKET
529
+ # -----------------------------
530
+ api = FastAPI()
 
 
531
 
532
  @api.on_event("startup")
533
  async def startup():
534
  asyncio.create_task(audio_engine_loop())
535
 
 
 
 
536
  @api.websocket("/ws/audio")
537
  async def websocket_endpoint(ws: WebSocket):
538
  await ws.accept()
539
+
540
+ voice_code = "af_bella"
 
541
  speed = 1.0
542
  loop = asyncio.get_running_loop()
 
 
543
 
544
+ print(f"✅ Client connected: {ws.client}")
 
 
 
 
 
 
 
 
 
 
 
545
 
546
  try:
547
  while True:
548
  try:
 
549
  data = await ws.receive_json()
550
  except WebSocketDisconnect:
551
  print("❌ Client disconnected cleanly")
552
+ break
553
  except Exception as e:
554
  print(f"⚠️ Connection lost: {e}")
555
+ break
556
 
 
557
  if "config" in data:
558
  voice_name = data.get("voice", "🇺🇸 🚺 Bella")
559
  voice_code = VOICE_CHOICES.get(voice_name, voice_name)
 
 
560
  speed = float(data.get("speed", speed))
561
+ get_voice_pack(voice_code)
562
+
 
563
  if "text" in data:
564
  text = data["text"]
565
+
566
+ # Build whole utterance first so we do not interleave chunks across users
567
+ phoneme_chunks = await loop.run_in_executor(
568
+ G2P_EXECUTOR,
569
+ lambda: text_to_phoneme_chunks(text, voice_code)
570
+ )
571
+
572
+ if not phoneme_chunks:
573
+ continue
574
+
575
+ try:
576
+ await INFERENCE_QUEUE.put({
577
+ "ws": ws,
578
+ "voice": voice_code,
579
+ "speed": speed,
580
+ "chunks": phoneme_chunks,
581
+ })
582
+ except asyncio.QueueFull:
583
+ # Hard backpressure on HF free tier
584
+ try:
585
+ await ws.send_json({"type": "error", "message": "Server busy. Try again."})
586
+ except Exception:
587
+ pass
588
+
589
  if "flush" in data:
590
+ # Client controlled. No server side buffering needed here.
591
  pass
592
 
593
  except Exception as e:
594
  print(f"🔥 Critical WS Error: {e}")
 
 
595
 
596
+ # Mount gradio on FastAPI
597
  final_app = gr.mount_gradio_app(api, app, path="/")
598
 
599
  if __name__ == "__main__":