alakxender commited on
Commit
71b9145
·
1 Parent(s): 281d8d1
Files changed (2) hide show
  1. app.py +101 -28
  2. chatterbox_dhivehi.py +1 -1
app.py CHANGED
@@ -59,16 +59,18 @@ def download_model():
59
  print(f"Warning: Could not download model files: {e}")
60
  print("=" * 60)
61
 
62
- def load_model(checkpoint=f"{_target}/kn_cbox"):
63
  """Load the TTS model"""
64
  global MODEL
65
  try:
66
- print(f"Loading model with checkpoint: {checkpoint}")
 
 
67
  MODEL = ChatterboxTTS.from_dhivehi(
68
- ckpt_dir=Path(checkpoint),
69
- device="cuda" if torch.cuda.is_available() else "cpu"
70
  )
71
- print("Model loaded successfully!")
72
  except Exception as e:
73
  print(f"Error loading model: {e}")
74
  raise e
@@ -82,14 +84,14 @@ def set_seed(seed: int):
82
  random.seed(seed)
83
  np.random.seed(seed)
84
 
85
- @spaces.GPU(duration=60)
86
- def generate_speech(text,
87
- reference_audio,
88
- exaggeration=0.5,
89
- temperature=0.1,
90
- cfg_weight=0.5,
91
- seed=42):
92
- """Generate speech from text using voice cloning"""
93
  global MODEL
94
 
95
  # Clean the input text
