danneauxs commited on
Commit
3aa3268
·
1 Parent(s): 131e31b

added random seed setting

Browse files
config/config.py CHANGED
@@ -22,7 +22,7 @@ MIN_CHUNK_WORDS = 4
22
  # ============================================================================
23
  # WORKER AND PERFORMANCE SETTINGS
24
  # ============================================================================
25
- MAX_WORKERS = 1
26
  TEST_MAX_WORKERS = 6 # For experimentation
27
  USE_DYNAMIC_WORKERS = False # Toggle for testing
28
  VRAM_SAFETY_THRESHOLD = 6.5 # GB
@@ -140,6 +140,7 @@ CYAN = "\033[96m"
140
  DEFAULT_EXAGGERATION = 0.5
141
  DEFAULT_CFG_WEIGHT = 0.5
142
  DEFAULT_TEMPERATURE = 0.85
 
143
 
144
  # Advanced Sampling Parameters (Min_P Sampler Support)
145
  DEFAULT_MIN_P = 0.05 # Min probability threshold (0.0 disables)
@@ -185,6 +186,46 @@ TTS_PARAM_MAX_TOP_P = 1.0 # MAX 1.0 disables top_p
185
  TTS_PARAM_MIN_REPETITION_PENALTY = 1.0 # 1.0 = no penalty
186
  TTS_PARAM_MAX_REPETITION_PENALTY = 2.0 # Higher values too restrictive MAX 2
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  # ============================================================================
189
  # BATCH PROCESSING SETTINGS
190
  # ============================================================================
 
22
  # ============================================================================
23
  # WORKER AND PERFORMANCE SETTINGS
24
  # ============================================================================
25
+ MAX_WORKERS = 2
26
  TEST_MAX_WORKERS = 6 # For experimentation
27
  USE_DYNAMIC_WORKERS = False # Toggle for testing
28
  VRAM_SAFETY_THRESHOLD = 6.5 # GB
 
140
  DEFAULT_EXAGGERATION = 0.5
141
  DEFAULT_CFG_WEIGHT = 0.5
142
  DEFAULT_TEMPERATURE = 0.85
143
+ DEFAULT_SEED = 0 # Random seed for generation. 0 means random.
144
 
145
  # Advanced Sampling Parameters (Min_P Sampler Support)
146
  DEFAULT_MIN_P = 0.05 # Min probability threshold (0.0 disables)
 
186
  TTS_PARAM_MIN_REPETITION_PENALTY = 1.0 # 1.0 = no penalty
187
  TTS_PARAM_MAX_REPETITION_PENALTY = 2.0 # Higher values too restrictive MAX 2
188
 
189
+ # ============================================================================
190
+ # TTS_PRESETS
191
+ # ============================================================================
192
+ TTS_PRESETS = {
193
+ "Narration": {
194
+ "exaggeration": 0.5,
195
+ "cfg_weight": 0.5,
196
+ "temperature": 0.85,
197
+ "min_p": 0.05,
198
+ "top_p": 1.0,
199
+ "repetition_penalty": 1.2,
200
+ "vader_enabled": True, # Default to VADER on for nuanced presets
201
+ "sentiment_smoothing": True,
202
+ "smoothing_window": 3,
203
+ "smoothing_method": "rolling"
204
+ },
205
+ "Expressive": {
206
+ "exaggeration": 0.65,
207
+ "cfg_weight": 0.7,
208
+ "temperature": 0.95,
209
+ "min_p": 0.05,
210
+ "top_p": 1.0,
211
+ "repetition_penalty": 1.2,
212
+ "vader_enabled": True,
213
+ "sentiment_smoothing": True,
214
+ "smoothing_window": 3,
215
+ "smoothing_method": "rolling"
216
+ },
217
+ "Exposition": {
218
+ "exaggeration": 0.4,
219
+ "cfg_weight": 0.6,
220
+ "temperature": 0.75,
221
+ "min_p": 0.05,
222
+ "top_p": 1.0,
223
+ "repetition_penalty": 1.2,
224
+ "vader_enabled": False, # VADER off for consistent, clear delivery
225
+ "sentiment_smoothing": False
226
+ }
227
+ }
228
+
229
  # ============================================================================
230
  # BATCH PROCESSING SETTINGS
231
  # ============================================================================
gradio_tabs/tab1_convert_book.py CHANGED
@@ -355,6 +355,17 @@ def create_convert_book_tab():
355
  with gr.Column(scale=1):
356
  gr.Markdown("### ⚙️ Quick Settings")
357
 
 
 
 
 
 
 
 
 
 
 
 
358
  # VADER and ASR
359
  vader_enabled = gr.Checkbox(
360
  label="Use VADER sentiment analysis",
@@ -545,6 +556,14 @@ def create_convert_book_tab():
545
  value=16, # Default value
546
  info="Number of chunks to process simultaneously when VADER is disabled for speed."
547
  )
 
 
 
 
 
 
 
 
548
 
549
  # Action Buttons and Status
550
  with gr.Row():
