import os import sys import shlex import subprocess import ctypes import types CUDA_LIBDIR = "/cuda-image/usr/local/cuda-13.0/lib64" try: ctypes.CDLL(os.path.join(CUDA_LIBDIR, "libcudart.so.13"), mode=ctypes.RTLD_GLOBAL) except OSError: pass os.environ["LD_LIBRARY_PATH"] = CUDA_LIBDIR + os.pathsep + os.environ.get("LD_LIBRARY_PATH", "") subprocess.run(shlex.split("pip install flash-attn --no-build-isolation"), env=os.environ | {"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, check=True) subprocess.run(shlex.split("pip install --no-deps https://github.com/state-spaces/mamba/releases/download/v2.3.2.post1/mamba_ssm-2.3.2.post1+cu13torch2.10cxx11abiTRUE-cp310-cp310-linux_x86_64.whl"), check=True) subprocess.run(shlex.split("pip install --no-deps https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.6.2.post1/causal_conv1d-1.6.2.post1+cu13torch2.10cxx11abiTRUE-cp310-cp310-linux_x86_64.whl"), check=True) # flash-attn was installed with FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE so the # `flash_attn_2_cuda` C extension is missing. Stub it so the top-level # `import flash_attn` succeeds (we only need flash_attn.layers.rotary which is # pure Triton). The CUDA-backed helpers (flash_attn_with_kvcache, etc.) get # replaced with SDPA shims below. _stub = types.ModuleType("flash_attn_2_cuda") _stub.__file__ = "" _stub.__spec__ = None def _stub_missing(*a, **kw): raise RuntimeError("flash_attn_2_cuda stub: CUDA flash-attn kernels not available on Blackwell") # Only raise on calls/attrs the runtime actually tries to use; let dunder # introspection succeed so libraries like `inspect` keep working. def _stub_getattr(name): if name.startswith("_"): raise AttributeError(name) return _stub_missing _stub.__getattr__ = _stub_getattr sys.modules.setdefault("flash_attn_2_cuda", _stub) import spaces # must come before torch / mamba_ssm / anything CUDA-related # Replace flash_attn_with_kvcache in mamba_ssm.modules.mha with an SDPA-based # shim. The hybrid/transformer Zonos models use rotary kvcache decoding which # normally calls into the CUDA kernel; this shim does the rotary apply + # kvcache update + SDPA in plain torch instead. import torch import torch.nn.functional as _F def _flash_attn_with_kvcache_sdpa( q, k_cache, v_cache, k=None, v=None, rotary_cos=None, rotary_sin=None, cache_seqlens=None, softmax_scale=None, causal=False, rotary_interleaved=False, **kwargs, ): # q: (B, S_q, H, D); k_cache/v_cache: (B, S_max, H_kv, D) bsz, sq, nh, hd = q.shape s_max = k_cache.shape[1] nh_kv = k_cache.shape[2] # Resolve write positions from cache_seqlens (int or tensor of ints per-batch). if isinstance(cache_seqlens, torch.Tensor): start = cache_seqlens.to(torch.long) else: start = torch.full((bsz,), int(cache_seqlens), dtype=torch.long, device=q.device) # Apply rotary to q and the new k if rotary tables are passed in. cache holds # already-rotated keys. if rotary_cos is not None and rotary_sin is not None and k is not None: ro_dim = rotary_cos.shape[-1] * 2 def _apply_rotary(t, positions): # t: (B, S, H, D); positions: (B, S) long cos = rotary_cos[positions] # (B, S, ro_dim/2) sin = rotary_sin[positions] # (B, S, ro_dim/2) t_ro = t[..., :ro_dim] t_pass = t[..., ro_dim:] if rotary_interleaved: t1 = t_ro[..., 0::2] t2 = t_ro[..., 1::2] cos_b = cos.unsqueeze(2) sin_b = sin.unsqueeze(2) o1 = t1 * cos_b - t2 * sin_b o2 = t1 * sin_b + t2 * cos_b out = torch.stack([o1, o2], dim=-1).flatten(-2) else: t1, t2 = t_ro.chunk(2, dim=-1) cos_b = cos.unsqueeze(2).repeat_interleave(2, dim=-1)[..., : ro_dim // 2] sin_b = sin.unsqueeze(2).repeat_interleave(2, dim=-1)[..., : ro_dim // 2] o1 = t1 * cos_b - t2 * sin_b o2 = t1 * sin_b + t2 * cos_b out = torch.cat([o1, o2], dim=-1) return torch.cat([out.to(t.dtype), t_pass], dim=-1) # positions for q: start..start+sq-1; for new k: same. offsets = torch.arange(sq, device=q.device).unsqueeze(0) # (1, sq) positions = start.unsqueeze(1) + offsets # (B, sq) q = _apply_rotary(q, positions) k = _apply_rotary(k, positions) # Write k, v into caches at positions [start:start+sq] per batch. if k is not None and v is not None: # Build write index of shape (B, sq) -> use scatter on the seq dim. offsets = torch.arange(sq, device=q.device).unsqueeze(0) # (1, sq) positions = start.unsqueeze(1) + offsets # (B, sq) # Expand to (B, sq, H_kv, D) for scatter on dim=1. idx = positions.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, nh_kv, hd) k_cache.scatter_(1, idx, k.to(k_cache.dtype)) v_cache.scatter_(1, idx, v.to(v_cache.dtype)) # Mask out keys beyond each batch's current end (start + sq). end = start + sq # (B,) arange = torch.arange(s_max, device=q.device).unsqueeze(0) # (1, S_max) key_valid = arange < end.unsqueeze(1) # (B, S_max) # For SDPA, we need (B, H, S, D) layout. q_h = q.transpose(1, 2) # (B, H, S_q, D) # Repeat kv heads to match q heads. rep = nh // nh_kv if rep > 1: k_full = k_cache.repeat_interleave(rep, dim=2) v_full = v_cache.repeat_interleave(rep, dim=2) else: k_full = k_cache v_full = v_cache k_h = k_full.transpose(1, 2) # (B, H, S_max, D) v_h = v_full.transpose(1, 2) # Build a (B, 1, S_q, S_max) attention mask: True = attend. # Causal: position i in q attends to positions <= start + i in key space. q_positions = start.unsqueeze(1) + torch.arange(sq, device=q.device).unsqueeze(0) # (B, sq) causal_mask = arange.unsqueeze(1) <= q_positions.unsqueeze(2) # (B, sq, S_max) attn_mask = key_valid.unsqueeze(1) & causal_mask # (B, sq, S_max) attn_mask = attn_mask.unsqueeze(1) # (B, 1, sq, S_max) out = _F.scaled_dot_product_attention( q_h, k_h, v_h, attn_mask=attn_mask, scale=softmax_scale, is_causal=False, ) return out.transpose(1, 2).contiguous() # (B, S_q, H, D) # Install the shim into mamba_ssm.modules.mha (it already tried to import the # real one and got None, so we just overwrite at module level). import mamba_ssm.modules.mha as _mha_mod _mha_mod.flash_attn_with_kvcache = _flash_attn_with_kvcache_sdpa import torchaudio import gradio as gr from os import getenv from zonos.model import Zonos from zonos.conditioning import make_cond_dict, supported_language_codes device = "cuda" MODEL_NAMES = ["Zyphra/Zonos-v0.1-transformer", "Zyphra/Zonos-v0.1-hybrid"] MODELS = {name: Zonos.from_pretrained(name, device=device) for name in MODEL_NAMES} for model in MODELS.values(): model.requires_grad_(False).eval() def update_ui(model_choice): """ Dynamically show/hide UI elements based on the model's conditioners. We do NOT display 'language_id' or 'ctc_loss' even if they exist in the model. """ model = MODELS[model_choice] cond_names = [c.name for c in model.prefix_conditioner.conditioners] print("Conditioners in this model:", cond_names) text_update = gr.update(visible=("espeak" in cond_names)) language_update = gr.update(visible=("espeak" in cond_names)) speaker_audio_update = gr.update(visible=("speaker" in cond_names)) prefix_audio_update = gr.update(visible=True) emotion1_update = gr.update(visible=("emotion" in cond_names)) emotion2_update = gr.update(visible=("emotion" in cond_names)) emotion3_update = gr.update(visible=("emotion" in cond_names)) emotion4_update = gr.update(visible=("emotion" in cond_names)) emotion5_update = gr.update(visible=("emotion" in cond_names)) emotion6_update = gr.update(visible=("emotion" in cond_names)) emotion7_update = gr.update(visible=("emotion" in cond_names)) emotion8_update = gr.update(visible=("emotion" in cond_names)) vq_single_slider_update = gr.update(visible=("vqscore_8" in cond_names)) fmax_slider_update = gr.update(visible=("fmax" in cond_names)) pitch_std_slider_update = gr.update(visible=("pitch_std" in cond_names)) speaking_rate_slider_update = gr.update(visible=("speaking_rate" in cond_names)) dnsmos_slider_update = gr.update(visible=("dnsmos_ovrl" in cond_names)) speaker_noised_checkbox_update = gr.update(visible=("speaker_noised" in cond_names)) unconditional_keys_update = gr.update( choices=[name for name in cond_names if name not in ("espeak", "language_id")] ) return ( text_update, language_update, speaker_audio_update, prefix_audio_update, emotion1_update, emotion2_update, emotion3_update, emotion4_update, emotion5_update, emotion6_update, emotion7_update, emotion8_update, vq_single_slider_update, fmax_slider_update, pitch_std_slider_update, speaking_rate_slider_update, dnsmos_slider_update, speaker_noised_checkbox_update, unconditional_keys_update, ) @spaces.GPU(duration=120) def generate_audio( model_choice, text, language, speaker_audio, prefix_audio, e1, e2, e3, e4, e5, e6, e7, e8, vq_single, fmax, pitch_std, speaking_rate, dnsmos_ovrl, speaker_noised, cfg_scale, min_p, seed, randomize_seed, unconditional_keys, progress=gr.Progress(), ): """ Generates audio based on the provided UI parameters. We do NOT use language_id or ctc_loss even if the model has them. """ selected_model = MODELS[model_choice] speaker_noised_bool = bool(speaker_noised) fmax = float(fmax) pitch_std = float(pitch_std) speaking_rate = float(speaking_rate) dnsmos_ovrl = float(dnsmos_ovrl) cfg_scale = float(cfg_scale) min_p = float(min_p) seed = int(seed) max_new_tokens = 86 * 30 if randomize_seed: seed = torch.randint(0, 2**32 - 1, (1,)).item() torch.manual_seed(seed) speaker_embedding = None if speaker_audio is not None and "speaker" not in unconditional_keys: wav, sr = torchaudio.load(speaker_audio) speaker_embedding = selected_model.make_speaker_embedding(wav, sr) speaker_embedding = speaker_embedding.to(device, dtype=torch.bfloat16) audio_prefix_codes = None if prefix_audio is not None: wav_prefix, sr_prefix = torchaudio.load(prefix_audio) wav_prefix = wav_prefix.mean(0, keepdim=True) wav_prefix = torchaudio.functional.resample(wav_prefix, sr_prefix, selected_model.autoencoder.sampling_rate) wav_prefix = wav_prefix.to(device, dtype=torch.float32) with torch.autocast(device, dtype=torch.float32): audio_prefix_codes = selected_model.autoencoder.encode(wav_prefix.unsqueeze(0)) emotion_tensor = torch.tensor(list(map(float, [e1, e2, e3, e4, e5, e6, e7, e8])), device=device) vq_val = float(vq_single) vq_tensor = torch.tensor([vq_val] * 8, device=device).unsqueeze(0) cond_dict = make_cond_dict( text=text, language=language, speaker=speaker_embedding, emotion=emotion_tensor, vqscore_8=vq_tensor, fmax=fmax, pitch_std=pitch_std, speaking_rate=speaking_rate, dnsmos_ovrl=dnsmos_ovrl, speaker_noised=speaker_noised_bool, device=device, unconditional_keys=unconditional_keys, ) conditioning = selected_model.prepare_conditioning(cond_dict) estimated_generation_duration = 30 * len(text) / 400 estimated_total_steps = int(estimated_generation_duration * 86) def update_progress(_frame: torch.Tensor, step: int, _total_steps: int) -> bool: progress((step, estimated_total_steps)) return True codes = selected_model.generate( prefix_conditioning=conditioning, audio_prefix_codes=audio_prefix_codes, max_new_tokens=max_new_tokens, cfg_scale=cfg_scale, batch_size=1, sampling_params=dict(min_p=min_p), callback=update_progress, ) wav_out = selected_model.autoencoder.decode(codes).cpu().detach() sr_out = selected_model.autoencoder.sampling_rate if wav_out.dim() == 2 and wav_out.size(0) > 1: wav_out = wav_out[0:1, :] return (sr_out, wav_out.squeeze().numpy()), seed def build_interface(): with gr.Blocks(theme='ParityError/Interstellar') as demo: gr.Markdown("# Zonos v0.1") gr.Markdown("State of the art text-to-speech model [[model]](https://huggingface.co/collections/Zyphra/zonos-v01-67ac661c85e1898670823b4f), [[blog]](https://www.zyphra.com/post/beta-release-of-zonos-v0-1), [[Zyphra Audio (hosted service)]](https://maia.zyphra.com/sign-in?redirect_url=https%3A%2F%2Fmaia.zyphra.com%2Faudio) ") with gr.Row(): with gr.Column(): text = gr.Textbox( label="Text to Synthesize", value="Zonos uses eSpeak for text to phoneme conversion!", lines=4, max_length=500, # approximately ) with gr.Row(): language = gr.Dropdown( choices=supported_language_codes, value="en-us", label="Language", ) model_choice = gr.Dropdown( choices=MODEL_NAMES, value="Zyphra/Zonos-v0.1-transformer", label="Zonos Model Type", info="Select the model variant to use.", ) speaker_noised_checkbox = gr.Checkbox( label="Denoise Speaker?", value=False ) speaker_audio = gr.Audio( label="Optional Speaker Audio (for cloning)", type="filepath", ) generate_button = gr.Button("Generate Audio") with gr.Column(): output_audio = gr.Audio(label="Generated Audio", type="numpy", autoplay=True) with gr.Accordion("Toggles", open=True): gr.Markdown( "### Emotion Sliders\n" "Warning: The way these sliders work is not intuitive and may require some trial and error to get the desired effect.\n" "Certain configurations can cause the model to become unstable. Setting emotion to unconditional may help." ) with gr.Row(): emotion1 = gr.Slider(0.0, 1.0, 1.0, 0.05, label="Happiness") emotion2 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Sadness") emotion3 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Disgust") emotion4 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Fear") with gr.Row(): emotion5 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Surprise") emotion6 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Anger") emotion7 = gr.Slider(0.0, 1.0, 0.1, 0.05, label="Other") emotion8 = gr.Slider(0.0, 1.0, 0.2, 0.05, label="Neutral") gr.Markdown( "### Unconditional Toggles\n" "Checking a box will make the model ignore the corresponding conditioning value and make it unconditional.\n" 'Practically this means the given conditioning feature will be unconstrained and "filled in automatically".' ) with gr.Row(): unconditional_keys = gr.CheckboxGroup( [ "speaker", "emotion", "vqscore_8", "fmax", "pitch_std", "speaking_rate", "dnsmos_ovrl", "speaker_noised", ], value=["emotion"], label="Unconditional Keys", ) with gr.Accordion("Advanced Settings", open=False): with gr.Row(): with gr.Column(): gr.Markdown("## Conditioning Parameters") dnsmos_slider = gr.Slider(1.0, 5.0, value=4.0, step=0.1, label="DNSMOS Overall") fmax_slider = gr.Slider(0, 24000, value=24000, step=1, label="Fmax (Hz)") vq_single_slider = gr.Slider(0.5, 0.8, 0.78, 0.01, label="VQ Score") pitch_std_slider = gr.Slider(0.0, 300.0, value=45.0, step=1, label="Pitch Std") speaking_rate_slider = gr.Slider(5.0, 30.0, value=15.0, step=0.5, label="Speaking Rate") with gr.Column(): gr.Markdown("## Generation Parameters") cfg_scale_slider = gr.Slider(1.0, 5.0, 2.0, 0.1, label="CFG Scale") min_p_slider = gr.Slider(0.0, 1.0, 0.15, 0.01, label="Min P") seed_number = gr.Number(label="Seed", value=420, precision=0) randomize_seed_toggle = gr.Checkbox(label="Randomize Seed (before generation)", value=True) prefix_audio = gr.Audio( value="assets/silence_100ms.wav", label="Optional Prefix Audio (continue from this audio)", type="filepath", ) model_choice.change( fn=update_ui, inputs=[model_choice], outputs=[ text, language, speaker_audio, prefix_audio, emotion1, emotion2, emotion3, emotion4, emotion5, emotion6, emotion7, emotion8, vq_single_slider, fmax_slider, pitch_std_slider, speaking_rate_slider, dnsmos_slider, speaker_noised_checkbox, unconditional_keys, ], ) # On page load, trigger the same UI refresh demo.load( fn=update_ui, inputs=[model_choice], outputs=[ text, language, speaker_audio, prefix_audio, emotion1, emotion2, emotion3, emotion4, emotion5, emotion6, emotion7, emotion8, vq_single_slider, fmax_slider, pitch_std_slider, speaking_rate_slider, dnsmos_slider, speaker_noised_checkbox, unconditional_keys, ], ) # Generate audio on button click generate_button.click( fn=generate_audio, inputs=[ model_choice, text, language, speaker_audio, prefix_audio, emotion1, emotion2, emotion3, emotion4, emotion5, emotion6, emotion7, emotion8, vq_single_slider, fmax_slider, pitch_std_slider, speaking_rate_slider, dnsmos_slider, speaker_noised_checkbox, cfg_scale_slider, min_p_slider, seed_number, randomize_seed_toggle, unconditional_keys, ], outputs=[output_audio, seed_number], ) return demo if __name__ == "__main__": demo = build_interface() demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)