oicui's picture
Update app.py
96c42e1 verified
import random
import re
import numpy as np
import torch
import torchaudio
from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
import gradio as gr
import spaces
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🚀 Running on device: {DEVICE}")
MODEL = None
LANGUAGE_CONFIG = {
"ar": {"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ar_f/ar_prompts2.flac",
"text": "في الشهر الماضي، وصلنا إلى معلم جديد بمليارين من المشاهدات على قناتنا على يوتيوب."},
"en": {"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/en_f1.flac",
"text": "Last month, we reached a new milestone with two billion views on our YouTube channel."},
"fr": {"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/fr_f1.flac",
"text": "Le mois dernier, nous avons atteint un nouveau jalon avec deux milliards de vues sur notre chaîne YouTube."},
"hi": {"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/hi_f1.flac",
"text": "पिछले महीने हमने एक नया मील का पत्थर छुआ: हमारे YouTube चैनल पर दो अरब व्यूज़।"},
"tr": {"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/tr_m.flac",
"text": "Geçen ay YouTube kanalımızda iki milyar görüntüleme ile yeni bir dönüm noktasına ulaştık."},
"zh": {"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/zh_f2.flac",
"text": "上个月,我们达到了一个新的里程碑。 我们的YouTube频道观看次数达到了二十亿次,这绝对令人难以置信。"},
}
def default_audio_for_ui(lang: str) -> str | None:
return LANGUAGE_CONFIG.get(lang, {}).get("audio")
def default_text_for_ui(lang: str) -> str:
return LANGUAGE_CONFIG.get(lang, {}).get("text", "")
def get_supported_languages_display() -> str:
items = [f"**{name}** (`{code}`)" for code, name in sorted(SUPPORTED_LANGUAGES.items())]
mid = len(items)//2
return (
f"### 🌍 Supported Languages ({len(SUPPORTED_LANGUAGES)} total)\n"
f"{' • '.join(items[:mid])}\n\n{' • '.join(items[mid:])}"
)
def get_or_load_model():
global MODEL
if MODEL is None:
print("Model not loaded, initializing...")
MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE)
if hasattr(MODEL, "to"):
MODEL.to(DEVICE)
print(f"✅ Model loaded successfully on {DEVICE}")
return MODEL
try:
get_or_load_model()
except Exception as e:
print(f"CRITICAL: Failed to load model. Error: {e}")
def set_seed(seed: int):
torch.manual_seed(seed)
if DEVICE == "cuda":
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
def resolve_audio_prompt(language_id: str, provided_path: str | None) -> str | None:
if provided_path and str(provided_path).strip():
return provided_path
return LANGUAGE_CONFIG.get(language_id, {}).get("audio")
# ============================
# SMART CHUNKING (TỐI ƯU)
# ============================
def smart_chunk_text(text: str, max_chars: int = 500) -> list[str]:
"""
Chia text thành các đoạn (chunk) ngắn:
- Ưu tiên tách theo câu.
- Nếu câu quá dài thì tách tiếp theo từ.
- Gộp nhiều câu nhỏ vào 1 chunk để giảm số lần gọi model.
"""
# Normalize khoảng trắng
text = re.sub(r"\s+", " ", text.strip())
if not text:
return []
if len(text) <= max_chars:
return [text]
# Hỗ trợ nhiều dấu câu đa ngôn ngữ: . ! ? … ؟ ، : ؛ ।
sentences = re.split(r'(?<=[\.!\?…؟،:؛।])\s+', text)
chunks: list[str] = []
current = ""
for sent in sentences:
sent = sent.strip()
if not sent:
continue
# Nếu bản thân câu đã dài hơn max_chars -> chia mềm theo từ
if len(sent) > max_chars:
words = sent.split()
temp = ""
for w in words:
if len(temp) + len(w) + 1 > max_chars:
if temp:
chunks.append(temp.strip())
temp = ""
temp += w + " "
if temp:
chunks.append(temp.strip())
continue
# Nếu gộp thêm câu mà vẫn không vượt max_chars -> gộp chung
if len(current) + len(sent) + 1 <= max_chars:
current += sent + " "
else:
if current:
chunks.append(current.strip())
current = sent + " "
if current:
chunks.append(current.strip())
return [c for c in chunks if c]
def concat_audio_torch(chunks: list[torch.Tensor],
crossfade_ms: int = 10,
sr: int = 24000) -> torch.Tensor:
"""
Nối nhiều đoạn audio (1D tensor) bằng crossfade nhẹ để tránh tiếng "click".
"""
if not chunks:
return torch.empty(0)
if len(chunks) == 1 or crossfade_ms <= 0:
return torch.cat(chunks, dim=-1)
output = chunks[0]
crossfade = int(crossfade_ms * sr / 1000)
for i in range(1, len(chunks)):
a = output
b = chunks[i]
# Đảm bảo crossfade không lớn hơn độ dài đoạn
cf = min(crossfade, a.shape[-1], b.shape[-1])
if cf <= 0:
output = torch.cat([a, b], dim=-1)
continue
fade_out = torch.linspace(1.0, 0.0, steps=cf, device=a.device, dtype=a.dtype)
fade_in = torch.linspace(0.0, 1.0, steps=cf, device=b.device, dtype=b.dtype)
a_tail = a[..., -cf:] * fade_out
b_head = b[..., :cf] * fade_in
mixed = a_tail + b_head
a_main = a[..., :-cf]
b_rest = b[..., cf:]
output = torch.cat([a_main, mixed, b_rest], dim=-1)
return output
@spaces.GPU
def generate_tts_audio(
text_input: str,
language_id: str,
audio_prompt_path_input: str = None,
exaggeration_input: float = 0.5,
temperature_input: float = 0.8,
seed_num_input: int = 0,
cfgw_input: float = 0.5
):
current_model = get_or_load_model()
if current_model is None:
raise RuntimeError("TTS model not loaded.")
# --- SEED LOGIC ---
if seed_num_input == 0:
seed_num_input = random.randint(1, 2**32 - 1)
print(f"🌱 Random seed generated: {seed_num_input}")
else:
print(f"🌱 Using provided seed: {seed_num_input}")
set_seed(int(seed_num_input))
chosen_prompt = audio_prompt_path_input or default_audio_for_ui(language_id)
generate_kwargs = {
"exaggeration": exaggeration_input,
"temperature": temperature_input,
"cfg_weight": cfgw_input,
}
if chosen_prompt:
generate_kwargs["audio_prompt_path"] = chosen_prompt
# 💡 DÙNG SMART CHUNKING TỐI ƯU
chunks = smart_chunk_text(text_input, max_chars=500)
print(f"📚 Total chunks: {len(chunks)}")
all_audio: list[torch.Tensor] = []
for idx, chunk in enumerate(chunks, start=1):
print(f"🎧 Rendering chunk {idx}/{len(chunks)} (len={len(chunk)} chars)")
wav = current_model.generate(chunk, language_id=language_id, **generate_kwargs)
all_audio.append(wav.squeeze(0).cpu())
# 🔗 NỐI AUDIO VỚI CROSSFADE NHẸ
final_audio = concat_audio_torch(
all_audio,
crossfade_ms=12,
sr=current_model.sr
)
# RETURN AUDIO + SEED
return (current_model.sr, final_audio.numpy()), str(seed_num_input)
# ============================
# GRADIO UI
# ============================
with gr.Blocks() as demo:
gr.Markdown("""
# 🎙️ Multi Language Realistic Voice Cloner
Generate long-form multilingual speech with reference audio styling and smart chunking (crossfaded).
""")
gr.Markdown(get_supported_languages_display())
with gr.Row():
with gr.Column():
initial_lang = "en"
text = gr.Textbox(
value=default_text_for_ui(initial_lang),
label="Text to synthesize",
lines=8
)
language_id = gr.Dropdown(
choices=list(ChatterboxMultilingualTTS.get_supported_languages().keys()),
value=initial_lang,
label="Language"
)
ref_wav = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Reference Audio (Optional)",
value=default_audio_for_ui(initial_lang)
)
exaggeration = gr.Slider(0.25, 2, step=.05, label="Exaggeration", value=.5)
cfg_weight = gr.Slider(0.2, 1, step=.05, label="CFG Weight", value=0.5)
with gr.Accordion("Advanced", open=False):
seed_num = gr.Number(value=0, label="Random Seed (0=random)")
temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8)
run_btn = gr.Button("Generate", variant="primary")
# OUTPUT COLUMN
with gr.Column():
audio_output = gr.Audio(label="Output Audio")
seed_output = gr.Textbox(label="Seed Used", interactive=False)
def on_lang_change(lang, current_ref, current_text):
return default_audio_for_ui(lang), default_text_for_ui(lang)
language_id.change(
fn=on_lang_change,
inputs=[language_id, ref_wav, text],
outputs=[ref_wav, text],
show_progress=False
)
# CONNECT BUTTON
run_btn.click(
fn=generate_tts_audio,
inputs=[text, language_id, ref_wav, exaggeration, temp, seed_num, cfg_weight],
outputs=[audio_output, seed_output],
)
demo.launch(mcp_server=True, share=True)