import argparse import json import os import uuid from pathlib import Path # Disable PyTorch dynamo/inductor globally os.environ["TORCHDYNAMO_DISABLE"] = "1" os.environ["TORCHINDUCTOR_DISABLE"] = "1" import torch._dynamo as dynamo dynamo.config.suppress_errors = True import gradio as gr import numpy as np import soundfile as sf import spaces import torch from voxtream.config import SpeechGeneratorConfig from voxtream.generator import SpeechGenerator from voxtream.utils.generator import ( DTYPE_MAP, existing_file, interpolate_speaking_rate_params, text_generator, ) MIN_CHUNK_SEC = 0.01 FADE_OUT_SEC = 0.10 CUSTOM_CSS = """ /* overall width */ .gradio-container {max-width: 1100px !important} /* stack labels tighter and even heights */ #cols .wrap > .form {gap: 10px} #left-col, #right-col {gap: 14px} /* make submit centered + bigger */ #submit {width: 260px; margin: 10px auto 0 auto;} /* make clear align left and look secondary */ #clear {width: 120px;} /* give audio a little breathing room */ audio {outline: none;} """ def float32_to_int16(audio_float32: np.ndarray) -> np.ndarray: """ Convert float32 audio samples (-1.0 to 1.0) to int16 PCM samples. Parameters: audio_float32 (np.ndarray): Input float32 audio samples. Returns: np.ndarray: Output int16 audio samples. """ if audio_float32.dtype != np.float32: raise ValueError("Input must be a float32 numpy array") # Clip to avoid overflow after scaling audio_clipped = np.clip(audio_float32, -1.0, 1.0) # Scale and convert audio_int16 = (audio_clipped * 32767).astype(np.int16) return audio_int16 def _clear_outputs(): # clears the player + hides file (download btn mirrors file via .change) return gr.update(value=None), gr.update(value=None, visible=False) def demo_app(config: SpeechGeneratorConfig, demo_examples, synthesize_fn): with gr.Blocks(css=CUSTOM_CSS, title="VoXtream2") as demo: gr.Markdown("# VoXtream2 TTS demo") gr.Markdown( "⚠️ The initial latency can be high due to deployment on ZeroGPU. For faster inference, please try local deployment. For more details, please visit [VoXtream GitHub repo](https://github.com/herimor/voxtream)" ) with gr.Row(equal_height=True, elem_id="cols"): with gr.Column(scale=1, elem_id="left-col"): prompt_audio = gr.Audio( sources=["microphone", "upload"], type="filepath", label=f"Prompt audio (3-10 sec of target voice. Max {config.max_prompt_sec} sec)", ) with gr.Accordion("Advanced options", open=False): prompt_enhancement = gr.Checkbox( label="Prompt enhancement", value=True ) voice_activity_detection = gr.Checkbox( label="Voice activity detection", value=True ) streaming_input = gr.Checkbox(label="Streaming input", value=False) with gr.Column(scale=1, elem_id="right-col"): target_text = gr.Textbox( lines=3, max_length=config.max_phone_tokens, label=f"Target text (Required, max {config.max_phone_tokens} chars)", placeholder="What you want the model to say", ) speaking_rate_control = gr.Slider( minimum=1, maximum=7, step=0.1, value=4, label="Speaking rate (syllables per second)", ) enable_speaking_rate = gr.Checkbox( label="Use speaking rate control", value=True ) enable_speaking_rate.change( fn=lambda enabled: gr.update(interactive=enabled), inputs=enable_speaking_rate, outputs=speaking_rate_control, ) output_audio = gr.Audio( label="Synthesized audio", interactive=False, streaming=True, autoplay=True, show_download_button=False, show_share_button=False, visible=False, ) # appears only when file is ready download_btn = gr.DownloadButton( "Download audio", visible=False, ) with gr.Row(): clear_btn = gr.Button("Clear", elem_id="clear", variant="secondary") submit_btn = gr.Button( "Submit", elem_id="submit", variant="primary", interactive=False ) # Message box for validation errors validation_msg = gr.Markdown("", visible=False) # --- Validation logic --- def validate_inputs(audio, ttext): if not audio: return gr.update( visible=True, value="⚠️ Please provide a prompt audio." ), gr.update(interactive=False) if not ttext.strip(): return gr.update( visible=True, value="⚠️ Please provide target text." ), gr.update(interactive=False) return gr.update(visible=False, value=""), gr.update(interactive=True) # Live validation whenever inputs change for inp in [prompt_audio, target_text]: inp.change( fn=validate_inputs, inputs=[prompt_audio, target_text], outputs=[validation_msg, submit_btn], ) # clear outputs before streaming submit_btn.click( fn=lambda a, t: ( gr.update(value=None, visible=True), gr.update(value=None, visible=False), ), inputs=[prompt_audio, target_text], outputs=[output_audio, download_btn], show_progress="hidden", ).then( fn=synthesize_fn, inputs=[ prompt_audio, target_text, prompt_enhancement, voice_activity_detection, streaming_input, speaking_rate_control, enable_speaking_rate, ], outputs=[output_audio, download_btn], ) clear_btn.click( fn=lambda: ( gr.update(value=None), gr.update(value=""), gr.update(value=None, visible=False), # output_audio gr.update(value=None, visible=False), # download_btn gr.update(visible=False, value=""), # validation_msg gr.update(interactive=False), # submit_btn ), inputs=[], outputs=[ prompt_audio, target_text, output_audio, download_btn, validation_msg, submit_btn, ], ) # --- Add Examples --- gr.Markdown("### Examples") ex = gr.Examples( examples=demo_examples, inputs=[ prompt_audio, target_text, prompt_enhancement, voice_activity_detection, streaming_input, speaking_rate_control, enable_speaking_rate, ], outputs=[output_audio, download_btn], fn=synthesize_fn, cache_examples=False, ) ex.dataset.click( fn=_clear_outputs, inputs=[], outputs=[output_audio, download_btn], queue=False, ).then( fn=validate_inputs, inputs=[prompt_audio, target_text], outputs=[validation_msg, submit_btn], queue=False, ) demo.launch() def main(): parser = argparse.ArgumentParser() parser.add_argument( "-c", "--config", type=existing_file, help="Path to the config file", default="configs/generator.json", ) parser.add_argument( "--spk-rate-config", type=existing_file, help="Path to the speaking rate config file", default="configs/speaking_rate.json", ) parser.add_argument( "--examples-config", type=existing_file, help="Path to the examples config file", default="assets/examples.json", ) args = parser.parse_args() with open(args.config) as f: config = SpeechGeneratorConfig(**json.load(f)) config.hf_token = os.environ.get("TOKEN") # Loading speaker encoder torch.hub.load( config.spk_enc_repo, config.spk_enc_model, model_name=config.spk_enc_model_name, train_type=config.spk_enc_train_type, dataset=config.spk_enc_dataset, trust_repo=True, verbose=False, ) with open(args.spk_rate_config) as f: speaking_rate_config = json.load(f) with open(args.examples_config) as f: examples_config = json.load(f) demo_examples = examples_config.get("examples", []) speech_generator = SpeechGenerator(config) CHUNK_SIZE = int(config.mimi_sr * MIN_CHUNK_SEC) @spaces.GPU def synthesize_fn( prompt_audio_path, target_text, prompt_enhancement, voice_activity_detection, streaming_input, speaking_rate_control, enable_speaking_rate, ): if next(speech_generator.model.parameters()).device.type == "cpu": speech_generator.model.to("cuda") speech_generator.mimi.to("cuda") speech_generator.ctx.mimi_prompt.to("cuda") speech_generator.ctx.spk_enc.to("cuda") speech_generator.ctx.device = "cuda" speech_generator.ctx.dtype = DTYPE_MAP["cuda"] if not prompt_audio_path or not target_text: return None, gr.update(value=None, visible=False) if enable_speaking_rate: duration_state, weight, cfg_gamma = interpolate_speaking_rate_params( speaking_rate_config, speaking_rate_control ) else: duration_state, weight, cfg_gamma = None, None, None stream = speech_generator.generate_stream( prompt_audio_path=Path(prompt_audio_path), text=text_generator(target_text) if streaming_input else target_text, target_spk_rate_cnt=duration_state, spk_rate_weight=weight, cfg_gamma=cfg_gamma, enhance_prompt=prompt_enhancement, apply_vad=voice_activity_detection, ) buffer = [] buffer_len = 0 total_buffer = [] for frame, _ in stream: buffer.append(frame) total_buffer.append(frame) buffer_len += frame.shape[0] if buffer_len >= CHUNK_SIZE: audio = np.concatenate(buffer) yield (config.mimi_sr, float32_to_int16(audio)), None # Reset buffer and length buffer = [] buffer_len = 0 # Handle any remaining audio in the buffer if buffer_len > 0: final = np.concatenate(buffer) nfade = min(int(config.mimi_sr * FADE_OUT_SEC), final.shape[0]) if nfade > 0: fade = np.linspace(1.0, 0.0, nfade, dtype=np.float32) final[-nfade:] *= fade yield (config.mimi_sr, float32_to_int16(final)), None # Save the full audio to a file for download if len(total_buffer) > 0: full_audio = np.concatenate(total_buffer) nfade = min(int(config.mimi_sr * FADE_OUT_SEC), full_audio.shape[0]) if nfade > 0: fade = np.linspace(1.0, 0.0, nfade, dtype=np.float32) full_audio[-nfade:] *= fade file_path = f"/tmp/voxtream_{uuid.uuid4().hex}.wav" sf.write(file_path, float32_to_int16(full_audio), config.mimi_sr) yield None, gr.update(value=file_path, visible=True) else: yield None, gr.update(value=None, visible=False) demo_app(config, demo_examples, synthesize_fn) if __name__ == "__main__": main()