Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| ONNX-based TTS Gradio Application for Japanese | |
| PyTorch-free implementation using ONNX Runtime | |
| """ | |
| import glob | |
| import os | |
| import tempfile | |
| from time import perf_counter | |
| from typing import Optional | |
| import gradio as gr | |
| import numpy as np | |
| import onnxruntime as ort | |
| import pyopenjtalk | |
| import soundfile as sf | |
| try: | |
| import spaces | |
| except ImportError: | |
| class spaces: | |
| def GPU(func): | |
| return func | |
| # ============================================================================ | |
| # Configuration | |
| # ============================================================================ | |
| # Get script directory | |
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| MODELS_DIR = os.path.join(SCRIPT_DIR, "models") | |
| DEFAULT_MODEL = "g003_ep5709.onnx" | |
| MODEL_PATH = os.getenv("MODEL_PATH", os.path.join(MODELS_DIR, DEFAULT_MODEL)) | |
| VOCODER_PATH = os.getenv("VOCODER_PATH", None) | |
| USE_GPU = os.getenv("USE_GPU", "false").lower() == "true" | |
| SAMPLE_RATE = 22050 | |
| DEBUG = os.getenv("DEBUG", "false").lower() == "true" | |
| def get_available_models(): | |
| """Get list of available ONNX models from models directory""" | |
| if not os.path.exists(MODELS_DIR): | |
| return [DEFAULT_MODEL] | |
| models = glob.glob(os.path.join(MODELS_DIR, "*.onnx")) | |
| model_names = [os.path.basename(m) for m in models] | |
| if not model_names: | |
| return [DEFAULT_MODEL] | |
| return sorted(model_names) | |
| # ============================================================================ | |
| # Text Processing (PyTorch-free) | |
| # ============================================================================ | |
| # Load symbols from matcha | |
| _pad = "_" | |
| _punctuation = ';:,.!?¡¿—…"«»"" ' | |
| _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" | |
| _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" | |
| symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) | |
| _symbol_to_id = {s: i for i, s in enumerate(symbols)} | |
| def text_to_sequence(text): | |
| """Convert text to sequence of IDs""" | |
| sequence = [] | |
| for symbol in text: | |
| if symbol in _symbol_to_id: | |
| sequence.append(_symbol_to_id[symbol]) | |
| else: | |
| sequence.append(0) # Unknown symbol | |
| return sequence | |
| def intersperse(sequence, token): | |
| """Intersperse token between elements of sequence""" | |
| result = [token] * (len(sequence) * 2 + 1) | |
| result[1::2] = sequence | |
| return result | |
| def process_japanese_text(text: str): | |
| """Process Japanese text to phoneme sequence""" | |
| if not text.strip(): | |
| raise ValueError("Text cannot be empty") | |
| # Phonemize using pyopenjtalk | |
| phonemes = pyopenjtalk.g2p(text, kana=False) | |
| phonemes = phonemes.replace(" ", "") | |
| phonemes = phonemes.replace("pau", " ") | |
| if DEBUG: | |
| print(f"Input: {text}") | |
| print(f"Phonemes: {phonemes}") | |
| # Text to sequence | |
| sequence = text_to_sequence(phonemes) | |
| # Intersperse with padding | |
| sequence = intersperse(sequence, 0) | |
| # Convert to numpy | |
| x = np.array(sequence, dtype=np.int64)[np.newaxis, :] | |
| x_lengths = np.array([x.shape[-1]], dtype=np.int64) | |
| return x, x_lengths | |
| # ============================================================================ | |
| # ONNX Model Manager | |
| # ============================================================================ | |
| class ONNXModelManager: | |
| """Manages ONNX model loading and inference""" | |
| def __init__(self, model_path: str, vocoder_path: Optional[str] = None, use_gpu: bool = False): | |
| self.model_path = model_path | |
| self.vocoder_path = vocoder_path | |
| self.use_gpu = use_gpu | |
| # Select execution providers | |
| if use_gpu: | |
| self.providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] | |
| else: | |
| self.providers = ["CPUExecutionProvider"] | |
| self.model = None | |
| self.vocoder = None | |
| self.is_multi_speaker = False | |
| self.has_vocoder_embedded = False | |
| self._load_model() | |
| def _load_model(self): | |
| """Load ONNX model(s)""" | |
| if DEBUG: | |
| print(f"Loading model from {self.model_path} with providers {self.providers}") | |
| self.model = ort.InferenceSession(self.model_path, providers=self.providers) | |
| model_inputs = self.model.get_inputs() | |
| model_outputs = list(self.model.get_outputs()) | |
| self.is_multi_speaker = len(model_inputs) == 4 | |
| self.has_vocoder_embedded = model_outputs[0].name == "wav" | |
| if DEBUG: | |
| print(f"Model loaded: multi_speaker={self.is_multi_speaker}, " | |
| f"vocoder_embedded={self.has_vocoder_embedded}") | |
| # Load external vocoder if needed | |
| if not self.has_vocoder_embedded and self.vocoder_path: | |
| if DEBUG: | |
| print(f"Loading external vocoder from {self.vocoder_path}") | |
| self.vocoder = ort.InferenceSession(self.vocoder_path, providers=self.providers) | |
| def synthesize( | |
| self, | |
| x: np.ndarray, | |
| x_lengths: np.ndarray, | |
| scales: np.ndarray, | |
| spks: Optional[np.ndarray] = None | |
| ): | |
| """Run ONNX inference""" | |
| inputs = { | |
| "x": x, | |
| "x_lengths": x_lengths, | |
| "scales": scales, | |
| } | |
| if self.is_multi_speaker and spks is not None: | |
| inputs["spks"] = spks | |
| # Run Matcha inference | |
| outputs = self.model.run(None, inputs) | |
| if self.has_vocoder_embedded: | |
| # End-to-end: model outputs waveform directly | |
| return outputs[0], outputs[1] # wav, wav_lengths | |
| else: | |
| # Model outputs mel spectrogram | |
| mels, mel_lengths = outputs[0], outputs[1] | |
| if self.vocoder is not None: | |
| # Run external vocoder | |
| vocoder_inputs = {self.vocoder.get_inputs()[0].name: mels} | |
| wavs = self.vocoder.run(None, vocoder_inputs)[0] | |
| wavs = wavs.squeeze(1) | |
| wav_lengths = mel_lengths * 256 | |
| return wavs, wav_lengths | |
| else: | |
| # No vocoder available, return mel | |
| return mels, mel_lengths | |
| # Initialize model managers (one per model) | |
| model_managers = {} | |
| current_model = None | |
| def get_model_manager(model_name: str) -> ONNXModelManager: | |
| """Get or create model manager for specified model""" | |
| global model_managers, current_model | |
| model_path = os.path.join(MODELS_DIR, model_name) | |
| if model_name not in model_managers: | |
| if DEBUG: | |
| print(f"Loading new model: {model_name}") | |
| model_managers[model_name] = ONNXModelManager( | |
| model_path=model_path, | |
| vocoder_path=VOCODER_PATH, | |
| use_gpu=USE_GPU | |
| ) | |
| current_model = model_name | |
| return model_managers[model_name] | |
| # Pre-load all available models | |
| if DEBUG: | |
| print("Pre-loading all models for ZeroGPU...") | |
| for model_name in get_available_models(): | |
| get_model_manager(model_name) | |
| if DEBUG: | |
| print("All models loaded.") | |
| # ============================================================================ | |
| # Gradio Interface Functions | |
| # ============================================================================ | |
| def synthesise( | |
| text: str, | |
| model_name: str, | |
| speaker_id: int, | |
| temperature: float, | |
| speaking_rate: float, | |
| ): | |
| """ | |
| Synthesize speech from Japanese text | |
| Args: | |
| text: Japanese text input | |
| model_name: Model filename | |
| speaker_id: Speaker ID (for multi-speaker models) | |
| temperature: Sampling temperature | |
| speaking_rate: Speaking rate multiplier | |
| Returns: | |
| Tuple of (audio_path, phonemes_text) | |
| """ | |
| t0 = perf_counter() | |
| try: | |
| # Get model manager | |
| manager = get_model_manager(model_name) | |
| # Process text | |
| x, x_lengths = process_japanese_text(text) | |
| # Prepare scales | |
| scales = np.array([temperature, speaking_rate], dtype=np.float32) | |
| # Prepare speaker ID | |
| spks = None | |
| if manager.is_multi_speaker and speaker_id >= 0: | |
| spks = np.array([speaker_id], dtype=np.int64) | |
| # Run inference | |
| outputs, output_lengths = manager.synthesize(x, x_lengths, scales, spks) | |
| # Extract single result | |
| audio = outputs[0][:output_lengths[0]] | |
| inference_time = perf_counter() - t0 | |
| # Calculate RTF | |
| audio_duration_sec = len(audio) / SAMPLE_RATE | |
| rtf = inference_time / audio_duration_sec | |
| if DEBUG: | |
| print(f"Inference time: {inference_time:.3f}s, " | |
| f"Audio duration: {audio_duration_sec:.3f}s, " | |
| f"RTF: {rtf:.3f}") | |
| # Save to temporary file | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp: | |
| sf.write(fp.name, audio, SAMPLE_RATE, "PCM_24") | |
| audio_path = fp.name | |
| # Get phonemes for display | |
| phonemes = pyopenjtalk.g2p(text, kana=False) | |
| phonemes = phonemes.replace(" ", "") | |
| phonemes = phonemes.replace("pau", " ") | |
| info = f"Model: {model_name}\n" | |
| info += f"Speaker ID: {speaker_id if manager.is_multi_speaker else 'N/A (Single speaker)'}\n" | |
| info += f"Phonemes: {phonemes}\n" | |
| info += f"RTF: {rtf:.3f}" | |
| return audio_path, info | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| raise | |
| # ============================================================================ | |
| # Gradio Application | |
| # ============================================================================ | |
| def create_gradio_interface(): | |
| """Create Gradio interface""" | |
| # Get available models | |
| available_models = get_available_models() | |
| with gr.Blocks( | |
| title="🍵 Matcha-TTS ONNX (Japanese)", | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🍵 Matcha-TTS ONNX - Japanese Text-to-Speech | |
| ### PyTorch-free implementation using ONNX Runtime | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Model Selection | |
| model_dropdown = gr.Dropdown( | |
| label="モデル / Model", | |
| choices=available_models, | |
| value=DEFAULT_MODEL if DEFAULT_MODEL in available_models else available_models[0], | |
| interactive=True | |
| ) | |
| text_input = gr.Textbox( | |
| label="日本語テキスト / Japanese Text", | |
| value="こんにちは、世界!", | |
| lines=3, | |
| placeholder="日本語のテキストを入力してください..." | |
| ) | |
| # Speaker ID | |
| speaker_id = gr.Number( | |
| label="Speaker ID (スピーカーID)", | |
| value=0, | |
| minimum=0, | |
| maximum=99, | |
| precision=0, | |
| info="単一スピーカーモデルでは無視されます" | |
| ) | |
| with gr.Row(): | |
| temperature = gr.Slider( | |
| label="Temperature (温度)", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.01, | |
| value=0.667, | |
| info="サンプリングのランダム性" | |
| ) | |
| speaking_rate = gr.Slider( | |
| label="Speaking Rate (話速)", | |
| minimum=0.1, | |
| maximum=5.0, | |
| step=0.1, | |
| value=1.0, | |
| info="1.0 = 標準速度" | |
| ) | |
| with gr.Row(): | |
| synthesise_btn = gr.Button( | |
| "🎵 音声生成 / Synthesize", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| clear_btn = gr.Button( | |
| "クリア / Clear", | |
| variant="secondary" | |
| ) | |
| with gr.Column(): | |
| audio_output = gr.Audio( | |
| label="生成音声 / Generated Audio", | |
| type="filepath" | |
| ) | |
| info_output = gr.Textbox( | |
| label="情報 / Information", | |
| lines=5, | |
| interactive=False | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["こんにちは、世界!", "g003_ep5709.onnx", 0, 0.667, 1.0], | |
| ["本日は晴天なり。", "g003_ep5709.onnx", 0, 0.667, 1.0], | |
| ["日本語の音声合成をテストしています。", "g003_ep5709.onnx", 0, 0.667, 1.0], | |
| ["人工知能の進化は目覚ましいものがあります。", "g003_ep5709.onnx", 0, 0.667, 1.0], | |
| ], | |
| inputs=[text_input, model_dropdown, speaker_id, temperature, speaking_rate], | |
| label="例文 / Examples" | |
| ) | |
| # Event handlers | |
| synthesise_btn.click( | |
| fn=synthesise, | |
| inputs=[text_input, model_dropdown, speaker_id, temperature, speaking_rate], | |
| outputs=[audio_output, info_output] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: (None, None, ""), | |
| outputs=[audio_output, info_output] | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### 情報 / Information | |
| - **モデル**: ONNX (PyTorch-free) | |
| - **サンプルレート**: 22050 Hz | |
| - **音素化**: pyopenjtalk | |
| - **推論**: ONNX Runtime | |
| - **モデル自動切り替え**: 選択したモデルを自動的にロード | |
| ### Speaker ID について | |
| - **単一スピーカーモデル**: Speaker ID は無視されます | |
| - **マルチスピーカーモデル**: Speaker ID で話者を切り替え | |
| """ | |
| ) | |
| return demo | |
| # ============================================================================ | |
| # Main | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| demo = create_gradio_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True | |
| ) |