Spaces:
Runtime error
Runtime error
| import random | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| import traceback | |
| # --- IMPORTANT: replace this import with the real Chatterbox package path --- | |
| # from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES | |
| # For the rewritten example we assume the class exists and exposes the APis used below. | |
| from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES | |
| # ----------------------------------------------------------------------------- | |
| # Configuration | |
| # ----------------------------------------------------------------------------- | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MAX_TEXT_CHARS = 300 | |
| DEFAULT_EXAGGERATION = 0.5 | |
| DEFAULT_TEMPERATURE = 0.8 | |
| DEFAULT_CFG_WEIGHT = 0.5 | |
| # Global model placeholder (lazy-loaded) | |
| _MODEL = None | |
| # A small language prompt config to populate the UI with example texts and default | |
| # reference audio URLs. You can extend or replace with your own resources. | |
| LANGUAGE_CONFIG = { | |
| "en": { | |
| "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/en_f1.flac", | |
| "text": "Last month, we reached a new milestone with two billion views on our YouTube channel." | |
| }, | |
| "fr": { | |
| "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/fr_f1.flac", | |
| "text": "Le mois dernier, nous avons atteint un nouveau jalon avec deux milliards de vues sur notre chaîne YouTube." | |
| }, | |
| # ... keep or extend languages as you like | |
| } | |
| # ----------------------------------------------------------------------------- | |
| # Utilities | |
| # ----------------------------------------------------------------------------- | |
| def get_default_audio(lang: str) -> str | None: | |
| return LANGUAGE_CONFIG.get(lang, {}).get("audio") | |
| def get_default_text(lang: str) -> str: | |
| return LANGUAGE_CONFIG.get(lang, {}).get("text", "") | |
| def set_seed(seed: int) -> None: | |
| """Set seeds for reproducibility across random, numpy and torch.""" | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if DEVICE == "cuda": | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| # ----------------------------------------------------------------------------- | |
| # Model loading and inference (lazy and safe) | |
| # ----------------------------------------------------------------------------- | |
| def get_or_load_model(force_reload: bool = False): | |
| """Lazily load the ChatterboxMultilingualTTS model and move to desired device. | |
| We avoid loading the model during module import to keep Spaces happy. The | |
| model will be loaded when the first GPU-decorated inference call happens. | |
| """ | |
| global _MODEL | |
| if _MODEL is not None and not force_reload: | |
| return _MODEL | |
| try: | |
| print(f"Initializing ChatterboxMultilingualTTS on device: {DEVICE}") | |
| _MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE) | |
| # If the model object has a `.to()` method and device semantics, ensure | |
| # it's on the correct device. | |
| if hasattr(_MODEL, "to"): | |
| try: | |
| _MODEL.to(DEVICE) | |
| except Exception: | |
| # Some model wrappers manage device internally; ignore if not supported. | |
| pass | |
| print("Model loaded successfully.") | |
| return _MODEL | |
| except Exception as e: | |
| print("Failed to load model:") | |
| traceback.print_exc() | |
| _MODEL = None | |
| raise | |
| def synthesize_audio( | |
| text: str, | |
| language_id: str, | |
| audio_prompt_path: str | None = None, | |
| exaggeration: float = DEFAULT_EXAGGERATION, | |
| temperature: float = DEFAULT_TEMPERATURE, | |
| seed_num: int = 0, | |
| cfg_weight: float = DEFAULT_CFG_WEIGHT, | |
| ) -> tuple[int, np.ndarray]: | |
| """High-level wrapper to generate audio using the loaded model. | |
| Returns (sample_rate, waveform_numpy_array). | |
| """ | |
| model = get_or_load_model() | |
| if model is None: | |
| raise RuntimeError("Model not loaded. Check server logs for errors.") | |
| if seed_num and int(seed_num) != 0: | |
| set_seed(int(seed_num)) | |
| text = (text or "").strip()[:MAX_TEXT_CHARS] | |
| if not text: | |
| raise ValueError("Empty text input. Please provide text to synthesize.") | |
| # Decide audio prompt to use | |
| chosen_prompt = audio_prompt_path if audio_prompt_path and str(audio_prompt_path).strip() else get_default_audio(language_id) | |
| generate_kwargs = { | |
| "exaggeration": float(exaggeration), | |
| "temperature": float(temperature), | |
| "cfg_weight": float(cfg_weight), | |
| } | |
| if chosen_prompt: | |
| generate_kwargs["audio_prompt_path"] = chosen_prompt | |
| print(f"Generating audio for language={language_id}, text_len={len(text)}, kwargs={generate_kwargs}") | |
| # Call the model's generate method. We keep a defensive wrapper around it so | |
| # that if the model returns a torch tensor we convert to numpy for Gradio. | |
| wav = model.generate(text, language_id=language_id, **generate_kwargs) | |
| # Assume model has `.sr` for sample rate or default to 22050 | |
| sr = getattr(model, "sr", 22050) | |
| # Normalize output to a numpy 1-d float32 array | |
| try: | |
| if hasattr(wav, "squeeze"): | |
| arr = wav.squeeze(0).cpu().numpy() | |
| else: | |
| # If already numpy-like | |
| arr = np.asarray(wav) | |
| except Exception: | |
| # Fallback: try direct conversion | |
| arr = np.array(wav) | |
| return int(sr), arr | |
| # ----------------------------------------------------------------------------- | |
| # Hugging Face Spaces GPU-aware entry point | |
| # ----------------------------------------------------------------------------- | |
| def gpu_generate_tts( | |
| text_input, | |
| language_id, | |
| audio_prompt_path_input, | |
| exaggeration_input, | |
| temperature_input, | |
| seed_num_input, | |
| cfgw_input, | |
| ): | |
| """This is the GPU-decorated entry point required by Hugging Face Spaces.""" | |
| # Loading the model here ensures the Spaces runtime sees GPU usage. | |
| # Any runtime errors will be surfaced in the Spaces logs. | |
| try: | |
| sr, wav = synthesize_audio( | |
| text_input, | |
| language_id, | |
| audio_prompt_path=audio_prompt_path_input, | |
| exaggeration=exaggeration_input, | |
| temperature=temperature_input, | |
| seed_num=seed_num_input, | |
| cfg_weight=cfgw_input, | |
| ) | |
| # Gradio Audio expects either a path or a tuple (sr, numpy_array) | |
| return (sr, wav) | |
| except Exception as e: | |
| print("Error during TTS generation:") | |
| traceback.print_exc() | |
| # Gradio will show this string as the output if an exception occurs. | |
| return f"Error during generation: {str(e)}" | |
| # ----------------------------------------------------------------------------- | |
| # Gradio UI | |
| # ----------------------------------------------------------------------------- | |
| def get_supported_languages_display() -> str: | |
| items = [] | |
| try: | |
| for code, name in sorted(SUPPORTED_LANGUAGES.items()): | |
| items.append(f"**{name}** (`{code}`)") | |
| except Exception: | |
| # If SUPPORTED_LANGUAGES isn't available, fallback to a short list | |
| items = ["**English** (`en`)", "**French** (`fr`)"] | |
| mid = len(items) // 2 if items else 1 | |
| line1 = " • ".join(items[:mid]) | |
| line2 = " • ".join(items[mid:]) | |
| return f""" | |
| ### 🌍 Supported Languages ({len(items)} total) | |
| {line1} | |
| {line2} | |
| """ | |
| with gr.Blocks(title="Chatterbox Multilingual TTS") as demo: | |
| gr.Markdown(""" | |
| # Chatterbox Multilingual Demo | |
| Generate high-quality multilingual speech from text with optional reference audio styling. | |
| """) | |
| gr.Markdown(get_supported_languages_display()) | |
| with gr.Row(): | |
| with gr.Column(scale=6): | |
| initial_lang = "en" | |
| text = gr.Textbox( | |
| value=get_default_text(initial_lang), | |
| label="Text to synthesize (max chars {})".format(MAX_TEXT_CHARS), | |
| max_lines=6, | |
| ) | |
| language_id = gr.Dropdown( | |
| choices=list(SUPPORTED_LANGUAGES.keys()), | |
| value=initial_lang, | |
| label="Language", | |
| info="Select the language code for synthesis", | |
| ) | |
| ref_wav = gr.Audio( | |
| sources=["upload", "microphone"], | |
| type="filepath", | |
| label="Reference Audio File (Optional)", | |
| value=get_default_audio(initial_lang), | |
| ) | |
| gr.Markdown( | |
| "💡 **Note**: If you provide a reference audio, make sure it matches the language tag. " | |
| "For language transfer behavior set CFG weight to 0." | |
| ) | |
| exaggeration = gr.Slider(0.25, 2.0, value=DEFAULT_EXAGGERATION, step=0.05, label="Exaggeration") | |
| cfg_weight = gr.Slider(0.0, 1.0, value=DEFAULT_CFG_WEIGHT, step=0.05, label="CFG / Pace Weight") | |
| with gr.Accordion("More options", open=False): | |
| seed_num = gr.Number(value=0, label="Random seed (0 for random)") | |
| temp = gr.Slider(0.05, 5.0, value=DEFAULT_TEMPERATURE, step=0.05, label="Temperature") | |
| run_btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=4): | |
| audio_output = gr.Audio(label="Output Audio") | |
| status = gr.Textbox(label="Status / Logs", interactive=False) | |
| def on_language_change(lang, current_ref, current_text): | |
| return get_default_audio(lang), get_default_text(lang) | |
| language_id.change(fn=on_language_change, inputs=[language_id, ref_wav, text], outputs=[ref_wav, text]) | |
| # Hook the GPU-decorated function into Gradio. This is the critical change to make | |
| # Hugging Face Spaces detect GPU usage at startup. | |
| run_btn.click( | |
| fn=gpu_generate_tts, | |
| inputs=[ | |
| text, | |
| language_id, | |
| ref_wav, | |
| exaggeration, | |
| temp, | |
| seed_num, | |
| cfg_weight, | |
| ], | |
| outputs=[audio_output], | |
| ) | |
| # Expose demo for Spaces. In a Spaces environment, `demo.launch()` is optional but | |
| # harmless; Spaces runs the app for you. Keep a minimal launch for local testing. | |
| if __name__ == "__main__": | |
| # demo.launch(server_name="0.0.0.0", server_port=7860) | |