Spaces:
Runtime error
Runtime error
| import torch | |
| import torchaudio | |
| import gradio as gr | |
| from zonos.model import Zonos | |
| from zonos.conditioning import make_cond_dict | |
| # Global cache to hold the loaded model | |
| MODEL = None | |
| device = "cuda" | |
| def load_model(): | |
| """ | |
| Loads the Zonos model once and caches it globally. | |
| Adjust the model name to the one you want to use. | |
| """ | |
| global MODEL | |
| if MODEL is None: | |
| model_name = "Zyphra/Zonos-v0.1-hybrid" | |
| print(f"Loading model: {model_name}") | |
| MODEL = Zonos.from_pretrained(model_name, device="cuda") | |
| MODEL = MODEL.requires_grad_(False).eval() | |
| MODEL.bfloat16() # optional, if your GPU supports bfloat16 | |
| print("Model loaded successfully!") | |
| return MODEL | |
| def tts(text, speaker_audio): | |
| """ | |
| text: str | |
| speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy" | |
| Returns (sample_rate, waveform) for Gradio audio output. | |
| """ | |
| model = load_model() | |
| if not text: | |
| return None | |
| # If the user hasn't provided any audio, just return None or a placeholder | |
| if speaker_audio is None: | |
| return None | |
| # Gradio provides audio in the format (sample_rate, numpy_array) | |
| sr, wav_np = speaker_audio | |
| # Convert to Torch tensor: shape (1, num_samples) | |
| wav_tensor = torch.from_numpy(wav_np).unsqueeze(0).float() | |
| if wav_tensor.dim() == 2 and wav_tensor.shape[0] > wav_tensor.shape[1]: | |
| # If shape is transposed, fix it | |
| wav_tensor = wav_tensor.T | |
| # Get speaker embedding | |
| with torch.no_grad(): | |
| spk_embedding = model.make_speaker_embedding(wav_tensor, sr) | |
| spk_embedding = spk_embedding.to(device, dtype=torch.bfloat16) | |
| # Prepare conditioning dictionary | |
| cond_dict = make_cond_dict( | |
| text=text, # The text prompt | |
| speaker=spk_embedding, # Speaker embedding from reference audio | |
| language="en-us", # Hard-coded language or switch to another if needed | |
| device=device, | |
| ) | |
| conditioning = model.prepare_conditioning(cond_dict) | |
| # Generate codes | |
| with torch.no_grad(): | |
| # Optionally set a manual seed for reproducibility | |
| # torch.manual_seed(1234) | |
| codes = model.generate(conditioning) | |
| # Decode the codes into raw audio | |
| wav_out = model.autoencoder.decode(codes).cpu().detach().squeeze() | |
| sr_out = model.autoencoder.sampling_rate | |
| return (sr_out, wav_out.numpy()) | |
| def build_demo(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Simple Zonos TTS Demo (Text + Reference Audio)") | |
| with gr.Row(): | |
| text_input = gr.Textbox( | |
| label="Text Prompt", | |
| value="Hello from Zonos!", | |
| lines=3 | |
| ) | |
| ref_audio_input = gr.Audio( | |
| label="Reference Audio (Speaker Cloning)", | |
| type="numpy" | |
| ) | |
| generate_button = gr.Button("Generate") | |
| # The output will be an audio widget that Gradio will play | |
| audio_output = gr.Audio(label="Synthesized Output", type="numpy") | |
| # Bind the generate button | |
| generate_button.click( | |
| fn=tts, | |
| inputs=[text_input, ref_audio_input], | |
| outputs=audio_output, | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo_app = build_demo() | |
| demo_app.launch(server_name="0.0.0.0", server_port=7860, share=True) | |