Spaces:
Build error
Build error
Commit
·
71b9145
1
Parent(s):
281d8d1
- app.py +101 -28
- chatterbox_dhivehi.py +1 -1
app.py
CHANGED
|
@@ -59,16 +59,18 @@ def download_model():
|
|
| 59 |
print(f"Warning: Could not download model files: {e}")
|
| 60 |
print("=" * 60)
|
| 61 |
|
| 62 |
-
def load_model(checkpoint=
|
| 63 |
"""Load the TTS model"""
|
| 64 |
global MODEL
|
| 65 |
try:
|
| 66 |
-
|
|
|
|
|
|
|
| 67 |
MODEL = ChatterboxTTS.from_dhivehi(
|
| 68 |
-
ckpt_dir=Path(
|
| 69 |
-
device=
|
| 70 |
)
|
| 71 |
-
print("Model loaded successfully!")
|
| 72 |
except Exception as e:
|
| 73 |
print(f"Error loading model: {e}")
|
| 74 |
raise e
|
|
@@ -82,14 +84,14 @@ def set_seed(seed: int):
|
|
| 82 |
random.seed(seed)
|
| 83 |
np.random.seed(seed)
|
| 84 |
|
| 85 |
-
|
| 86 |
-
def
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
"""
|
| 93 |
global MODEL
|
| 94 |
|
| 95 |
# Clean the input text
|
|
@@ -161,6 +163,25 @@ def generate_speech(text,
|
|
| 161 |
print(error_msg)
|
| 162 |
return None, error_msg
|
| 163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
def clean_text(text):
|
| 165 |
"""Clean text by removing newlines at start/end, double spaces, and extra whitespace"""
|
| 166 |
import re
|
|
@@ -224,14 +245,15 @@ def split_sentences(text):
|
|
| 224 |
|
| 225 |
return final_sentences
|
| 226 |
|
| 227 |
-
|
| 228 |
-
def
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
|
|
|
| 235 |
global MODEL
|
| 236 |
|
| 237 |
# Clean the input text
|
|
@@ -251,7 +273,7 @@ def generate_speech_multi_sentence(text,
|
|
| 251 |
# If only one sentence or no periods, use regular method
|
| 252 |
if len(sentences) <= 1:
|
| 253 |
yield None, "Generating single sentence..."
|
| 254 |
-
result_audio, result_status = generate_speech(text, reference_audio, exaggeration, temperature, cfg_weight, seed)
|
| 255 |
yield result_audio, result_status
|
| 256 |
return
|
| 257 |
|
|
@@ -360,12 +382,32 @@ def generate_speech_multi_sentence(text,
|
|
| 360 |
print(error_msg)
|
| 361 |
yield None, error_msg
|
| 362 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
def create_interface():
|
| 364 |
"""Create the Gradio interface"""
|
| 365 |
|
| 366 |
-
# Load the model
|
| 367 |
-
load_model()
|
| 368 |
-
|
| 369 |
# Sample texts in Dhivehi
|
| 370 |
sample_texts = [
|
| 371 |
"ކާޑު ނުލައި ފައިސާ ދެއްކޭ ނެޝަނަލް ކިއުއާރް ކޯޑް އެމްއެމްއޭ އިން ތައާރަފްކުރަނީ",
|
|
@@ -456,6 +498,21 @@ The ministry handed over the land reclamation, replacement of the port canal and
|
|
| 456 |
label="Seed",
|
| 457 |
info="For reproducible results"
|
| 458 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
|
| 460 |
# Row 4: Generate button
|
| 461 |
generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
|
|
@@ -473,6 +530,15 @@ The ministry handed over the land reclamation, replacement of the port canal and
|
|
| 473 |
def set_reference_audio(audio_file):
|
| 474 |
return audio_file
|
| 475 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
sample_btn1.click(lambda: set_sample_text(0), outputs=[text_input])
|
| 477 |
sample_btn2.click(lambda: set_sample_text(1), outputs=[text_input])
|
| 478 |
sample_btn3.click(lambda: set_sample_text(2), outputs=[text_input])
|
|
@@ -483,17 +549,24 @@ The ministry handed over the land reclamation, replacement of the port canal and
|
|
| 483 |
ref_btn3.click(lambda: set_reference_audio("m1.wav"), outputs=[reference_audio])
|
| 484 |
ref_btn4.click(lambda: set_reference_audio("m2.wav"), outputs=[reference_audio])
|
| 485 |
|
| 486 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
"""Generate speech with streaming progress updates"""
|
|
|
|
| 488 |
# Use the streaming generator
|
| 489 |
for result_audio, result_status in generate_speech_multi_sentence(
|
| 490 |
-
text, reference_audio, exaggeration, temperature, cfg_weight, seed
|
| 491 |
):
|
| 492 |
yield result_audio, result_status
|
| 493 |
|
| 494 |
generate_btn.click(
|
| 495 |
fn=generate_with_progress,
|
| 496 |
-
inputs=[text_input, reference_audio, exaggeration, temperature, cfg_weight, seed],
|
| 497 |
outputs=[output_audio, status_message]
|
| 498 |
)
|
| 499 |
|
|
|
|
| 59 |
print(f"Warning: Could not download model files: {e}")
|
| 60 |
print("=" * 60)
|
| 61 |
|
| 62 |
+
def load_model(checkpoint="kn_cbox", device="cuda"):
|
| 63 |
"""Load the TTS model"""
|
| 64 |
global MODEL
|
| 65 |
try:
|
| 66 |
+
checkpoint_path = f"{_target}/{checkpoint}"
|
| 67 |
+
print(f"Loading model with checkpoint: {checkpoint_path}")
|
| 68 |
+
print(f"Target device: {device}")
|
| 69 |
MODEL = ChatterboxTTS.from_dhivehi(
|
| 70 |
+
ckpt_dir=Path(checkpoint_path),
|
| 71 |
+
device=device
|
| 72 |
)
|
| 73 |
+
print(f"Model loaded successfully on {device}!")
|
| 74 |
except Exception as e:
|
| 75 |
print(f"Error loading model: {e}")
|
| 76 |
raise e
|
|
|
|
| 84 |
random.seed(seed)
|
| 85 |
np.random.seed(seed)
|
| 86 |
|
| 87 |
+
# Internal implementation without decorator
|
| 88 |
+
def _generate_speech_impl(text,
|
| 89 |
+
reference_audio,
|
| 90 |
+
exaggeration=0.5,
|
| 91 |
+
temperature=0.1,
|
| 92 |
+
cfg_weight=0.5,
|
| 93 |
+
seed=42):
|
| 94 |
+
"""Internal implementation of generate speech"""
|
| 95 |
global MODEL
|
| 96 |
|
| 97 |
# Clean the input text
|
|
|
|
| 163 |
print(error_msg)
|
| 164 |
return None, error_msg
|
| 165 |
|
| 166 |
+
# GPU version with decorator
|
| 167 |
+
@spaces.GPU(duration=60)
|
| 168 |
+
def _generate_speech_gpu(text, reference_audio, exaggeration=0.5, temperature=0.1, cfg_weight=0.5, seed=42):
|
| 169 |
+
"""GPU version of generate speech"""
|
| 170 |
+
return _generate_speech_impl(text, reference_audio, exaggeration, temperature, cfg_weight, seed)
|
| 171 |
+
|
| 172 |
+
# CPU version without decorator
|
| 173 |
+
def _generate_speech_cpu(text, reference_audio, exaggeration=0.5, temperature=0.1, cfg_weight=0.5, seed=42):
|
| 174 |
+
"""CPU version of generate speech"""
|
| 175 |
+
return _generate_speech_impl(text, reference_audio, exaggeration, temperature, cfg_weight, seed)
|
| 176 |
+
|
| 177 |
+
# Router function
|
| 178 |
+
def generate_speech(text, reference_audio, exaggeration=0.5, temperature=0.1, cfg_weight=0.5, seed=42, use_gpu=True):
|
| 179 |
+
"""Generate speech from text using voice cloning"""
|
| 180 |
+
if use_gpu:
|
| 181 |
+
return _generate_speech_gpu(text, reference_audio, exaggeration, temperature, cfg_weight, seed)
|
| 182 |
+
else:
|
| 183 |
+
return _generate_speech_cpu(text, reference_audio, exaggeration, temperature, cfg_weight, seed)
|
| 184 |
+
|
| 185 |
def clean_text(text):
|
| 186 |
"""Clean text by removing newlines at start/end, double spaces, and extra whitespace"""
|
| 187 |
import re
|
|
|
|
| 245 |
|
| 246 |
return final_sentences
|
| 247 |
|
| 248 |
+
# Internal implementation without decorator
|
| 249 |
+
def _generate_speech_multi_sentence_impl(text,
|
| 250 |
+
reference_audio,
|
| 251 |
+
exaggeration=0.5,
|
| 252 |
+
temperature=0.1,
|
| 253 |
+
cfg_weight=0.5,
|
| 254 |
+
seed=42,
|
| 255 |
+
use_gpu=True):
|
| 256 |
+
"""Internal implementation of multi-sentence speech generation"""
|
| 257 |
global MODEL
|
| 258 |
|
| 259 |
# Clean the input text
|
|
|
|
| 273 |
# If only one sentence or no periods, use regular method
|
| 274 |
if len(sentences) <= 1:
|
| 275 |
yield None, "Generating single sentence..."
|
| 276 |
+
result_audio, result_status = generate_speech(text, reference_audio, exaggeration, temperature, cfg_weight, seed, use_gpu)
|
| 277 |
yield result_audio, result_status
|
| 278 |
return
|
| 279 |
|
|
|
|
| 382 |
print(error_msg)
|
| 383 |
yield None, error_msg
|
| 384 |
|
| 385 |
+
# GPU version with decorator
|
| 386 |
+
@spaces.GPU
|
| 387 |
+
def _generate_speech_multi_sentence_gpu(text, reference_audio, exaggeration=0.5, temperature=0.1, cfg_weight=0.5, seed=42):
|
| 388 |
+
"""GPU version of multi-sentence speech generation"""
|
| 389 |
+
for result in _generate_speech_multi_sentence_impl(text, reference_audio, exaggeration, temperature, cfg_weight, seed, use_gpu=True):
|
| 390 |
+
yield result
|
| 391 |
+
|
| 392 |
+
# CPU version without decorator
|
| 393 |
+
def _generate_speech_multi_sentence_cpu(text, reference_audio, exaggeration=0.5, temperature=0.1, cfg_weight=0.5, seed=42):
|
| 394 |
+
"""CPU version of multi-sentence speech generation"""
|
| 395 |
+
for result in _generate_speech_multi_sentence_impl(text, reference_audio, exaggeration, temperature, cfg_weight, seed, use_gpu=False):
|
| 396 |
+
yield result
|
| 397 |
+
|
| 398 |
+
# Router function
|
| 399 |
+
def generate_speech_multi_sentence(text, reference_audio, exaggeration=0.5, temperature=0.1, cfg_weight=0.5, seed=42, use_gpu=True):
|
| 400 |
+
"""Generate speech from text with multi-sentence support and progress tracking"""
|
| 401 |
+
if use_gpu:
|
| 402 |
+
for result in _generate_speech_multi_sentence_gpu(text, reference_audio, exaggeration, temperature, cfg_weight, seed):
|
| 403 |
+
yield result
|
| 404 |
+
else:
|
| 405 |
+
for result in _generate_speech_multi_sentence_cpu(text, reference_audio, exaggeration, temperature, cfg_weight, seed):
|
| 406 |
+
yield result
|
| 407 |
+
|
| 408 |
def create_interface():
|
| 409 |
"""Create the Gradio interface"""
|
| 410 |
|
|
|
|
|
|
|
|
|
|
| 411 |
# Sample texts in Dhivehi
|
| 412 |
sample_texts = [
|
| 413 |
"ކާޑު ނުލައި ފައިސާ ދެއްކޭ ނެޝަނަލް ކިއުއާރް ކޯޑް އެމްއެމްއޭ އިން ތައާރަފްކުރަނީ",
|
|
|
|
| 498 |
label="Seed",
|
| 499 |
info="For reproducible results"
|
| 500 |
)
|
| 501 |
+
with gr.Row():
|
| 502 |
+
model_select = gr.Dropdown(
|
| 503 |
+
choices=["kn_cbox", "f01_cbox"],
|
| 504 |
+
value="kn_cbox",
|
| 505 |
+
label="Model",
|
| 506 |
+
info="Select TTS model"
|
| 507 |
+
)
|
| 508 |
+
device_select = gr.Dropdown(
|
| 509 |
+
choices=["GPU", "CPU"],
|
| 510 |
+
value="GPU",
|
| 511 |
+
label="Device",
|
| 512 |
+
info="Select computation device"
|
| 513 |
+
)
|
| 514 |
+
reload_btn = gr.Button("🔄 Reload Model", size="sm")
|
| 515 |
+
reload_status = gr.Textbox(label="Model Status", value="Model not loaded", interactive=False)
|
| 516 |
|
| 517 |
# Row 4: Generate button
|
| 518 |
generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
|
|
|
|
| 530 |
def set_reference_audio(audio_file):
|
| 531 |
return audio_file
|
| 532 |
|
| 533 |
+
def reload_model_handler(model_name, device_name):
|
| 534 |
+
"""Reload model with selected checkpoint and device"""
|
| 535 |
+
try:
|
| 536 |
+
device = "cuda" if device_name == "GPU" else "cpu"
|
| 537 |
+
load_model(checkpoint=model_name, device=device)
|
| 538 |
+
return f"✅ Model '{model_name}' loaded successfully on {device_name}!"
|
| 539 |
+
except Exception as e:
|
| 540 |
+
return f"❌ Error loading model: {str(e)}"
|
| 541 |
+
|
| 542 |
sample_btn1.click(lambda: set_sample_text(0), outputs=[text_input])
|
| 543 |
sample_btn2.click(lambda: set_sample_text(1), outputs=[text_input])
|
| 544 |
sample_btn3.click(lambda: set_sample_text(2), outputs=[text_input])
|
|
|
|
| 549 |
ref_btn3.click(lambda: set_reference_audio("m1.wav"), outputs=[reference_audio])
|
| 550 |
ref_btn4.click(lambda: set_reference_audio("m2.wav"), outputs=[reference_audio])
|
| 551 |
|
| 552 |
+
reload_btn.click(
|
| 553 |
+
fn=reload_model_handler,
|
| 554 |
+
inputs=[model_select, device_select],
|
| 555 |
+
outputs=[reload_status]
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
def generate_with_progress(text, reference_audio, exaggeration, temperature, cfg_weight, seed, device_name):
|
| 559 |
"""Generate speech with streaming progress updates"""
|
| 560 |
+
use_gpu = (device_name == "GPU")
|
| 561 |
# Use the streaming generator
|
| 562 |
for result_audio, result_status in generate_speech_multi_sentence(
|
| 563 |
+
text, reference_audio, exaggeration, temperature, cfg_weight, seed, use_gpu
|
| 564 |
):
|
| 565 |
yield result_audio, result_status
|
| 566 |
|
| 567 |
generate_btn.click(
|
| 568 |
fn=generate_with_progress,
|
| 569 |
+
inputs=[text_input, reference_audio, exaggeration, temperature, cfg_weight, seed, device_select],
|
| 570 |
outputs=[output_audio, status_message]
|
| 571 |
)
|
| 572 |
|
chatterbox_dhivehi.py
CHANGED
|
@@ -156,7 +156,7 @@ def from_dhivehi(
|
|
| 156 |
*,
|
| 157 |
ckpt_dir: Union[str, Path],
|
| 158 |
device: str = "cpu",
|
| 159 |
-
force_vocab_size: int =
|
| 160 |
):
|
| 161 |
"""
|
| 162 |
Construct a Dhivehi-extended ChatterboxTTS from a checkpoint directory.
|
|
|
|
| 156 |
*,
|
| 157 |
ckpt_dir: Union[str, Path],
|
| 158 |
device: str = "cpu",
|
| 159 |
+
force_vocab_size: int = 2000,
|
| 160 |
):
|
| 161 |
"""
|
| 162 |
Construct a Dhivehi-extended ChatterboxTTS from a checkpoint directory.
|