supertonic-3 / app.py
ailuntz's picture
Use cached snapshot download for model files
dfd2f88 verified
#!/usr/bin/env python3
"""Gradio Space for Supertonic 3 MLX."""
from __future__ import annotations
import os
import time
from functools import lru_cache
import gradio as gr
import numpy as np
from huggingface_hub import snapshot_download
from supertonic_mlx import SupertonicMLX
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
MODEL_ID = os.environ.get("SUPERTONIC_MLX_MODEL", "mlx-community/supertonic-3")
LANGUAGES = [
"en", "ko", "ja", "ar", "bg", "cs", "da", "de", "el", "es", "et", "fi",
"fr", "hi", "hr", "hu", "id", "it", "lt", "lv", "nl", "pl", "pt", "ro",
"ru", "sk", "sl", "sv", "tr", "uk", "vi", "na",
]
VOICES = ["M1", "M2", "M3", "M4", "M5", "F1", "F2", "F3", "F4", "F5"]
@lru_cache(maxsize=1)
def load_model() -> SupertonicMLX:
model_dir = snapshot_download(
repo_id=MODEL_ID,
repo_type="model",
allow_patterns=[
"graphs/*.json",
"weights/*.npz",
"voice_styles/*.json",
"tts.json",
"unicode_indexer.json",
"mlx_manifest.json",
],
)
return SupertonicMLX.from_pretrained(model_dir)
def synthesize(text: str, lang: str, voice: str, total_step: int, speed: float, seed: int):
if not text or not text.strip():
return None, "Enter text to synthesize."
try:
start = time.perf_counter()
was_loaded = load_model.cache_info().currsize > 0
load_start = time.perf_counter()
tts = load_model()
load_elapsed = time.perf_counter() - load_start
style = tts.get_voice_style(voice)
infer_start = time.perf_counter()
wav, duration = tts.synthesize(
text.strip(),
lang,
style,
total_step=int(total_step),
speed=float(speed),
seed=int(seed),
)
infer_elapsed = time.perf_counter() - infer_start
samples = int(tts.sample_rate * float(duration[0]))
audio = np.clip(wav[0, :samples], -1.0, 1.0)
audio_i16 = (audio * 32767).astype(np.int16)
audio_duration = samples / tts.sample_rate
total_elapsed = time.perf_counter() - start
load_text = "cached" if was_loaded else f"{load_elapsed:.2f}s"
return (
(tts.sample_rate, audio_i16),
(
f"Done. elapsed={total_elapsed:.2f}s, "
f"model_load={load_text}, synth={infer_elapsed:.2f}s, "
f"audio={audio_duration:.2f}s"
),
)
except Exception as exc:
return None, f"Error: {type(exc).__name__}: {exc}"
with gr.Blocks(title="Supertonic 3 MLX", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Supertonic 3 MLX
Lightning-fast multilingual text-to-speech using the community MLX conversion.
"""
)
with gr.Row():
with gr.Column(scale=1):
text = gr.Textbox(
label="Text",
lines=5,
value="Supertonic 3 is running through an MLX graph runtime.",
)
with gr.Row():
lang = gr.Dropdown(LANGUAGES, value="en", label="Language")
voice = gr.Dropdown(VOICES, value="M1", label="Voice")
with gr.Accordion("Generation Settings", open=False):
total_step = gr.Slider(5, 8, value=5, step=1, label="Total Steps")
speed = gr.Slider(0.7, 2.0, value=1.05, step=0.05, label="Speed")
seed = gr.Number(value=0, precision=0, label="Seed")
btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
audio = gr.Audio(label="Output", type="numpy")
status = gr.Textbox(label="Status")
btn.click(synthesize, [text, lang, voice, total_step, speed, seed], [audio, status])
gr.Examples(
examples=[
["Supertonic 3 is running on CPU fallback.", "en", "M1", 5, 1.05, 0],
["こんにちは。MLX 変換をテストしています。", "ja", "F1", 5, 1.05, 1],
["오늘은 MLX 변환을 테스트하고 있습니다.", "ko", "F2", 5, 1.05, 2],
],
inputs=[text, lang, voice, total_step, speed, seed],
)
if __name__ == "__main__":
demo.queue(max_size=8).launch()