tahirturk's picture
Update app.py
dfa8c4e verified
import random
import numpy as np
import torch
import gradio as gr
import spaces
import traceback
# --- IMPORTANT: replace this import with the real Chatterbox package path ---
# from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
# For the rewritten example we assume the class exists and exposes the APis used below.
from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_TEXT_CHARS = 300
DEFAULT_EXAGGERATION = 0.5
DEFAULT_TEMPERATURE = 0.8
DEFAULT_CFG_WEIGHT = 0.5
# Global model placeholder (lazy-loaded)
_MODEL = None
# A small language prompt config to populate the UI with example texts and default
# reference audio URLs. You can extend or replace with your own resources.
LANGUAGE_CONFIG = {
"en": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/en_f1.flac",
"text": "Last month, we reached a new milestone with two billion views on our YouTube channel."
},
"fr": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/fr_f1.flac",
"text": "Le mois dernier, nous avons atteint un nouveau jalon avec deux milliards de vues sur notre chaîne YouTube."
},
# ... keep or extend languages as you like
}
# -----------------------------------------------------------------------------
# Utilities
# -----------------------------------------------------------------------------
def get_default_audio(lang: str) -> str | None:
return LANGUAGE_CONFIG.get(lang, {}).get("audio")
def get_default_text(lang: str) -> str:
return LANGUAGE_CONFIG.get(lang, {}).get("text", "")
def set_seed(seed: int) -> None:
"""Set seeds for reproducibility across random, numpy and torch."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if DEVICE == "cuda":
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# -----------------------------------------------------------------------------
# Model loading and inference (lazy and safe)
# -----------------------------------------------------------------------------
def get_or_load_model(force_reload: bool = False):
"""Lazily load the ChatterboxMultilingualTTS model and move to desired device.
We avoid loading the model during module import to keep Spaces happy. The
model will be loaded when the first GPU-decorated inference call happens.
"""
global _MODEL
if _MODEL is not None and not force_reload:
return _MODEL
try:
print(f"Initializing ChatterboxMultilingualTTS on device: {DEVICE}")
_MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE)
# If the model object has a `.to()` method and device semantics, ensure
# it's on the correct device.
if hasattr(_MODEL, "to"):
try:
_MODEL.to(DEVICE)
except Exception:
# Some model wrappers manage device internally; ignore if not supported.
pass
print("Model loaded successfully.")
return _MODEL
except Exception as e:
print("Failed to load model:")
traceback.print_exc()
_MODEL = None
raise
def synthesize_audio(
text: str,
language_id: str,
audio_prompt_path: str | None = None,
exaggeration: float = DEFAULT_EXAGGERATION,
temperature: float = DEFAULT_TEMPERATURE,
seed_num: int = 0,
cfg_weight: float = DEFAULT_CFG_WEIGHT,
) -> tuple[int, np.ndarray]:
"""High-level wrapper to generate audio using the loaded model.
Returns (sample_rate, waveform_numpy_array).
"""
model = get_or_load_model()
if model is None:
raise RuntimeError("Model not loaded. Check server logs for errors.")
if seed_num and int(seed_num) != 0:
set_seed(int(seed_num))
text = (text or "").strip()[:MAX_TEXT_CHARS]
if not text:
raise ValueError("Empty text input. Please provide text to synthesize.")
# Decide audio prompt to use
chosen_prompt = audio_prompt_path if audio_prompt_path and str(audio_prompt_path).strip() else get_default_audio(language_id)
generate_kwargs = {
"exaggeration": float(exaggeration),
"temperature": float(temperature),
"cfg_weight": float(cfg_weight),
}
if chosen_prompt:
generate_kwargs["audio_prompt_path"] = chosen_prompt
print(f"Generating audio for language={language_id}, text_len={len(text)}, kwargs={generate_kwargs}")
# Call the model's generate method. We keep a defensive wrapper around it so
# that if the model returns a torch tensor we convert to numpy for Gradio.
wav = model.generate(text, language_id=language_id, **generate_kwargs)
# Assume model has `.sr` for sample rate or default to 22050
sr = getattr(model, "sr", 22050)
# Normalize output to a numpy 1-d float32 array
try:
if hasattr(wav, "squeeze"):
arr = wav.squeeze(0).cpu().numpy()
else:
# If already numpy-like
arr = np.asarray(wav)
except Exception:
# Fallback: try direct conversion
arr = np.array(wav)
return int(sr), arr
# -----------------------------------------------------------------------------
# Hugging Face Spaces GPU-aware entry point
# -----------------------------------------------------------------------------
@spaces.GPU
def gpu_generate_tts(
text_input,
language_id,
audio_prompt_path_input,
exaggeration_input,
temperature_input,
seed_num_input,
cfgw_input,
):
"""This is the GPU-decorated entry point required by Hugging Face Spaces."""
# Loading the model here ensures the Spaces runtime sees GPU usage.
# Any runtime errors will be surfaced in the Spaces logs.
try:
sr, wav = synthesize_audio(
text_input,
language_id,
audio_prompt_path=audio_prompt_path_input,
exaggeration=exaggeration_input,
temperature=temperature_input,
seed_num=seed_num_input,
cfg_weight=cfgw_input,
)
# Gradio Audio expects either a path or a tuple (sr, numpy_array)
return (sr, wav)
except Exception as e:
print("Error during TTS generation:")
traceback.print_exc()
# Gradio will show this string as the output if an exception occurs.
return f"Error during generation: {str(e)}"
# -----------------------------------------------------------------------------
# Gradio UI
# -----------------------------------------------------------------------------
def get_supported_languages_display() -> str:
items = []
try:
for code, name in sorted(SUPPORTED_LANGUAGES.items()):
items.append(f"**{name}** (`{code}`)")
except Exception:
# If SUPPORTED_LANGUAGES isn't available, fallback to a short list
items = ["**English** (`en`)", "**French** (`fr`)"]
mid = len(items) // 2 if items else 1
line1 = " • ".join(items[:mid])
line2 = " • ".join(items[mid:])
return f"""
### 🌍 Supported Languages ({len(items)} total)
{line1}
{line2}
"""
with gr.Blocks(title="Chatterbox Multilingual TTS") as demo:
gr.Markdown("""
# Chatterbox Multilingual Demo
Generate high-quality multilingual speech from text with optional reference audio styling.
""")
gr.Markdown(get_supported_languages_display())
with gr.Row():
with gr.Column(scale=6):
initial_lang = "en"
text = gr.Textbox(
value=get_default_text(initial_lang),
label="Text to synthesize (max chars {})".format(MAX_TEXT_CHARS),
max_lines=6,
)
language_id = gr.Dropdown(
choices=list(SUPPORTED_LANGUAGES.keys()),
value=initial_lang,
label="Language",
info="Select the language code for synthesis",
)
ref_wav = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Reference Audio File (Optional)",
value=get_default_audio(initial_lang),
)
gr.Markdown(
"💡 **Note**: If you provide a reference audio, make sure it matches the language tag. "
"For language transfer behavior set CFG weight to 0."
)
exaggeration = gr.Slider(0.25, 2.0, value=DEFAULT_EXAGGERATION, step=0.05, label="Exaggeration")
cfg_weight = gr.Slider(0.0, 1.0, value=DEFAULT_CFG_WEIGHT, step=0.05, label="CFG / Pace Weight")
with gr.Accordion("More options", open=False):
seed_num = gr.Number(value=0, label="Random seed (0 for random)")
temp = gr.Slider(0.05, 5.0, value=DEFAULT_TEMPERATURE, step=0.05, label="Temperature")
run_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=4):
audio_output = gr.Audio(label="Output Audio")
status = gr.Textbox(label="Status / Logs", interactive=False)
def on_language_change(lang, current_ref, current_text):
return get_default_audio(lang), get_default_text(lang)
language_id.change(fn=on_language_change, inputs=[language_id, ref_wav, text], outputs=[ref_wav, text])
# Hook the GPU-decorated function into Gradio. This is the critical change to make
# Hugging Face Spaces detect GPU usage at startup.
run_btn.click(
fn=gpu_generate_tts,
inputs=[
text,
language_id,
ref_wav,
exaggeration,
temp,
seed_num,
cfg_weight,
],
outputs=[audio_output],
)
# Expose demo for Spaces. In a Spaces environment, `demo.launch()` is optional but
# harmless; Spaces runs the app for you. Keep a minimal launch for local testing.
if __name__ == "__main__":
# demo.launch(server_name="0.0.0.0", server_port=7860)