@@ -792,7 +811,7 @@ def create_convert_book_tab():
792
  sentiment_smooth_val, smooth_window_val, smooth_method_val,
793
  mfcc_val, output_val, spectral_thresh_val, output_thresh_val,
794
  exag_val, cfg_val, temp_val, min_p_val, top_p_val, rep_penalty_val,
795
- tts_batch_size_val):
796
  """Start the actual book conversion - file upload version"""
797
 
798
  # Validation
@@ -896,7 +915,8 @@ def create_convert_book_tab():
896
  'asr_enabled': asr_val,
897
  'asr_config': asr_config,
898
  'add_to_batch': add_to_batch_val,
899
- 'tts_batch_size': tts_batch_size_val
 
900
  }
901
 
902
  # Set conversion state
@@ -987,7 +1007,7 @@ def create_convert_book_tab():
987
  sentiment_smoothing, smoothing_window, smoothing_method,
988
  mfcc_validation, output_validation, spectral_threshold, output_threshold,
989
  exaggeration, cfg_weight, temperature, min_p, top_p, repetition_penalty,
990
- tts_batch_size
991
  ],
992
  outputs=[status_display, progress_display, audio_player, audiobook_selector, m4b_file_selector]
993
  )
@@ -1019,6 +1039,44 @@ def create_convert_book_tab():
1019
  inputs=[m4b_file_selector, playback_speed],
1020
  outputs=[status_display, audio_player, audiobook_selector, m4b_file_selector]
1021
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1022
 
1023
  # Progress monitoring with file-based approach
1024
  def get_current_stats():
 
355
  with gr.Column(scale=1):
356
  gr.Markdown("### ⚙️ Quick Settings")
357
 
358
+ # NEW: Presets
359
+ with gr.Row():
360
+ preset_dropdown = gr.Dropdown(
361
+ label="Load Preset",
362
+ choices=list(TTS_PRESETS.keys()),
363
+ value="Narration",
364
+ interactive=True,
365
+ info="Apply predefined TTS parameter settings."
366
+ )
367
+ apply_preset_btn = gr.Button("Apply Preset", size="sm", variant="secondary")
368
+
369
  # VADER and ASR
370
  vader_enabled = gr.Checkbox(
371
  label="Use VADER sentiment analysis",
 
556
  value=16, # Default value
557
  info="Number of chunks to process simultaneously when VADER is disabled for speed."
558
  )
559
+
560
+ # NEW: Random Seed
561
+ seed = gr.Number(
562
+ label="Random Seed (0 for random)",
563
+ minimum=0, maximum=999999999, step=1,
564
+ value=0, # Default value
565
+ info="Set a seed for reproducible generation. 0 means random."
566
+ )
567
 
568
  # Action Buttons and Status
569
  with gr.Row():
 
811
  sentiment_smooth_val, smooth_window_val, smooth_method_val,
812
  mfcc_val, output_val, spectral_thresh_val, output_thresh_val,
813
  exag_val, cfg_val, temp_val, min_p_val, top_p_val, rep_penalty_val,
814
+ tts_batch_size_val, seed_val):
815
  """Start the actual book conversion - file upload version"""
816
 
817
  # Validation
 
915
  'asr_enabled': asr_val,
916
  'asr_config': asr_config,
917
  'add_to_batch': add_to_batch_val,
918
+ 'tts_batch_size': tts_batch_size_val,
919
+ 'seed': seed_val
920
  }
921
 
922
  # Set conversion state
 
1007
  sentiment_smoothing, smoothing_window, smoothing_method,
1008
  mfcc_validation, output_validation, spectral_threshold, output_threshold,
1009
  exaggeration, cfg_weight, temperature, min_p, top_p, repetition_penalty,
1010
+ tts_batch_size, seed
1011
  ],
1012
  outputs=[status_display, progress_display, audio_player, audiobook_selector, m4b_file_selector]
1013
  )
 
1039
  inputs=[m4b_file_selector, playback_speed],
1040
  outputs=[status_display, audio_player, audiobook_selector, m4b_file_selector]
1041
  )
1042
+
1043
+ # NEW: Apply Preset Function
1044
+ def apply_preset(preset_name):
1045
+ if preset_name not in TTS_PRESETS:
1046
+ return gr.update() # No change if preset not found
1047
+
1048
+ preset = TTS_PRESETS[preset_name]
1049
+
1050
+ return (
1051
+ gr.update(value=preset.get("vader_enabled", True)),
1052
+ gr.update(value=preset.get("sentiment_smoothing", True)),
1053
+ gr.update(value=preset.get("smoothing_window", 3)),
1054
+ gr.update(value=preset.get("smoothing_method", "rolling")),
1055
+ gr.update(value=preset.get("exaggeration", DEFAULT_EXAGGERATION)),
1056
+ gr.update(value=preset.get("cfg_weight", DEFAULT_CFG_WEIGHT)),
1057
+ gr.update(value=preset.get("temperature", DEFAULT_TEMPERATURE)),
1058
+ gr.update(value=preset.get("min_p", DEFAULT_MIN_P)),
1059
+ gr.update(value=preset.get("top_p", DEFAULT_TOP_P)),
1060
+ gr.update(value=preset.get("repetition_penalty", DEFAULT_REPETITION_PENALTY)),
1061
+ )
1062
+
1063
+ # Connect apply_preset_btn
1064
+ apply_preset_btn.click(
1065
+ apply_preset,
1066
+ inputs=[preset_dropdown],
1067
+ outputs=[
1068
+ vader_enabled,
1069
+ sentiment_smoothing,
1070
+ smoothing_window,
1071
+ smoothing_method,
1072
+ exaggeration,
1073
+ cfg_weight,
1074
+ temperature,
1075
+ min_p,
1076
+ top_p,
1077
+ repetition_penalty,
1078
+ ]
1079
+ )
1080
 
