#!/usr/bin/env python3 """Gradio Space for Supertonic 3 MLX.""" from __future__ import annotations import os import time from functools import lru_cache import gradio as gr import numpy as np from huggingface_hub import snapshot_download from supertonic_mlx import SupertonicMLX os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") MODEL_ID = os.environ.get("SUPERTONIC_MLX_MODEL", "mlx-community/supertonic-3") LANGUAGES = [ "en", "ko", "ja", "ar", "bg", "cs", "da", "de", "el", "es", "et", "fi", "fr", "hi", "hr", "hu", "id", "it", "lt", "lv", "nl", "pl", "pt", "ro", "ru", "sk", "sl", "sv", "tr", "uk", "vi", "na", ] VOICES = ["M1", "M2", "M3", "M4", "M5", "F1", "F2", "F3", "F4", "F5"] @lru_cache(maxsize=1) def load_model() -> SupertonicMLX: model_dir = snapshot_download( repo_id=MODEL_ID, repo_type="model", allow_patterns=[ "graphs/*.json", "weights/*.npz", "voice_styles/*.json", "tts.json", "unicode_indexer.json", "mlx_manifest.json", ], ) return SupertonicMLX.from_pretrained(model_dir) def synthesize(text: str, lang: str, voice: str, total_step: int, speed: float, seed: int): if not text or not text.strip(): return None, "Enter text to synthesize." try: start = time.perf_counter() was_loaded = load_model.cache_info().currsize > 0 load_start = time.perf_counter() tts = load_model() load_elapsed = time.perf_counter() - load_start style = tts.get_voice_style(voice) infer_start = time.perf_counter() wav, duration = tts.synthesize( text.strip(), lang, style, total_step=int(total_step), speed=float(speed), seed=int(seed), ) infer_elapsed = time.perf_counter() - infer_start samples = int(tts.sample_rate * float(duration[0])) audio = np.clip(wav[0, :samples], -1.0, 1.0) audio_i16 = (audio * 32767).astype(np.int16) audio_duration = samples / tts.sample_rate total_elapsed = time.perf_counter() - start load_text = "cached" if was_loaded else f"{load_elapsed:.2f}s" return ( (tts.sample_rate, audio_i16), ( f"Done. elapsed={total_elapsed:.2f}s, " f"model_load={load_text}, synth={infer_elapsed:.2f}s, " f"audio={audio_duration:.2f}s" ), ) except Exception as exc: return None, f"Error: {type(exc).__name__}: {exc}" with gr.Blocks(title="Supertonic 3 MLX", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # Supertonic 3 MLX Lightning-fast multilingual text-to-speech using the community MLX conversion. """ ) with gr.Row(): with gr.Column(scale=1): text = gr.Textbox( label="Text", lines=5, value="Supertonic 3 is running through an MLX graph runtime.", ) with gr.Row(): lang = gr.Dropdown(LANGUAGES, value="en", label="Language") voice = gr.Dropdown(VOICES, value="M1", label="Voice") with gr.Accordion("Generation Settings", open=False): total_step = gr.Slider(5, 8, value=5, step=1, label="Total Steps") speed = gr.Slider(0.7, 2.0, value=1.05, step=0.05, label="Speed") seed = gr.Number(value=0, precision=0, label="Seed") btn = gr.Button("Generate", variant="primary") with gr.Column(scale=1): audio = gr.Audio(label="Output", type="numpy") status = gr.Textbox(label="Status") btn.click(synthesize, [text, lang, voice, total_step, speed, seed], [audio, status]) gr.Examples( examples=[ ["Supertonic 3 is running on CPU fallback.", "en", "M1", 5, 1.05, 0], ["こんにちは。MLX 変換をテストしています。", "ja", "F1", 5, 1.05, 1], ["오늘은 MLX 변환을 테스트하고 있습니다.", "ko", "F2", 5, 1.05, 2], ], inputs=[text, lang, voice, total_step, speed, seed], ) if __name__ == "__main__": demo.queue(max_size=8).launch()