Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Gradio Space for Supertonic 3 MLX.""" | |
| from __future__ import annotations | |
| import os | |
| import time | |
| from functools import lru_cache | |
| import gradio as gr | |
| import numpy as np | |
| from huggingface_hub import snapshot_download | |
| from supertonic_mlx import SupertonicMLX | |
| os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") | |
| MODEL_ID = os.environ.get("SUPERTONIC_MLX_MODEL", "mlx-community/supertonic-3") | |
| LANGUAGES = [ | |
| "en", "ko", "ja", "ar", "bg", "cs", "da", "de", "el", "es", "et", "fi", | |
| "fr", "hi", "hr", "hu", "id", "it", "lt", "lv", "nl", "pl", "pt", "ro", | |
| "ru", "sk", "sl", "sv", "tr", "uk", "vi", "na", | |
| ] | |
| VOICES = ["M1", "M2", "M3", "M4", "M5", "F1", "F2", "F3", "F4", "F5"] | |
| def load_model() -> SupertonicMLX: | |
| model_dir = snapshot_download( | |
| repo_id=MODEL_ID, | |
| repo_type="model", | |
| allow_patterns=[ | |
| "graphs/*.json", | |
| "weights/*.npz", | |
| "voice_styles/*.json", | |
| "tts.json", | |
| "unicode_indexer.json", | |
| "mlx_manifest.json", | |
| ], | |
| ) | |
| return SupertonicMLX.from_pretrained(model_dir) | |
| def synthesize(text: str, lang: str, voice: str, total_step: int, speed: float, seed: int): | |
| if not text or not text.strip(): | |
| return None, "Enter text to synthesize." | |
| try: | |
| start = time.perf_counter() | |
| was_loaded = load_model.cache_info().currsize > 0 | |
| load_start = time.perf_counter() | |
| tts = load_model() | |
| load_elapsed = time.perf_counter() - load_start | |
| style = tts.get_voice_style(voice) | |
| infer_start = time.perf_counter() | |
| wav, duration = tts.synthesize( | |
| text.strip(), | |
| lang, | |
| style, | |
| total_step=int(total_step), | |
| speed=float(speed), | |
| seed=int(seed), | |
| ) | |
| infer_elapsed = time.perf_counter() - infer_start | |
| samples = int(tts.sample_rate * float(duration[0])) | |
| audio = np.clip(wav[0, :samples], -1.0, 1.0) | |
| audio_i16 = (audio * 32767).astype(np.int16) | |
| audio_duration = samples / tts.sample_rate | |
| total_elapsed = time.perf_counter() - start | |
| load_text = "cached" if was_loaded else f"{load_elapsed:.2f}s" | |
| return ( | |
| (tts.sample_rate, audio_i16), | |
| ( | |
| f"Done. elapsed={total_elapsed:.2f}s, " | |
| f"model_load={load_text}, synth={infer_elapsed:.2f}s, " | |
| f"audio={audio_duration:.2f}s" | |
| ), | |
| ) | |
| except Exception as exc: | |
| return None, f"Error: {type(exc).__name__}: {exc}" | |
| with gr.Blocks(title="Supertonic 3 MLX", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # Supertonic 3 MLX | |
| Lightning-fast multilingual text-to-speech using the community MLX conversion. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| text = gr.Textbox( | |
| label="Text", | |
| lines=5, | |
| value="Supertonic 3 is running through an MLX graph runtime.", | |
| ) | |
| with gr.Row(): | |
| lang = gr.Dropdown(LANGUAGES, value="en", label="Language") | |
| voice = gr.Dropdown(VOICES, value="M1", label="Voice") | |
| with gr.Accordion("Generation Settings", open=False): | |
| total_step = gr.Slider(5, 8, value=5, step=1, label="Total Steps") | |
| speed = gr.Slider(0.7, 2.0, value=1.05, step=0.05, label="Speed") | |
| seed = gr.Number(value=0, precision=0, label="Seed") | |
| btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=1): | |
| audio = gr.Audio(label="Output", type="numpy") | |
| status = gr.Textbox(label="Status") | |
| btn.click(synthesize, [text, lang, voice, total_step, speed, seed], [audio, status]) | |
| gr.Examples( | |
| examples=[ | |
| ["Supertonic 3 is running on CPU fallback.", "en", "M1", 5, 1.05, 0], | |
| ["こんにちは。MLX 変換をテストしています。", "ja", "F1", 5, 1.05, 1], | |
| ["오늘은 MLX 변환을 테스트하고 있습니다.", "ko", "F2", 5, 1.05, 2], | |
| ], | |
| inputs=[text, lang, voice, total_step, speed, seed], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=8).launch() | |