1081
  # Progress monitoring with file-based approach
1082
  def get_current_stats():
modules/tts_engine.py CHANGED
@@ -64,6 +64,27 @@ YELLOW = '\033[93m'
64
  CYAN = '\033[96m'
65
  RESET = '\033[0m'
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  # ============================================================================
68
  # MEMORY AND MODEL MANAGEMENT
69
  # ============================================================================
@@ -234,9 +255,11 @@ def process_batch(
234
  batch, text_chunks_dir, audio_chunks_dir,
235
  voice_path, tts_params, start_time, total_chunks,
236
  punc_norm, basename, log_run_func, log_path, device,
237
- model, asr_model, all_chunks,
238
  enable_asr=None
239
  ):
 
 
240
  """
241
  Process a batch of chunks using the batch-enabled TTS model.
242
  """
@@ -309,9 +332,11 @@ def process_one_chunk(
309
  i, chunk, text_chunks_dir, audio_chunks_dir,
310
  voice_path, tts_params, start_time, total_chunks,
311
  punc_norm, basename, log_run_func, log_path, device,
312
- model, asr_model, boundary_type="none",
313
  enable_asr=None
314
  ):
 
 
315
  """Enhanced chunk processing with quality control, contextual silence, and deep cleanup"""
316
  import difflib
317
  from pydub import AudioSegment
 
64
  CYAN = '\033[96m'
65
  RESET = '\033[0m'
66
 
67
+ import random
68
+ import numpy as np
69
+ import torch
70
+
71
+ def set_seed(seed_value: int):
72
+ """
73
+ Sets the seed for torch, random, and numpy for reproducibility.
74
+ This is called if a non-zero seed is provided for generation.
75
+ """
76
+ torch.manual_seed(seed_value)
77
+ if torch.cuda.is_available():
78
+ torch.cuda.manual_seed(seed_value)
79
+ torch.cuda.manual_seed_all(seed_value) # if using multi-GPU
80
+ if torch.backends.mps.is_available():
81
+ # Check if torch.mps exists before calling
82
+ if hasattr(torch, 'mps') and torch.mps.is_available():
83
+ torch.mps.manual_seed(seed_value)
84
+ random.seed(seed_value)
85
+ np.random.seed(seed_value)
86
+ logging.info(f"Global seed set to: {seed_value}")
87
+
88
  # ============================================================================
89
  # MEMORY AND MODEL MANAGEMENT
90
  # ============================================================================
 
255
  batch, text_chunks_dir, audio_chunks_dir,
256
  voice_path, tts_params, start_time, total_chunks,
257
  punc_norm, basename, log_run_func, log_path, device,
258
+ model, asr_model, seed=0,
259
  enable_asr=None
260
  ):
261
+ if seed != 0:
262
+ set_seed(seed)
263
  """
264
  Process a batch of chunks using the batch-enabled TTS model.
265
  """
 
332
  i, chunk, text_chunks_dir, audio_chunks_dir,
333
  voice_path, tts_params, start_time, total_chunks,
334
  punc_norm, basename, log_run_func, log_path, device,
335
+ model, asr_model, seed=0, boundary_type="none",
336
  enable_asr=None
337
  ):
338
+ if seed != 0:
339
+ set_seed(seed)
340
  """Enhanced chunk processing with quality control, contextual silence, and deep cleanup"""
341
  import difflib
342
  from pydub import AudioSegment
src/chatterbox/tts.py CHANGED
@@ -294,6 +294,7 @@ class ChatterboxTTS:
294
  min_p=0.05,
295
  top_p=0.8,
296
  repetition_penalty=2.0,
 
297
  ):
298
  if audio_prompt_path:
299
  self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration)
@@ -360,6 +361,7 @@ class ChatterboxTTS:
360
  min_p=0.05,
361
  top_p=0.8,
362
  repetition_penalty=2.0,
 
363
  ):
364
  if audio_prompt_path:
365
  self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration)
 
294
  min_p=0.05,
295
  top_p=0.8,
296
  repetition_penalty=2.0,
297
+ seed=0,
298
  ):
299
  if audio_prompt_path:
300
  self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration)
 
361
  min_p=0.05,
362
  top_p=0.8,
363
  repetition_penalty=2.0,
364
+ seed=0,
365
  ):
366
  if audio_prompt_path:
367
  self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration)