auralodyssey commited on
Commit
4d29cb7
·
verified ·
1 Parent(s): 9986dc0

main logic for kokoro

Files changed (1) hide show
  1. app.py +213 -0
app.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import sys
4
+ import time
5
+ import json
6
+ import numpy as np
7
+ import tempfile
8
+ from huggingface_hub import snapshot_download
9
+ from onnxruntime import InferenceSession, SessionOptions, GraphOptimizationLevel
10
+ import scipy.io.wavfile as wavfile
11
+ import gradio as gr
12
+
13
+ # Misaki G2P
14
+ try:
15
+ from misaki import en as misaki_en
16
+ except Exception as e:
17
+ print("Misaki import failed", e)
18
+ raise
19
+
20
+ # Config
21
+ HF_REPO = "onnx-community/Kokoro-82M-v1.0-ONNX"
22
+ LOCAL_DIR = "/tmp/kokoro_model"
23
+ ONNX_SUBPATH = "onnx/model_q8f16.onnx" # best CPU quantized file
24
+ VOICES_DIRNAME = "voices"
25
+ SAMPLE_RATE = 24000 # Kokoro uses 24k in README
26
+
27
+ # Ensure local dir
28
+ os.makedirs(LOCAL_DIR, exist_ok=True)
29
+
30
+ def download_repo():
31
+ """Download model files to LOCAL_DIR (cached by HF hub)."""
32
+ # This will download the repo into hf cache and give us a path
33
+ print("Downloading model repo snapshot from HF. This may take several minutes on first run.")
34
+ repo_dir = snapshot_download(repo_id=HF_REPO, cache_dir=LOCAL_DIR, local_dir=LOCAL_DIR, repo_type="model")
35
+ print("Snapshot downloaded to", repo_dir)
36
+ return repo_dir
37
+
38
+ def load_tokenizer_map(repo_dir):
39
+ # tokenizer.json contains mapping from phoneme token text -> id
40
+ tok_path = os.path.join(repo_dir, "tokenizer.json")
41
+ if not os.path.exists(tok_path):
42
+ raise FileNotFoundError(f"tokenizer.json not found at {tok_path}")
43
+ with open(tok_path, "r", encoding="utf-8") as f:
44
+ tok = json.load(f)
45
+ # tokenizer.json may follow HF tokenizers format; we need map: piece -> id
46
+ if "model" in tok and "vocab" in tok["model"]:
47
+ vocab = tok["model"]["vocab"]
48
+ elif "vocab" in tok:
49
+ vocab = tok["vocab"]
50
+ else:
51
+ # attempt fallback
52
+ vocab = tok.get("vocab", {})
53
+ piece_to_id = {}
54
+ if isinstance(vocab, dict):
55
+ # typical mapping piece -> id
56
+ piece_to_id = vocab
57
+ else:
58
+ # try tokens list (rare)
59
+ for i, p in enumerate(vocab):
60
+ piece_to_id[p] = i
61
+ return piece_to_id
62
+
63
+ def make_session(onnx_path):
64
+ sess_opts = SessionOptions()
65
+ sess_opts.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
66
+ # CPU provider explicit
67
+ sess = InferenceSession(onnx_path, sess_options=sess_opts, providers=["CPUExecutionProvider"])
68
+ return sess
69
+
70
+ # Lazy global
71
+ _repo_dir = None
72
+ _sess = None
73
+ _piece_to_id = None
74
+ _voices_arr = None
75
+
76
+ def ensure_loaded():
77
+ global _repo_dir, _sess, _piece_to_id, _voices_arr
78
+ if _repo_dir is None:
79
+ _repo_dir = download_repo()
80
+ if _piece_to_id is None:
81
+ _piece_to_id = load_tokenizer_map(_repo_dir)
82
+ if _sess is None:
83
+ onnx_path = os.path.join(_repo_dir, ONNX_SUBPATH)
84
+ if not os.path.exists(onnx_path):
85
+ # try alternative names
86
+ candidates = [p for p in os.listdir(os.path.join(_repo_dir, "onnx")) if p.endswith(".onnx")]
87
+ if not candidates:
88
+ raise FileNotFoundError("No ONNX model file found in repo/onnx")
89
+ onnx_path = os.path.join(_repo_dir, "onnx", candidates[0])
90
+ print("Loading onnx model:", onnx_path)
91
+ _sess = make_session(onnx_path)
92
+ if _voices_arr is None:
93
+ # read voices list from voices folder; we'll lazily load per voice later as needed
94
+ voices_path = os.path.join(_repo_dir, VOICES_DIRNAME)
95
+ if not os.path.exists(voices_path):
96
+ raise FileNotFoundError("voices folder not found in repo")
97
+ _voices_arr = {} # dict voice_name -> numpy array
98
+ return
99
+
100
+ def tokens_from_misaki(text):
101
+ # Use misaki to produce phonemes and tokens. misaki returns phonemes, tokens
102
+ # tokens can be a list of ints or token objects. We try to extract ints.
103
+ g2p = misaki_en.G2P(trf=False, british=False, fallback=None)
104
+ phonemes, tokens = g2p(text)
105
+ # tokens may be nested lists, token objects etc.
106
+ flat_ids = []
107
+ for entry in tokens:
108
+ if isinstance(entry, list):
109
+ # nested list of token objects
110
+ for tk in entry:
111
+ if hasattr(tk, "id"):
112
+ flat_ids.append(int(tk.id))
113
+ elif isinstance(tk, int):
114
+ flat_ids.append(int(tk))
115
+ else:
116
+ # fallback: try string repr and map using tokenizer mapping
117
+ token_str = str(tk)
118
+ if token_str in _piece_to_id:
119
+ flat_ids.append(int(_piece_to_id[token_str]))
120
+ else:
121
+ raise ValueError("Unknown token object and not in tokenizer map: " + token_str)
122
+ else:
123
+ if isinstance(entry, int):
124
+ flat_ids.append(int(entry))
125
+ elif hasattr(entry, "id"):
126
+ flat_ids.append(int(entry.id))
127
+ else:
128
+ token_str = str(entry)
129
+ if token_str in _piece_to_id:
130
+ flat_ids.append(int(_piece_to_id[token_str]))
131
+ else:
132
+ raise ValueError("Unknown token and not in tokenizer map: " + token_str)
133
+ # sanity
134
+ if len(flat_ids) > 510:
135
+ raise ValueError("Tokenized length exceeds model context length (<=510).")
136
+ return flat_ids, phonemes
137
+
138
+ def load_voice_vector(repo_dir, voice):
139
+ voices_folder = os.path.join(repo_dir, VOICES_DIRNAME)
140
+ if not os.path.exists(voices_folder):
141
+ raise FileNotFoundError("voices folder missing")
142
+ file_path = os.path.join(voices_folder, f"{voice}.bin")
143
+ if not os.path.exists(file_path):
144
+ raise FileNotFoundError(f"voice file {voice}.bin not found in voices folder")
145
+ arr = np.fromfile(file_path, dtype=np.float32).reshape(-1, 1, 256) # shape checks per README
146
+ return arr
147
+
148
+ def infer_kokoro(text, voice="af_bella", speed=1.0):
149
+ ensure_loaded()
150
+ # get token ids
151
+ token_ids, phonemes = tokens_from_misaki(text)
152
+ repo_dir = _repo_dir
153
+ # load voice vector
154
+ style_arr = load_voice_vector(repo_dir, voice)
155
+ # pick style vector by length tokens; README uses voices[len(tokens)]
156
+ idx = min(len(token_ids), style_arr.shape[0] - 1)
157
+ ref_s = style_arr[idx] # shape (1, 256) expected
158
+ # build input tokens with pad 0 at start and end
159
+ input_ids = np.array([[0] + token_ids + [0]], dtype=np.int64)
160
+ speed_arr = np.ones((1,), dtype=np.float32) * float(speed)
161
+ # ONNX session run
162
+ ort_inputs = {
163
+ "input_ids": input_ids,
164
+ "style": ref_s.astype(np.float32),
165
+ "speed": speed_arr.astype(np.float32),
166
+ }
167
+ out = _sess.run(None, ort_inputs)[0] # expected shape: (1, T)
168
+ # convert to int16 PCM for wav
169
+ audio = np.clip(out[0], -1.0, 1.0)
170
+ # map float32 [-1,1] to int16
171
+ pcm16 = (audio * 32767.0).astype(np.int16)
172
+ # write to temp wav and return path
173
+ tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
174
+ wavfile.write(tmp.name, SAMPLE_RATE, pcm16)
175
+ tmp.close()
176
+ return tmp.name
177
+
178
+ # Gradio UI and API
179
+ with gr.Blocks() as demo:
180
+ gr.Markdown("### Kokoro ONNX TTS CPU Space")
181
+ with gr.Row():
182
+ txt = gr.Textbox(label="Text", value="Hello world", lines=3)
183
+ voice = gr.Dropdown(choices=[], label="Voice (loaded after model)", value="af_bella")
184
+ speed = gr.Slider(0.5, 2.0, value=1.0, step=0.01, label="Speed")
185
+ btn = gr.Button("Synthesize")
186
+ audio_out = gr.Audio(label="Audio", type="file")
187
+
188
+ def on_load():
189
+ ensure_loaded()
190
+ # read voices folder names
191
+ repo_dir = _repo_dir
192
+ voices_list = []
193
+ vf = os.path.join(repo_dir, VOICES_DIRNAME)
194
+ for f in os.listdir(vf):
195
+ if f.endswith(".bin"):
196
+ voices_list.append(f[:-4])
197
+ return gr.Dropdown.update(choices=sorted(voices_list), value=voices_list[0] if voices_list else None)
198
+
199
+ def synth(text_in, voice_in, speed_in):
200
+ if not text_in or not text_in.strip():
201
+ return None
202
+ t0 = time.time()
203
+ wav_path = infer_kokoro(text_in, voice_in, speed_in)
204
+ elapsed = time.time() - t0
205
+ print(f"Generated audio at {wav_path} in {elapsed:.2f}s")
206
+ return wav_path
207
+
208
+ demo.load(on_load)
209
+ btn.click(synth, inputs=[txt, voice, speed], outputs=[audio_out], api_name="/tts")
210
+
211
+ if __name__ == "__main__":
212
+ demo.queue(concurrency_count=1) # keep low concurrency on free CPU space
213
+ demo.launch()