| import os |
| import tempfile |
| import threading |
| import time |
| import wave |
| from pathlib import Path |
|
|
| os.environ.setdefault("HF_HOME", "/tmp/.cache/huggingface") |
| os.environ.setdefault("HF_MODULES_CACHE", "/tmp/hf_modules") |
| os.environ.setdefault("MPLCONFIGDIR", "/tmp/matplotlib") |
| os.environ.setdefault("ZONOS2_TTS_NORM_CACHE_DIR", "/tmp/zonos2-tts-norm") |
| os.environ.setdefault("GRADIO_SSR_MODE", "false") |
| os.environ.setdefault("NUMBA_DISABLE_CUDA", "1") |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") |
|
|
| for cache_dir in ( |
| os.environ["HF_HOME"], |
| os.environ["HF_MODULES_CACHE"], |
| os.environ["MPLCONFIGDIR"], |
| os.environ["ZONOS2_TTS_NORM_CACHE_DIR"], |
| ): |
| Path(cache_dir).mkdir(parents=True, exist_ok=True) |
|
|
|
|
| print("Importing Space runtime dependencies...", flush=True) |
| import spaces |
| import gradio as gr |
| import numpy as np |
| import torch |
|
|
| print("Importing ZONOS2 modules...", flush=True) |
| from zonos2.message import TTSSamplingParams |
| from zonos2.tokenizer.textnorm import SERVER_TO_NEMO_LANG, TTSTextNormalizer |
| from zonos2.tts import TTSLLM |
| print("Imported ZONOS2 modules.", flush=True) |
|
|
| MODEL_ID = "Zyphra/ZONOS2" |
| SAMPLE_RATE = 44100 |
| LANGUAGES = [ |
| ("English (US)", "en_us"), |
| ("English (UK)", "en_gb"), |
| ("French", "fr_fr"), |
| ("German", "de"), |
| ("Spanish", "es"), |
| ("Italian", "it"), |
| ("Portuguese (Brazil)", "pt_br"), |
| ("Japanese", "ja"), |
| ("Mandarin Chinese", "cmn"), |
| ("Korean", "ko"), |
| ] |
| SPEAKING_RATE_BUCKETS = [ |
| ("Default", "default"), |
| ("Very slow", "0"), |
| ("Slow", "1"), |
| ("Relaxed", "2"), |
| ("Natural", "3"), |
| ("Bright", "4"), |
| ("Fast", "5"), |
| ("Very fast", "6"), |
| ("Extreme", "7"), |
| ] |
|
|
| torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
|
| def _load_model() -> TTSLLM: |
| print(f"Loading {MODEL_ID} for ZeroGPU inference...", flush=True) |
| started = time.perf_counter() |
| model = TTSLLM( |
| model_path=MODEL_ID, |
| decode_audio=True, |
| cuda_graph_max_bs=0, |
| max_running_req=4, |
| max_extend_tokens=4096, |
| memory_ratio=0.75, |
| use_pynccl=False, |
| ) |
| elapsed = time.perf_counter() - started |
| print(f"Loaded {MODEL_ID} in {elapsed:.1f}s", flush=True) |
| return model |
|
|
|
|
| TTS: TTSLLM | None = None |
| TEXT_NORMALIZER = TTSTextNormalizer() |
| TTS_LOCK = threading.Lock() |
|
|
|
|
| def _estimate_duration(*args, **kwargs) -> int: |
| max_tokens = kwargs.get("max_tokens") |
| if max_tokens is None and len(args) > 4: |
| max_tokens = args[4] |
| try: |
| max_tokens = int(max_tokens) |
| except (TypeError, ValueError): |
| max_tokens = 768 |
| base_seconds = 220 if TTS is None else 45 |
| return min(300, max(60, base_seconds + max_tokens // 12)) |
|
|
|
|
| def _pcm_float32_to_wav(audio_bytes: bytes, sample_rate: int = SAMPLE_RATE) -> str: |
| audio = np.frombuffer(audio_bytes, dtype=np.float32) |
| if audio.size == 0: |
| raise gr.Error("The model returned no audio. Try increasing max tokens.") |
| audio = np.nan_to_num(audio, nan=0.0, posinf=0.0, neginf=0.0) |
| audio = np.clip(audio, -1.0, 1.0) |
| audio_i16 = (audio * 32767.0).astype(np.int16) |
|
|
| handle = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) |
| handle.close() |
| with wave.open(handle.name, "wb") as wav: |
| wav.setnchannels(1) |
| wav.setsampwidth(2) |
| wav.setframerate(sample_rate) |
| wav.writeframes(audio_i16.tobytes()) |
| return handle.name |
|
|
|
|
| def _normalize_text(text: str, language: str, enabled: bool) -> str: |
| if not enabled: |
| return text |
| if language not in SERVER_TO_NEMO_LANG: |
| return text |
| return TEXT_NORMALIZER.normalize(text, language) |
|
|
|
|
| def _speaking_rate_bucket(value: str) -> int | None: |
| if value == "default": |
| return None |
| return int(value) |
|
|
|
|
| @spaces.GPU(duration=_estimate_duration) |
| def synthesize( |
| text: str, |
| language: str, |
| text_normalization: bool, |
| speaking_rate: str, |
| max_tokens: int, |
| temperature: float, |
| topk: int, |
| top_p: float, |
| min_p: float, |
| repetition_penalty: float, |
| seed: int, |
| ): |
| text = (text or "").strip() |
| if not text: |
| raise gr.Error("Enter text to synthesize.") |
| if len(text) > 1200: |
| raise gr.Error("Keep the prompt under 1200 characters for this Space.") |
|
|
| normalized = _normalize_text(text, language, text_normalization) |
| params = TTSSamplingParams( |
| temperature=float(temperature), |
| topk=int(topk), |
| top_p=float(top_p), |
| min_p=float(min_p), |
| max_tokens=int(max_tokens), |
| repetition_window=50, |
| repetition_penalty=float(repetition_penalty), |
| repetition_codebooks=8, |
| seed=None if seed is None or int(seed) < 0 else int(seed), |
| ) |
|
|
| started = time.perf_counter() |
| with TTS_LOCK: |
| global TTS |
| if TTS is None: |
| TTS = _load_model() |
| torch.cuda.set_stream(TTS.stream) |
| result = TTS.generate_one( |
| normalized, |
| params, |
| decode_audio=True, |
| speaking_rate_bucket=_speaking_rate_bucket(speaking_rate), |
| quality_buckets=None, |
| ) |
| elapsed = time.perf_counter() - started |
|
|
| wav_path = _pcm_float32_to_wav(result["audio"], result.get("sample_rate", SAMPLE_RATE)) |
| frames = len(result.get("audio_tokens") or []) |
| eos_frame = result.get("eos_frame") |
| status = f"Generated {frames} frames in {elapsed:.1f}s" |
| if eos_frame is not None: |
| status += f" (EOS frame {eos_frame})" |
| if normalized != text: |
| status += f"\n\nNormalized text: {normalized}" |
| return wav_path, status |
|
|
|
|
| CSS = """ |
| main, .gradio-container, .gradio-container > .fillable { |
| max-width: 1180px !important; |
| margin-inline: auto !important; |
| } |
| .compact-status textarea { |
| font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, monospace; |
| } |
| """ |
|
|
|
|
| with gr.Blocks(title="ZONOS2") as demo: |
| gr.Markdown("# ZONOS2") |
| with gr.Row(): |
| with gr.Column(scale=5): |
| text = gr.Textbox( |
| label="Text", |
| value="In the quiet hum of the studio, ZONOS2 turns written words into natural speech.", |
| lines=6, |
| max_length=1200, |
| ) |
| with gr.Row(): |
| language = gr.Dropdown( |
| choices=LANGUAGES, |
| value="en_us", |
| label="Language", |
| ) |
| speaking_rate = gr.Dropdown( |
| choices=SPEAKING_RATE_BUCKETS, |
| value="default", |
| label="Speaking rate", |
| ) |
| text_normalization = gr.Checkbox(value=True, label="Text normalization") |
| generate = gr.Button("Generate", variant="primary") |
| with gr.Column(scale=4): |
| audio = gr.Audio(label="Audio", type="filepath", format="wav") |
| status = gr.Textbox( |
| label="Status", |
| lines=5, |
| interactive=False, |
| elem_classes=["compact-status"], |
| ) |
|
|
| with gr.Accordion("Sampling", open=False): |
| with gr.Row(): |
| max_tokens = gr.Slider( |
| minimum=128, |
| maximum=2048, |
| step=64, |
| value=768, |
| label="Max audio tokens", |
| ) |
| seed = gr.Number(value=-1, precision=0, label="Seed (-1 random)") |
| with gr.Row(): |
| temperature = gr.Slider(0.1, 2.0, value=1.15, step=0.05, label="Temperature") |
| topk = gr.Slider(1, 512, value=106, step=1, label="Top-k") |
| with gr.Row(): |
| top_p = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="Top-p") |
| min_p = gr.Slider(0.0, 0.5, value=0.18, step=0.01, label="Min-p") |
| repetition_penalty = gr.Slider( |
| 1.0, |
| 2.0, |
| value=1.2, |
| step=0.05, |
| label="Repetition penalty", |
| ) |
|
|
| gr.Examples( |
| examples=[ |
| [ |
| "The first explorers landed just after sunrise, carrying maps, coffee, and impossible optimism.", |
| "en_us", |
| True, |
| "default", |
| 512, |
| 1.15, |
| 106, |
| 0.0, |
| 0.18, |
| 1.2, |
| -1, |
| ], |
| [ |
| "Le modèle parle avec une voix claire, expressive et naturellement rythmée.", |
| "fr_fr", |
| True, |
| "default", |
| 512, |
| 1.15, |
| 106, |
| 0.0, |
| 0.18, |
| 1.2, |
| -1, |
| ], |
| ], |
| inputs=[ |
| text, |
| language, |
| text_normalization, |
| speaking_rate, |
| max_tokens, |
| temperature, |
| topk, |
| top_p, |
| min_p, |
| repetition_penalty, |
| seed, |
| ], |
| ) |
|
|
| generate.click( |
| fn=synthesize, |
| inputs=[ |
| text, |
| language, |
| text_normalization, |
| speaking_rate, |
| max_tokens, |
| temperature, |
| topk, |
| top_p, |
| min_p, |
| repetition_penalty, |
| seed, |
| ], |
| outputs=[audio, status], |
| api_name="generate", |
| concurrency_limit=1, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.queue(max_size=8, default_concurrency_limit=1).launch(css=CSS) |
|
|