@@ -161,6 +163,25 @@ def generate_speech(text,
161
  print(error_msg)
162
  return None, error_msg
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  def clean_text(text):
165
  """Clean text by removing newlines at start/end, double spaces, and extra whitespace"""
166
  import re
@@ -224,14 +245,15 @@ def split_sentences(text):
224
 
225
  return final_sentences
226
 
227
- @spaces.GPU
228
- def generate_speech_multi_sentence(text,
229
- reference_audio,
230
- exaggeration=0.5,
231
- temperature=0.1,
232
- cfg_weight=0.5,
233
- seed=42):
234
- """Generate speech from text with multi-sentence support and progress tracking"""
 
235
  global MODEL
236
 
237
  # Clean the input text
@@ -251,7 +273,7 @@ def generate_speech_multi_sentence(text,
251
  # If only one sentence or no periods, use regular method
252
  if len(sentences) <= 1:
253
  yield None, "Generating single sentence..."
254
- result_audio, result_status = generate_speech(text, reference_audio, exaggeration, temperature, cfg_weight, seed)
255
  yield result_audio, result_status
256
  return
257
 
@@ -360,12 +382,32 @@ def generate_speech_multi_sentence(text,
360
  print(error_msg)
361
  yield None, error_msg
362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  def create_interface():
364
  """Create the Gradio interface"""
365
 
366
- # Load the model
367
- load_model()
368
-
369
  # Sample texts in Dhivehi
370
  sample_texts = [
371
  "ކާޑު ނުލައި ފައިސާ ދެއްކޭ ނެޝަނަލް ކިއުއާރް ކޯޑް އެމްއެމްއޭ އިން ތައާރަފްކުރަނީ",
@@ -456,6 +498,21 @@ The ministry handed over the land reclamation, replacement of the port canal and
456
  label="Seed",
457
  info="For reproducible results"
458
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
 
460
  # Row 4: Generate button
461
  generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
@@ -473,6 +530,15 @@ The ministry handed over the land reclamation, replacement of the port canal and
473
  def set_reference_audio(audio_file):
474
  return audio_file
475
 
 
 
 
 
 
 
 
 
 
476
  sample_btn1.click(lambda: set_sample_text(0), outputs=[text_input])
477
  sample_btn2.click(lambda: set_sample_text(1), outputs=[text_input])
478
  sample_btn3.click(lambda: set_sample_text(2), outputs=[text_input])
@@ -483,17 +549,24 @@ The ministry handed over the land reclamation, replacement of the port canal and
483
  ref_btn3.click(lambda: set_reference_audio("m1.wav"), outputs=[reference_audio])
484
  ref_btn4.click(lambda: set_reference_audio("m2.wav"), outputs=[reference_audio])
485
 
486
- def generate_with_progress(text, reference_audio, exaggeration, temperature, cfg_weight, seed):
 
 
 
 
 
 
487
  """Generate speech with streaming progress updates"""
 
488
  # Use the streaming generator
489
  for result_audio, result_status in generate_speech_multi_sentence(
490
- text, reference_audio, exaggeration, temperature, cfg_weight, seed
491
  ):
492
  yield result_audio, result_status
493
 
494
  generate_btn.click(
495
  fn=generate_with_progress,
496
- inputs=[text_input, reference_audio, exaggeration, temperature, cfg_weight, seed],
497
  outputs=[output_audio, status_message]
498
  )
499
 
 
59
  print(f"Warning: Could not download model files: {e}")
60
  print("=" * 60)
61
 
62
+ def load_model(checkpoint="kn_cbox", device="cuda"):
63
  """Load the TTS model"""
64
  global MODEL
65
  try:
66
+ checkpoint_path = f"{_target}/{checkpoint}"
67
+ print(f"Loading model with checkpoint: {checkpoint_path}")
68
+ print(f"Target device: {device}")
69
  MODEL = ChatterboxTTS.from_dhivehi(
70
+ ckpt_dir=Path(checkpoint_path),
71
+ device=device
72
  )
73
+ print(f"Model loaded successfully on {device}!")
74
  except Exception as e:
75
  print(f"Error loading model: {e}")
76
  raise e
 
84
  random.seed(seed)
85
  np.random.seed(seed)
86
 
87
+ # Internal implementation without decorator
88
+ def _generate_speech_impl(text,
89
+ reference_audio,
90
+ exaggeration=0.5,
91
+ temperature=0.1,
92
+ cfg_weight=0.5,
93
+ seed=42):
94
+ """Internal implementation of generate speech"""
95
  global MODEL
96
 
97
  # Clean the input text
 
163
  print(error_msg)
164
  return None, error_msg
165
 
166
+ # GPU version with decorator
167
+ @spaces.GPU(duration=60)
168
+ def _generate_speech_gpu(text, reference_audio, exaggeration=0.5, temperature=0.1, cfg_weight=0.5, seed=42):
169
+ """GPU version of generate speech"""
170
+ return _generate_speech_impl(text, reference_audio, exaggeration, temperature, cfg_weight, seed)
171
+
172
+ # CPU version without decorator
173
+ def _generate_speech_cpu(text, reference_audio, exaggeration=0.5, temperature=0.1, cfg_weight=0.5, seed=42):
174
+ """CPU version of generate speech"""
175
+ return _generate_speech_impl(text, reference_audio, exaggeration, temperature, cfg_weight, seed)
176
+
177
+ # Router function
178
+ def generate_speech(text, reference_audio, exaggeration=0.5, temperature=0.1, cfg_weight=0.5, seed=42, use_gpu=True):
179
+ """Generate speech from text using voice cloning"""
180
+ if use_gpu:
181
+ return _generate_speech_gpu(text, reference_audio, exaggeration, temperature, cfg_weight, seed)
182
+ else:
183
+ return _generate_speech_cpu(text, reference_audio, exaggeration, temperature, cfg_weight, seed)
184
+
185
  def clean_text(text):
186
  """Clean text by removing newlines at start/end, double spaces, and extra whitespace"""
187
  import re
 
245
 
246
  return final_sentences
247
 
248
+ # Internal implementation without decorator
249
+ def _generate_speech_multi_sentence_impl(text,
250
+ reference_audio,
251
+ exaggeration=0.5,
252
+ temperature=0.1,
253
+ cfg_weight=0.5,
254
+ seed=42,
255
+ use_gpu=True):
256
+ """Internal implementation of multi-sentence speech generation"""
257
  global MODEL
258
 
259
  # Clean the input text
 
273
  # If only one sentence or no periods, use regular method
274
  if len(sentences) <= 1:
275
  yield None, "Generating single sentence..."
276
+ result_audio, result_status = generate_speech(text, reference_audio, exaggeration, temperature, cfg_weight, seed, use_gpu)
277
  yield result_audio, result_status
278
  return
279
 
 
382
  print(error_msg)
383
  yield None, error_msg
384
 
385
+ # GPU version with decorator
386
+ @spaces.GPU
387
+ def _generate_speech_multi_sentence_gpu(text, reference_audio, exaggeration=0.5, temperature=0.1, cfg_weight=0.5, seed=42):
388
+ """GPU version of multi-sentence speech generation"""
389
+ for result in _generate_speech_multi_sentence_impl(text, reference_audio, exaggeration, temperature, cfg_weight, seed, use_gpu=True):
390
+ yield result
391
+
392
+ # CPU version without decorator
393
+ def _generate_speech_multi_sentence_cpu(text, reference_audio, exaggeration=0.5, temperature=0.1, cfg_weight=0.5, seed=42):
394
+ """CPU version of multi-sentence speech generation"""
395
+ for result in _generate_speech_multi_sentence_impl(text, reference_audio, exaggeration, temperature, cfg_weight, seed, use_gpu=False):
396
+ yield result
397
+
398
+ # Router function
399
+ def generate_speech_multi_sentence(text, reference_audio, exaggeration=0.5, temperature=0.1, cfg_weight=0.5, seed=42, use_gpu=True):
400
+ """Generate speech from text with multi-sentence support and progress tracking"""
401
+ if use_gpu:
402
+ for result in _generate_speech_multi_sentence_gpu(text, reference_audio, exaggeration, temperature, cfg_weight, seed):
403
+ yield result
404
+ else:
405
+ for result in _generate_speech_multi_sentence_cpu(text, reference_audio, exaggeration, temperature, cfg_weight, seed):
406
+ yield result
407
+
408
  def create_interface():
409
  """Create the Gradio interface"""
410
 
 
 
 
411
  # Sample texts in Dhivehi
412
  sample_texts = [
413
  "ކާޑު ނުލައި ފައިސާ ދެއްކޭ ނެޝަނަލް ކިއުއާރް ކޯޑް އެމްއެމްއޭ އިން ތައާރަފްކުރަނީ",
 
498
  label="Seed",
499
  info="For reproducible results"
500
  )
501
+ with gr.Row():
502
+ model_select = gr.Dropdown(
503
+ choices=["kn_cbox", "f01_cbox"],
504
+ value="kn_cbox",
505
+ label="Model",
506
+ info="Select TTS model"
507
+ )
508
+ device_select = gr.Dropdown(
509
+ choices=["GPU", "CPU"],
510
+ value="GPU",
511
+ label="Device",
512
+ info="Select computation device"
513
+ )
514
+ reload_btn = gr.Button("🔄 Reload Model", size="sm")
515
+ reload_status = gr.Textbox(label="Model Status", value="Model not loaded", interactive=False)
516
 
517
  # Row 4: Generate button
518
  generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
 
530
  def set_reference_audio(audio_file):
531
  return audio_file
532
 
533
+ def reload_model_handler(model_name, device_name):
534
+ """Reload model with selected checkpoint and device"""
535
+ try:
536
+ device = "cuda" if device_name == "GPU" else "cpu"
537
+ load_model(checkpoint=model_name, device=device)
538
+ return f"✅ Model '{model_name}' loaded successfully on {device_name}!"
539
+ except Exception as e:
540
+ return f"❌ Error loading model: {str(e)}"
541
+
542
  sample_btn1.click(lambda: set_sample_text(0), outputs=[text_input])
543
  sample_btn2.click(lambda: set_sample_text(1), outputs=[text_input])
544
  sample_btn3.click(lambda: set_sample_text(2), outputs=[text_input])
 
549
  ref_btn3.click(lambda: set_reference_audio("m1.wav"), outputs=[reference_audio])
550
  ref_btn4.click(lambda: set_reference_audio("m2.wav"), outputs=[reference_audio])
551
 
552
+ reload_btn.click(
553
+ fn=reload_model_handler,
554
+ inputs=[model_select, device_select],
555
+ outputs=[reload_status]
556
+ )
557
+
558
+ def generate_with_progress(text, reference_audio, exaggeration, temperature, cfg_weight, seed, device_name):
559
  """Generate speech with streaming progress updates"""
560
+ use_gpu = (device_name == "GPU")
561
  # Use the streaming generator
562
  for result_audio, result_status in generate_speech_multi_sentence(
563
+ text, reference_audio, exaggeration, temperature, cfg_weight, seed, use_gpu
564
  ):
565
  yield result_audio, result_status
566
 
567
  generate_btn.click(
568
  fn=generate_with_progress,
569
+ inputs=[text_input, reference_audio, exaggeration, temperature, cfg_weight, seed, device_select],
570
  outputs=[output_audio, status_message]
571
  )
572
 
chatterbox_dhivehi.py CHANGED
@@ -156,7 +156,7 @@ def from_dhivehi(
156
  *,
157
  ckpt_dir: Union[str, Path],
158
  device: str = "cpu",
159
- force_vocab_size: int = 2500,
160
  ):
161
  """
162
  Construct a Dhivehi-extended ChatterboxTTS from a checkpoint directory.
 
156
  *,
157
  ckpt_dir: Union[str, Path],
158
  device: str = "cpu",
159
+ force_vocab_size: int = 2000,
160
  ):
161
  """
162
  Construct a Dhivehi-extended ChatterboxTTS from a checkpoint directory.