Akjava commited on
Commit
00a6409
·
1 Parent(s): 9e49fc7
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ venv/
app.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ONNX-based TTS Gradio Application for Japanese
3
+ PyTorch-free implementation using ONNX Runtime
4
+ """
5
+
6
+ import glob
7
+ import os
8
+ import tempfile
9
+ from time import perf_counter
10
+ from typing import Optional
11
+
12
+ import gradio as gr
13
+ import numpy as np
14
+ import onnxruntime as ort
15
+ import pyopenjtalk
16
+ import soundfile as sf
17
+
18
+ # ============================================================================
19
+ # Configuration
20
+ # ============================================================================
21
+
22
+ # Get script directory
23
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
24
+ MODELS_DIR = os.path.join(SCRIPT_DIR, "models")
25
+ DEFAULT_MODEL = "g003_ep5709.onnx"
26
+ MODEL_PATH = os.getenv("MODEL_PATH", os.path.join(MODELS_DIR, DEFAULT_MODEL))
27
+ VOCODER_PATH = os.getenv("VOCODER_PATH", None)
28
+ USE_GPU = os.getenv("USE_GPU", "false").lower() == "true"
29
+ SAMPLE_RATE = 22050
30
+
31
+
32
+ def get_available_models():
33
+ """Get list of available ONNX models from models directory"""
34
+ if not os.path.exists(MODELS_DIR):
35
+ return [DEFAULT_MODEL]
36
+
37
+ models = glob.glob(os.path.join(MODELS_DIR, "*.onnx"))
38
+ model_names = [os.path.basename(m) for m in models]
39
+
40
+ if not model_names:
41
+ return [DEFAULT_MODEL]
42
+
43
+ return sorted(model_names)
44
+
45
+ # ============================================================================
46
+ # Text Processing (PyTorch-free)
47
+ # ============================================================================
48
+
49
+ # Load symbols from matcha
50
+ _pad = "_"
51
+ _punctuation = ';:,.!?¡¿—…"«»"" '
52
+ _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
53
+ _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
54
+
55
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
56
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
57
+
58
+
59
+ def text_to_sequence(text):
60
+ """Convert text to sequence of IDs"""
61
+ sequence = []
62
+ for symbol in text:
63
+ if symbol in _symbol_to_id:
64
+ sequence.append(_symbol_to_id[symbol])
65
+ else:
66
+ sequence.append(0) # Unknown symbol
67
+ return sequence
68
+
69
+
70
+ def intersperse(sequence, token):
71
+ """Intersperse token between elements of sequence"""
72
+ result = [token] * (len(sequence) * 2 + 1)
73
+ result[1::2] = sequence
74
+ return result
75
+
76
+
77
+ def process_japanese_text(text: str):
78
+ """Process Japanese text to phoneme sequence"""
79
+ if not text.strip():
80
+ raise ValueError("Text cannot be empty")
81
+
82
+ # Phonemize using pyopenjtalk
83
+ phonemes = pyopenjtalk.g2p(text, kana=False)
84
+ phonemes = phonemes.replace(" ", "")
85
+ phonemes = phonemes.replace("pau", " ")
86
+
87
+ print(f"Input: {text}")
88
+ print(f"Phonemes: {phonemes}")
89
+
90
+ # Text to sequence
91
+ sequence = text_to_sequence(phonemes)
92
+
93
+ # Intersperse with padding
94
+ sequence = intersperse(sequence, 0)
95
+
96
+ # Convert to numpy
97
+ x = np.array(sequence, dtype=np.int64)[np.newaxis, :]
98
+ x_lengths = np.array([x.shape[-1]], dtype=np.int64)
99
+
100
+ return x, x_lengths
101
+
102
+
103
+ # ============================================================================
104
+ # ONNX Model Manager
105
+ # ============================================================================
106
+
107
+ class ONNXModelManager:
108
+ """Manages ONNX model loading and inference"""
109
+
110
+ def __init__(self, model_path: str, vocoder_path: Optional[str] = None, use_gpu: bool = False):
111
+ self.model_path = model_path
112
+ self.vocoder_path = vocoder_path
113
+ self.use_gpu = use_gpu
114
+
115
+ # Select execution providers
116
+ if use_gpu:
117
+ self.providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
118
+ else:
119
+ self.providers = ["CPUExecutionProvider"]
120
+
121
+ self.model = None
122
+ self.vocoder = None
123
+ self.is_multi_speaker = False
124
+ self.has_vocoder_embedded = False
125
+
126
+ self._load_model()
127
+
128
+ def _load_model(self):
129
+ """Load ONNX model(s)"""
130
+ print(f"Loading model from {self.model_path} with providers {self.providers}")
131
+ self.model = ort.InferenceSession(self.model_path, providers=self.providers)
132
+
133
+ model_inputs = self.model.get_inputs()
134
+ model_outputs = list(self.model.get_outputs())
135
+
136
+ self.is_multi_speaker = len(model_inputs) == 4
137
+ self.has_vocoder_embedded = model_outputs[0].name == "wav"
138
+
139
+ print(f"Model loaded: multi_speaker={self.is_multi_speaker}, "
140
+ f"vocoder_embedded={self.has_vocoder_embedded}")
141
+
142
+ # Load external vocoder if needed
143
+ if not self.has_vocoder_embedded and self.vocoder_path:
144
+ print(f"Loading external vocoder from {self.vocoder_path}")
145
+ self.vocoder = ort.InferenceSession(self.vocoder_path, providers=self.providers)
146
+
147
+ def synthesize(
148
+ self,
149
+ x: np.ndarray,
150
+ x_lengths: np.ndarray,
151
+ scales: np.ndarray,
152
+ spks: Optional[np.ndarray] = None
153
+ ):
154
+ """Run ONNX inference"""
155
+ inputs = {
156
+ "x": x,
157
+ "x_lengths": x_lengths,
158
+ "scales": scales,
159
+ }
160
+
161
+ if self.is_multi_speaker and spks is not None:
162
+ inputs["spks"] = spks
163
+
164
+ # Run Matcha inference
165
+ outputs = self.model.run(None, inputs)
166
+
167
+ if self.has_vocoder_embedded:
168
+ # End-to-end: model outputs waveform directly
169
+ return outputs[0], outputs[1] # wav, wav_lengths
170
+ else:
171
+ # Model outputs mel spectrogram
172
+ mels, mel_lengths = outputs[0], outputs[1]
173
+
174
+ if self.vocoder is not None:
175
+ # Run external vocoder
176
+ vocoder_inputs = {self.vocoder.get_inputs()[0].name: mels}
177
+ wavs = self.vocoder.run(None, vocoder_inputs)[0]
178
+ wavs = wavs.squeeze(1)
179
+ wav_lengths = mel_lengths * 256
180
+ return wavs, wav_lengths
181
+ else:
182
+ # No vocoder available, return mel
183
+ return mels, mel_lengths
184
+
185
+
186
+ # Initialize model managers (one per model)
187
+ model_managers = {}
188
+ current_model = None
189
+
190
+
191
+ def get_model_manager(model_name: str) -> ONNXModelManager:
192
+ """Get or create model manager for specified model"""
193
+ global model_managers, current_model
194
+
195
+ model_path = os.path.join(MODELS_DIR, model_name)
196
+
197
+ if model_name not in model_managers:
198
+ print(f"Loading new model: {model_name}")
199
+ model_managers[model_name] = ONNXModelManager(
200
+ model_path=model_path,
201
+ vocoder_path=VOCODER_PATH,
202
+ use_gpu=USE_GPU
203
+ )
204
+
205
+ current_model = model_name
206
+ return model_managers[model_name]
207
+
208
+
209
+ # Initialize default model
210
+ get_model_manager(DEFAULT_MODEL)
211
+
212
+ # ============================================================================
213
+ # Gradio Interface Functions
214
+ # ============================================================================
215
+
216
+
217
+ def synthesise(
218
+ text: str,
219
+ model_name: str,
220
+ speaker_id: int,
221
+ temperature: float,
222
+ speaking_rate: float,
223
+ ):
224
+ """
225
+ Synthesize speech from Japanese text
226
+
227
+ Args:
228
+ text: Japanese text input
229
+ model_name: Model filename
230
+ speaker_id: Speaker ID (for multi-speaker models)
231
+ temperature: Sampling temperature
232
+ speaking_rate: Speaking rate multiplier
233
+
234
+ Returns:
235
+ Tuple of (audio_path, phonemes_text)
236
+ """
237
+ t0 = perf_counter()
238
+
239
+ try:
240
+ # Get model manager
241
+ manager = get_model_manager(model_name)
242
+
243
+ # Process text
244
+ x, x_lengths = process_japanese_text(text)
245
+
246
+ # Prepare scales
247
+ scales = np.array([temperature, speaking_rate], dtype=np.float32)
248
+
249
+ # Prepare speaker ID
250
+ spks = None
251
+ if manager.is_multi_speaker and speaker_id >= 0:
252
+ spks = np.array([speaker_id], dtype=np.int64)
253
+
254
+ # Run inference
255
+ outputs, output_lengths = manager.synthesize(x, x_lengths, scales, spks)
256
+
257
+ # Extract single result
258
+ audio = outputs[0][:output_lengths[0]]
259
+ inference_time = perf_counter() - t0
260
+
261
+ # Calculate RTF
262
+ audio_duration_sec = len(audio) / SAMPLE_RATE
263
+ rtf = inference_time / audio_duration_sec
264
+
265
+ print(f"Inference time: {inference_time:.3f}s, "
266
+ f"Audio duration: {audio_duration_sec:.3f}s, "
267
+ f"RTF: {rtf:.3f}")
268
+
269
+ # Save to temporary file
270
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
271
+ sf.write(fp.name, audio, SAMPLE_RATE, "PCM_24")
272
+ audio_path = fp.name
273
+
274
+ # Get phonemes for display
275
+ phonemes = pyopenjtalk.g2p(text, kana=False)
276
+ phonemes = phonemes.replace(" ", "")
277
+ phonemes = phonemes.replace("pau", " ")
278
+
279
+ info = f"Model: {model_name}\n"
280
+ info += f"Speaker ID: {speaker_id if manager.is_multi_speaker else 'N/A (Single speaker)'}\n"
281
+ info += f"Phonemes: {phonemes}\n"
282
+ info += f"RTF: {rtf:.3f}"
283
+
284
+ return audio_path, info
285
+
286
+ except Exception as e:
287
+ print(f"Error: {e}")
288
+ raise
289
+
290
+
291
+ # ============================================================================
292
+ # Gradio Application
293
+ # ============================================================================
294
+
295
+ def create_gradio_interface():
296
+ """Create Gradio interface"""
297
+
298
+ # Get available models
299
+ available_models = get_available_models()
300
+
301
+ with gr.Blocks(
302
+ title="🍵 Matcha-TTS ONNX (Japanese)",
303
+ ) as demo:
304
+ gr.Markdown(
305
+ """
306
+ # 🍵 Matcha-TTS ONNX - Japanese Text-to-Speech
307
+
308
+ ### PyTorch-free implementation using ONNX Runtime
309
+ """
310
+ )
311
+
312
+ with gr.Row():
313
+ with gr.Column():
314
+ # Model Selection
315
+ model_dropdown = gr.Dropdown(
316
+ label="モデル / Model",
317
+ choices=available_models,
318
+ value=DEFAULT_MODEL if DEFAULT_MODEL in available_models else available_models[0],
319
+ interactive=True
320
+ )
321
+
322
+ text_input = gr.Textbox(
323
+ label="日本語テキスト / Japanese Text",
324
+ value="こんにちは、世界!",
325
+ lines=3,
326
+ placeholder="日本語のテキストを入力してください..."
327
+ )
328
+
329
+ # Speaker ID
330
+ speaker_id = gr.Number(
331
+ label="Speaker ID (スピーカーID)",
332
+ value=0,
333
+ minimum=0,
334
+ maximum=99,
335
+ precision=0,
336
+ info="単一スピーカーモデルでは無視されます"
337
+ )
338
+
339
+ with gr.Row():
340
+ temperature = gr.Slider(
341
+ label="Temperature (温度)",
342
+ minimum=0.0,
343
+ maximum=1.0,
344
+ step=0.01,
345
+ value=0.667,
346
+ info="サンプリングのランダム性"
347
+ )
348
+
349
+ speaking_rate = gr.Slider(
350
+ label="Speaking Rate (話速)",
351
+ minimum=0.1,
352
+ maximum=5.0,
353
+ step=0.1,
354
+ value=1.0,
355
+ info="1.0 = 標準速度"
356
+ )
357
+
358
+ with gr.Row():
359
+ synthesise_btn = gr.Button(
360
+ "🎵 音声生成 / Synthesize",
361
+ variant="primary",
362
+ size="lg"
363
+ )
364
+ clear_btn = gr.Button(
365
+ "クリア / Clear",
366
+ variant="secondary"
367
+ )
368
+
369
+ with gr.Column():
370
+ audio_output = gr.Audio(
371
+ label="生成音声 / Generated Audio",
372
+ type="filepath"
373
+ )
374
+
375
+ info_output = gr.Textbox(
376
+ label="情報 / Information",
377
+ lines=5,
378
+ interactive=False
379
+ )
380
+
381
+ # Examples
382
+ gr.Examples(
383
+ examples=[
384
+ ["こんにちは、世界!", "g003_ep5709.onnx", 0, 0.667, 1.0],
385
+ ["本日は晴天なり。", "g003_ep5709.onnx", 0, 0.667, 1.0],
386
+ ["日本語の音声合成をテストしています。", "g003_ep5709.onnx", 0, 0.667, 1.0],
387
+ ["人工知能の進化は目覚ましいものがあります。", "g003_ep5709.onnx", 0, 0.667, 1.0],
388
+ ],
389
+ inputs=[text_input, model_dropdown, speaker_id, temperature, speaking_rate],
390
+ label="例文 / Examples"
391
+ )
392
+
393
+ # Event handlers
394
+ synthesise_btn.click(
395
+ fn=synthesise,
396
+ inputs=[text_input, model_dropdown, speaker_id, temperature, speaking_rate],
397
+ outputs=[audio_output, info_output]
398
+ )
399
+
400
+ clear_btn.click(
401
+ fn=lambda: (None, None, ""),
402
+ outputs=[audio_output, info_output]
403
+ )
404
+
405
+ gr.Markdown(
406
+ """
407
+ ---
408
+ ### 情報 / Information
409
+
410
+ - **モデル**: ONNX (PyTorch-free)
411
+ - **サンプルレート**: 22050 Hz
412
+ - **音素化**: pyopenjtalk
413
+ - **推論**: ONNX Runtime
414
+ - **モデル自動切り替え**: 選択したモデルを自動的にロード
415
+
416
+ ### Speaker ID について
417
+ - **単一スピーカーモデル**: Speaker ID は無視されます
418
+ - **マルチスピーカーモデル**: Speaker ID で話者を切り替え
419
+ """
420
+ )
421
+
422
+ return demo
423
+
424
+
425
+ # ============================================================================
426
+ # Main
427
+ # ============================================================================
428
+
429
+ if __name__ == "__main__":
430
+ demo = create_gradio_interface()
431
+ demo.launch(
432
+ server_name="0.0.0.0",
433
+ server_port=7860,
434
+ share=False,
435
+ show_error=True
436
+ )
models/g003_ep5709.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ff5be57a656822250aabd0b32a7b942332de3d1a7fe6dacbe87ac7b4075c9af
3
+ size 140821217
models/g003_ep5709_qint8.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1980f50cf9e30b728fc6c10075d698b8aee8d63144e619090502c95185467bf2
3
+ size 43394106
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ numpy
3
+ onnxruntime-gpu
4
+ pyopenjtalk
5
+ soundfile