zonos2 / app.py
Mike0021's picture
Move Gradio CSS to launch
1a0d910 verified
Raw
History Blame Contribute Delete
9.54 kB
import os
import tempfile
import threading
import time
import wave
from pathlib import Path
os.environ.setdefault("HF_HOME", "/tmp/.cache/huggingface")
os.environ.setdefault("HF_MODULES_CACHE", "/tmp/hf_modules")
os.environ.setdefault("MPLCONFIGDIR", "/tmp/matplotlib")
os.environ.setdefault("ZONOS2_TTS_NORM_CACHE_DIR", "/tmp/zonos2-tts-norm")
os.environ.setdefault("GRADIO_SSR_MODE", "false")
os.environ.setdefault("NUMBA_DISABLE_CUDA", "1")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
for cache_dir in (
os.environ["HF_HOME"],
os.environ["HF_MODULES_CACHE"],
os.environ["MPLCONFIGDIR"],
os.environ["ZONOS2_TTS_NORM_CACHE_DIR"],
):
Path(cache_dir).mkdir(parents=True, exist_ok=True)
print("Importing Space runtime dependencies...", flush=True)
import spaces
import gradio as gr
import numpy as np
import torch
print("Importing ZONOS2 modules...", flush=True)
from zonos2.message import TTSSamplingParams
from zonos2.tokenizer.textnorm import SERVER_TO_NEMO_LANG, TTSTextNormalizer
from zonos2.tts import TTSLLM
print("Imported ZONOS2 modules.", flush=True)
MODEL_ID = "Zyphra/ZONOS2"
SAMPLE_RATE = 44100
LANGUAGES = [
("English (US)", "en_us"),
("English (UK)", "en_gb"),
("French", "fr_fr"),
("German", "de"),
("Spanish", "es"),
("Italian", "it"),
("Portuguese (Brazil)", "pt_br"),
("Japanese", "ja"),
("Mandarin Chinese", "cmn"),
("Korean", "ko"),
]
SPEAKING_RATE_BUCKETS = [
("Default", "default"),
("Very slow", "0"),
("Slow", "1"),
("Relaxed", "2"),
("Natural", "3"),
("Bright", "4"),
("Fast", "5"),
("Very fast", "6"),
("Extreme", "7"),
]
torch.backends.cuda.matmul.allow_tf32 = True
def _load_model() -> TTSLLM:
print(f"Loading {MODEL_ID} for ZeroGPU inference...", flush=True)
started = time.perf_counter()
model = TTSLLM(
model_path=MODEL_ID,
decode_audio=True,
cuda_graph_max_bs=0,
max_running_req=4,
max_extend_tokens=4096,
memory_ratio=0.75,
use_pynccl=False,
)
elapsed = time.perf_counter() - started
print(f"Loaded {MODEL_ID} in {elapsed:.1f}s", flush=True)
return model
TTS: TTSLLM | None = None
TEXT_NORMALIZER = TTSTextNormalizer()
TTS_LOCK = threading.Lock()
def _estimate_duration(*args, **kwargs) -> int:
max_tokens = kwargs.get("max_tokens")
if max_tokens is None and len(args) > 4:
max_tokens = args[4]
try:
max_tokens = int(max_tokens)
except (TypeError, ValueError):
max_tokens = 768
base_seconds = 220 if TTS is None else 45
return min(300, max(60, base_seconds + max_tokens // 12))
def _pcm_float32_to_wav(audio_bytes: bytes, sample_rate: int = SAMPLE_RATE) -> str:
audio = np.frombuffer(audio_bytes, dtype=np.float32)
if audio.size == 0:
raise gr.Error("The model returned no audio. Try increasing max tokens.")
audio = np.nan_to_num(audio, nan=0.0, posinf=0.0, neginf=0.0)
audio = np.clip(audio, -1.0, 1.0)
audio_i16 = (audio * 32767.0).astype(np.int16)
handle = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
handle.close()
with wave.open(handle.name, "wb") as wav:
wav.setnchannels(1)
wav.setsampwidth(2)
wav.setframerate(sample_rate)
wav.writeframes(audio_i16.tobytes())
return handle.name
def _normalize_text(text: str, language: str, enabled: bool) -> str:
if not enabled:
return text
if language not in SERVER_TO_NEMO_LANG:
return text
return TEXT_NORMALIZER.normalize(text, language)
def _speaking_rate_bucket(value: str) -> int | None:
if value == "default":
return None
return int(value)
@spaces.GPU(duration=_estimate_duration)
def synthesize(
text: str,
language: str,
text_normalization: bool,
speaking_rate: str,
max_tokens: int,
temperature: float,
topk: int,
top_p: float,
min_p: float,
repetition_penalty: float,
seed: int,
):
text = (text or "").strip()
if not text:
raise gr.Error("Enter text to synthesize.")
if len(text) > 1200:
raise gr.Error("Keep the prompt under 1200 characters for this Space.")
normalized = _normalize_text(text, language, text_normalization)
params = TTSSamplingParams(
temperature=float(temperature),
topk=int(topk),
top_p=float(top_p),
min_p=float(min_p),
max_tokens=int(max_tokens),
repetition_window=50,
repetition_penalty=float(repetition_penalty),
repetition_codebooks=8,
seed=None if seed is None or int(seed) < 0 else int(seed),
)
started = time.perf_counter()
with TTS_LOCK:
global TTS
if TTS is None:
TTS = _load_model()
torch.cuda.set_stream(TTS.stream)
result = TTS.generate_one(
normalized,
params,
decode_audio=True,
speaking_rate_bucket=_speaking_rate_bucket(speaking_rate),
quality_buckets=None,
)
elapsed = time.perf_counter() - started
wav_path = _pcm_float32_to_wav(result["audio"], result.get("sample_rate", SAMPLE_RATE))
frames = len(result.get("audio_tokens") or [])
eos_frame = result.get("eos_frame")
status = f"Generated {frames} frames in {elapsed:.1f}s"
if eos_frame is not None:
status += f" (EOS frame {eos_frame})"
if normalized != text:
status += f"\n\nNormalized text: {normalized}"
return wav_path, status
CSS = """
main, .gradio-container, .gradio-container > .fillable {
max-width: 1180px !important;
margin-inline: auto !important;
}
.compact-status textarea {
font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, monospace;
}
"""
with gr.Blocks(title="ZONOS2") as demo:
gr.Markdown("# ZONOS2")
with gr.Row():
with gr.Column(scale=5):
text = gr.Textbox(
label="Text",
value="In the quiet hum of the studio, ZONOS2 turns written words into natural speech.",
lines=6,
max_length=1200,
)
with gr.Row():
language = gr.Dropdown(
choices=LANGUAGES,
value="en_us",
label="Language",
)
speaking_rate = gr.Dropdown(
choices=SPEAKING_RATE_BUCKETS,
value="default",
label="Speaking rate",
)
text_normalization = gr.Checkbox(value=True, label="Text normalization")
generate = gr.Button("Generate", variant="primary")
with gr.Column(scale=4):
audio = gr.Audio(label="Audio", type="filepath", format="wav")
status = gr.Textbox(
label="Status",
lines=5,
interactive=False,
elem_classes=["compact-status"],
)
with gr.Accordion("Sampling", open=False):
with gr.Row():
max_tokens = gr.Slider(
minimum=128,
maximum=2048,
step=64,
value=768,
label="Max audio tokens",
)
seed = gr.Number(value=-1, precision=0, label="Seed (-1 random)")
with gr.Row():
temperature = gr.Slider(0.1, 2.0, value=1.15, step=0.05, label="Temperature")
topk = gr.Slider(1, 512, value=106, step=1, label="Top-k")
with gr.Row():
top_p = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="Top-p")
min_p = gr.Slider(0.0, 0.5, value=0.18, step=0.01, label="Min-p")
repetition_penalty = gr.Slider(
1.0,
2.0,
value=1.2,
step=0.05,
label="Repetition penalty",
)
gr.Examples(
examples=[
[
"The first explorers landed just after sunrise, carrying maps, coffee, and impossible optimism.",
"en_us",
True,
"default",
512,
1.15,
106,
0.0,
0.18,
1.2,
-1,
],
[
"Le modèle parle avec une voix claire, expressive et naturellement rythmée.",
"fr_fr",
True,
"default",
512,
1.15,
106,
0.0,
0.18,
1.2,
-1,
],
],
inputs=[
text,
language,
text_normalization,
speaking_rate,
max_tokens,
temperature,
topk,
top_p,
min_p,
repetition_penalty,
seed,
],
)
generate.click(
fn=synthesize,
inputs=[
text,
language,
text_normalization,
speaking_rate,
max_tokens,
temperature,
topk,
top_p,
min_p,
repetition_penalty,
seed,
],
outputs=[audio, status],
api_name="generate",
concurrency_limit=1,
)
if __name__ == "__main__":
demo.queue(max_size=8, default_concurrency_limit=1).launch(css=CSS)