multimodalart's picture
multimodalart HF Staff
Upload app.py with huggingface_hub
e6a77d8 verified
Raw
History Blame Contribute Delete
12.8 kB
import os
os.environ.setdefault("NUMBA_DISABLE_CUDA", "1")
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
import ctypes
import glob
import site
def _preload_cudart13():
# zonos2's JIT kernels are compiled by the image's CUDA 13 nvcc and link
# libcudart.so.13, but torch 2.9.1 (cu128) only ships cudart 12.
patterns = [f"{sp}/nvidia/**/libcudart.so.13*" for sp in site.getsitepackages()]
patterns += [
"/usr/local/cuda*/targets/*/lib/libcudart.so.13*",
"/usr/local/cuda*/lib64/libcudart.so.13*",
"/usr/lib/x86_64-linux-gnu/libcudart.so.13*",
]
for pattern in patterns:
for lib in sorted(glob.glob(pattern, recursive=True)):
ctypes.CDLL(lib, mode=ctypes.RTLD_GLOBAL)
return
_preload_cudart13()
import spaces
import hashlib
import random
import threading
import gradio as gr
import numpy as np
import torch
from huggingface_hub import snapshot_download
MODEL_REPO = "Zyphra/ZONOS2"
SPEAKER_REPO = "marksverdhei/Qwen3-Voice-Embedding-12Hz-1.7B"
SAMPLE_RATE = 44100
FRAMES_PER_SECOND = SAMPLE_RATE / 512 # DAC hop length
MODEL_PATH = snapshot_download(MODEL_REPO, allow_patterns=["*.json", "*.pth", "*.pt", "*.yaml"])
snapshot_download(SPEAKER_REPO)
import dac as _dac
_dac.utils.download(model_type="44khz")
from zonos2.message.tts import TTSSamplingParams, TTSUserMsg
from zonos2.tokenizer.textnorm import TTSTextNormalizer
from zonos2.tts import TTSLLM
import socket
from zonos2.engine.config import EngineConfig
_DIST_PORT = None
def _distributed_addr(self):
# Upstream hardcodes tcp://127.0.0.1:23333; when ZeroGPU retries a call in
# a fresh worker while another worker is still mid-init, the fixed port
# collides with EADDRINUSE. Pick a free port once per process instead.
global _DIST_PORT
if _DIST_PORT is None:
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
_DIST_PORT = s.getsockname()[1]
return f"tcp://127.0.0.1:{_DIST_PORT}"
EngineConfig.distributed_addr = property(_distributed_addr)
import zonos2.engine.engine as zonos2_engine
from zonos2.models.weight import _normalize_zonos2_state_dict
# Deserialize the 15.3 GB checkpoint once in the main process (mmap keeps it
# page-cache backed); forked GPU workers inherit it copy-on-write, so cold
# engine init skips the ~17s torch.load and only pays the host->device copy.
_STATE_DICT = torch.load(
f"{MODEL_PATH}/model.pth", map_location="cpu", weights_only=False, mmap=True
)
if "model" in _STATE_DICT:
_STATE_DICT = _STATE_DICT["model"]
_STATE_DICT = _normalize_zonos2_state_dict(_STATE_DICT)
def _preloaded_checkpoint_weight(model_path, device):
return {k: v.to(device) for k, v in _STATE_DICT.items()}
zonos2_engine.load_checkpoint_weight = _preloaded_checkpoint_weight
LANGUAGES = {
"English (US)": "en_us",
"English (UK)": "en_gb",
"French": "fr_fr",
"German": "de",
"Spanish": "es",
"Italian": "it",
"Portuguese (BR)": "pt_br",
"Japanese": "ja",
"Mandarin": "cmn",
"Korean": "ko",
}
SPEAKING_RATE_BUCKETS = ["0-8", "8-11", "11-14", "14-17", "17-21", "21-28", "28-40", "40+"]
RATE_CHOICES = ["Auto"] + SPEAKING_RATE_BUCKETS
MAX_SEED = np.iinfo(np.int32).max
NORMALIZER = TTSTextNormalizer()
threading.Thread(target=NORMALIZER.warmup, daemon=True).start()
class ZonosTTSLLM(TTSLLM):
"""TTSLLM with speaker-embedding conditioning plumbed into the offline path."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.speaker_embedding = None
self.clean_speaker_background = False
self.accurate_mode = True
def offline_receive_msg(self, blocking: bool = False):
msgs = super().offline_receive_msg(blocking)
for msg in msgs:
if isinstance(msg, TTSUserMsg):
msg.speaker_embedding = self.speaker_embedding
msg.clean_speaker_background = self.clean_speaker_background
msg.accurate_mode = self.accurate_mode
return msgs
MODELS = {}
EMBEDDING_CACHE = {}
def _get_models():
if "tts" not in MODELS:
from zonos2.models.speaker_cloning import Qwen3SpeakerEmbedding
MODELS["embedder"] = Qwen3SpeakerEmbedding(device="cuda")
MODELS["tts"] = ZonosTTSLLM(
model_path=MODEL_PATH,
cuda_graph_max_bs=4,
num_page_override=65536,
)
return MODELS
def _embed_speaker(models, speaker_audio):
sr, wav = speaker_audio
key = hashlib.sha256(wav.tobytes() + str(sr).encode()).hexdigest()
if key in EMBEDDING_CACHE:
return EMBEDDING_CACHE[key]
wav = np.asarray(wav)
if wav.dtype == np.int16:
wav = wav.astype(np.float32) / 32768.0
elif wav.dtype == np.int32:
wav = wav.astype(np.float32) / 2147483648.0
else:
wav = wav.astype(np.float32)
if wav.ndim == 2:
wav = wav.T # (samples, channels) -> (channels, samples)
else:
# The embedder's reflect-pad requires a 2D (channels, samples) input;
# mono uploads arrive 1D.
wav = wav[None, :]
wav_t = torch.from_numpy(wav)
embedder = models["embedder"]
with torch.inference_mode():
output = embedder(wav_t, sr)
candidates = output if isinstance(output, tuple) else (output,)
for candidate in candidates:
candidate = candidate.squeeze(0).to(dtype=torch.float32, device="cpu")
if candidate.numel() == 2048:
embedding = candidate.reshape(2048)
EMBEDDING_CACHE[key] = embedding
return embedding
raise gr.Error("Could not compute a speaker embedding from the reference audio.")
def normalize_text(text, language, apply_normalization):
text = (text or "").strip()
if not text:
raise gr.Error("Please enter some text to synthesize.")
if len(text) > 5000:
raise gr.Error("Text is too long — please keep it under 5000 characters.")
if not apply_normalization:
return text
return NORMALIZER.normalize(text, LANGUAGES[language])
def _gpu_duration(
normalized_text, speaker_audio, accurate_mode, clean_background, speaking_rate, max_seconds, *args
):
# ~18s engine init + JIT/embedder headroom, decode measured at ~51 frames/s
# (86.13 frames per audio second -> ~1.7x realtime).
return 75 + 2 * float(max_seconds)
@spaces.GPU(duration=_gpu_duration)
def generate(
normalized_text,
speaker_audio,
accurate_mode,
clean_background,
speaking_rate,
max_seconds,
seed,
randomize_seed,
temperature,
top_k,
min_p,
repetition_penalty,
progress=gr.Progress(),
):
models = _get_models()
tts = models["tts"]
# The scheduler pins its CUDA stream thread-locally at init, but each call
# may run in a new thread; re-pin or run_forever's stream assert fails.
torch.cuda.set_stream(tts.stream)
if randomize_seed:
seed = random.randint(0, MAX_SEED)
seed = int(seed)
progress(0.1, desc="Embedding reference voice...")
embedding = _embed_speaker(models, speaker_audio) if speaker_audio is not None else None
tts.speaker_embedding = embedding
tts.clean_speaker_background = bool(clean_background)
tts.accurate_mode = bool(accurate_mode)
sampling_params = TTSSamplingParams(
temperature=float(temperature),
topk=int(top_k),
min_p=float(min_p),
repetition_penalty=float(repetition_penalty),
max_tokens=int(float(max_seconds) * FRAMES_PER_SECOND),
seed=seed,
)
rate_bucket = None if speaking_rate == "Auto" else SPEAKING_RATE_BUCKETS.index(speaking_rate)
progress(0.3, desc="Generating speech...")
result = tts.generate_one(
normalized_text,
sampling_params,
speaking_rate_bucket=rate_bucket,
)
if not result["audio"]:
raise gr.Error("Generation produced no audio — try a different seed or shorter text.")
audio = np.frombuffer(result["audio"], dtype=np.float32).copy()
return (SAMPLE_RATE, audio), seed
css = """
.gradio-container {max-width: 960px !important; margin: 0 auto !important;}
"""
with gr.Blocks(css=css, title="Zonos 2") as demo:
gr.Markdown(
"""
# 🗣️ Zonos 2
[Zyphra's ZONOS2](https://huggingface.co/Zyphra/ZONOS2) — an expressive multilingual
text-to-speech model with high-fidelity voice cloning, trained on 6M+ hours of speech.
Upload or record a few seconds of a voice and it will speak your text.
[Blog](https://www.zyphra.com/our-work/zonos2) · [Code](https://github.com/Zyphra/ZONOS2)
"""
)
with gr.Row():
with gr.Column():
text = gr.Textbox(
label="Text",
lines=4,
value="Hello! I am Zonos 2, a text to speech model by Zyphra. I can clone anyone's voice from just a few seconds of audio.",
)
language = gr.Dropdown(
choices=list(LANGUAGES.keys()), value="English (US)", label="Language"
)
speaker_audio = gr.Audio(
label="Reference voice (upload or record)",
type="numpy",
sources=["upload", "microphone"],
value="voices/AmericanFemale.mp3",
)
gr.Examples(
examples=[
["voices/AmericanFemale.mp3"],
["voices/AmericanMale.mp3"],
["voices/BritishFemale.mp3"],
],
inputs=[speaker_audio],
label="Default voices",
)
generate_btn = gr.Button("Generate", variant="primary")
with gr.Column():
audio_out = gr.Audio(label="Generated speech", type="numpy")
with gr.Accordion("Advanced settings", open=False):
accurate_mode = gr.Checkbox(
value=True,
label="Accurate mode",
info="Disable for more expressive (less literal) delivery",
)
clean_background = gr.Checkbox(
value=False,
label="Clean reference audio",
info="Mark the reference recording as having a clean background",
)
normalize_chk = gr.Checkbox(
value=True,
label="Normalize text",
info='Convert written forms to spoken forms ("$5" → "five dollars")',
)
speaking_rate = gr.Dropdown(
choices=RATE_CHOICES, value="Auto", label="Speaking rate (phonemes/sec)"
)
max_seconds = gr.Slider(
minimum=2, maximum=60, value=30, step=1, label="Max audio length (seconds)"
)
temperature = gr.Slider(
minimum=0.1, maximum=2.0, value=1.15, step=0.05, label="Temperature"
)
top_k = gr.Slider(minimum=1, maximum=1024, value=106, step=1, label="Top-k")
min_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.18, step=0.01, label="Min-p")
repetition_penalty = gr.Slider(
minimum=1.0, maximum=2.0, value=1.2, step=0.05, label="Repetition penalty"
)
seed = gr.Number(value=42, precision=0, label="Seed")
randomize_seed = gr.Checkbox(value=True, label="Randomize seed")
normalized_text = gr.State("")
gr.Examples(
examples=[
["Did you know? The sun is actually a giant ball of plasma — over one million Earths could fit inside it!", "English (US)"],
["On the 3rd of March 2026, tickets cost $5.32 each.", "English (US)"],
["Bonjour ! Je peux parler plusieurs langues avec une voix naturelle et expressive.", "French"],
["私は数秒の音声からどんな声でも再現できます。", "Japanese"],
["¡Hola! Puedo clonar cualquier voz con solo unos segundos de audio.", "Spanish"],
],
inputs=[text, language],
label="Example texts",
)
generate_btn.click(
fn=normalize_text,
inputs=[text, language, normalize_chk],
outputs=[normalized_text],
).then(
fn=generate,
inputs=[
normalized_text,
speaker_audio,
accurate_mode,
clean_background,
speaking_rate,
max_seconds,
seed,
randomize_seed,
temperature,
top_k,
min_p,
repetition_penalty,
],
outputs=[audio_out, seed],
)
demo.launch()