Kuangwei Chen commited on
Commit
ebd7ee5
·
1 Parent(s): 2d87bb9

Add normalization toggles to Space UI

Browse files
Files changed (3) hide show
  1. app.py +23 -4
  2. requirements.txt +1 -0
  3. text_normalization_pipeline.py +32 -16
app.py CHANGED
@@ -26,7 +26,7 @@ except ImportError:
26
  spaces = _SpacesFallback()
27
 
28
  from nano_tts_runtime import DEFAULT_VOICE, NanoTTSService, build_default_voice_presets
29
- from text_normalization_pipeline import prepare_tts_request_texts
30
 
31
  APP_DIR = Path(__file__).resolve().parent
32
  CHECKPOINT_PATH = APP_DIR / "weights" / "tts"
@@ -162,6 +162,11 @@ def get_runtime_tts_service() -> NanoTTSService:
162
  return get_tts_service(bool(torch.cuda.is_available()))
163
 
164
 
 
 
 
 
 
165
  def preload_service() -> None:
166
  started_at = time.monotonic()
167
  service = get_runtime_tts_service()
@@ -346,6 +351,8 @@ def run_inference(
346
  voice: str,
347
  prompt_audio_path: str | None,
348
  selected_demo_audio_path: str | None,
 
 
349
  max_new_frames: int,
350
  voice_clone_max_text_tokens: int,
351
  do_sample: bool,
@@ -361,6 +368,7 @@ def run_inference(
361
  generated_audio_path: str | None = None
362
  try:
363
  service = get_runtime_tts_service()
 
364
  effective_prompt_audio_path = resolve_effective_prompt_audio_path(
365
  voice=voice,
366
  prompt_audio_path=prompt_audio_path,
@@ -374,8 +382,9 @@ def run_inference(
374
  text=normalized_text,
375
  prompt_text="",
376
  voice=voice,
377
- enable_wetext=False,
378
- text_normalizer_manager=None,
 
379
  )
380
  prompt_source = build_prompt_source_text(
381
  voice=voice,
@@ -471,10 +480,18 @@ def build_demo():
471
  )
472
 
473
  gr.Markdown(
474
- "Robust text normalization is always on. Runtime device and backbone are fixed by the Space and are not user-configurable. Uploaded reference audio overrides the selected example case."
475
  )
476
 
477
  with gr.Accordion("Advanced Parameters", open=False):
 
 
 
 
 
 
 
 
478
  max_new_frames = gr.Slider(
479
  minimum=64,
480
  maximum=512,
@@ -589,6 +606,8 @@ def build_demo():
589
  voice,
590
  prompt_audio,
591
  selected_demo_audio_path,
 
 
592
  max_new_frames,
593
  voice_clone_max_text_tokens,
594
  do_sample,
 
26
  spaces = _SpacesFallback()
27
 
28
  from nano_tts_runtime import DEFAULT_VOICE, NanoTTSService, build_default_voice_presets
29
+ from text_normalization_pipeline import WeTextProcessingManager, prepare_tts_request_texts
30
 
31
  APP_DIR = Path(__file__).resolve().parent
32
  CHECKPOINT_PATH = APP_DIR / "weights" / "tts"
 
162
  return get_tts_service(bool(torch.cuda.is_available()))
163
 
164
 
165
+ @functools.lru_cache(maxsize=1)
166
+ def get_text_normalizer_manager() -> WeTextProcessingManager:
167
+ return WeTextProcessingManager()
168
+
169
+
170
  def preload_service() -> None:
171
  started_at = time.monotonic()
172
  service = get_runtime_tts_service()
 
351
  voice: str,
352
  prompt_audio_path: str | None,
353
  selected_demo_audio_path: str | None,
354
+ enable_wetext_processing: bool,
355
+ enable_normalize_tts_text: bool,
356
  max_new_frames: int,
357
  voice_clone_max_text_tokens: int,
358
  do_sample: bool,
 
368
  generated_audio_path: str | None = None
369
  try:
370
  service = get_runtime_tts_service()
371
+ text_normalizer_manager = get_text_normalizer_manager() if enable_wetext_processing else None
372
  effective_prompt_audio_path = resolve_effective_prompt_audio_path(
373
  voice=voice,
374
  prompt_audio_path=prompt_audio_path,
 
382
  text=normalized_text,
383
  prompt_text="",
384
  voice=voice,
385
+ enable_wetext=bool(enable_wetext_processing),
386
+ enable_normalize_tts_text=bool(enable_normalize_tts_text),
387
+ text_normalizer_manager=text_normalizer_manager,
388
  )
389
  prompt_source = build_prompt_source_text(
390
  voice=voice,
 
480
  )
481
 
482
  gr.Markdown(
483
+ "Runtime device and backbone are fixed by the Space and are not user-configurable. Uploaded reference audio overrides the selected example case."
484
  )
485
 
486
  with gr.Accordion("Advanced Parameters", open=False):
487
+ enable_wetext_processing = gr.Checkbox(
488
+ value=False,
489
+ label="Enable WeTextProcessing",
490
+ )
491
+ enable_normalize_tts_text = gr.Checkbox(
492
+ value=True,
493
+ label="Enable normalize_tts_text",
494
+ )
495
  max_new_frames = gr.Slider(
496
  minimum=64,
497
  maximum=512,
 
606
  voice,
607
  prompt_audio,
608
  selected_demo_audio_path,
609
+ enable_wetext_processing,
610
+ enable_normalize_tts_text,
611
  max_new_frames,
612
  voice_clone_max_text_tokens,
613
  do_sample,
requirements.txt CHANGED
@@ -7,3 +7,4 @@ safetensors>=0.4.3
7
  soundfile>=0.13.1
8
  gradio==6.5.1
9
  spaces
 
 
7
  soundfile>=0.13.1
8
  gradio==6.5.1
9
  spaces
10
+ WeTextProcessing>=1.0.4.1
text_normalization_pipeline.py CHANGED
@@ -135,6 +135,7 @@ def prepare_tts_request_texts(
135
  prompt_text: str,
136
  voice: str,
137
  enable_wetext: bool,
 
138
  text_normalizer_manager: WeTextProcessingManager | None,
139
  ) -> dict[str, object]:
140
  raw_text = str(text or "")
@@ -168,28 +169,43 @@ def prepare_tts_request_texts(
168
  normalization_language,
169
  )
170
 
171
- final_text = normalize_tts_text(intermediate_text)
172
- final_prompt_text = normalize_tts_text(intermediate_prompt_text) if intermediate_prompt_text else ""
 
173
 
174
- if final_text != intermediate_text:
175
- logging.info(
176
- "normalized text chars_before=%d chars_after=%d stage=robust_final",
177
- len(intermediate_text),
178
- len(final_text),
179
- )
180
- if intermediate_prompt_text and final_prompt_text != intermediate_prompt_text:
181
- logging.info(
182
- "normalized prompt_text chars_before=%d chars_after=%d stage=robust_final",
183
- len(intermediate_prompt_text),
184
- len(final_prompt_text),
185
- )
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  return {
188
  "text": final_text,
189
  "prompt_text": final_prompt_text,
190
  "normalized_text": final_text,
191
  "normalized_prompt_text": final_prompt_text,
192
- "normalization_method": (f"wetext:{normalization_language}+robust" if enable_wetext else "robust"),
193
  "text_normalization_language": normalization_language,
194
- "text_normalization_enabled": bool(enable_wetext),
 
 
195
  }
 
135
  prompt_text: str,
136
  voice: str,
137
  enable_wetext: bool,
138
+ enable_normalize_tts_text: bool,
139
  text_normalizer_manager: WeTextProcessingManager | None,
140
  ) -> dict[str, object]:
141
  raw_text = str(text or "")
 
169
  normalization_language,
170
  )
171
 
172
+ if enable_normalize_tts_text:
173
+ final_text = normalize_tts_text(intermediate_text)
174
+ final_prompt_text = normalize_tts_text(intermediate_prompt_text) if intermediate_prompt_text else ""
175
 
176
+ if final_text != intermediate_text:
177
+ logging.info(
178
+ "normalized text chars_before=%d chars_after=%d stage=robust_final",
179
+ len(intermediate_text),
180
+ len(final_text),
181
+ )
182
+ if intermediate_prompt_text and final_prompt_text != intermediate_prompt_text:
183
+ logging.info(
184
+ "normalized prompt_text chars_before=%d chars_after=%d stage=robust_final",
185
+ len(intermediate_prompt_text),
186
+ len(final_prompt_text),
187
+ )
188
+ else:
189
+ final_text = intermediate_text
190
+ final_prompt_text = intermediate_prompt_text
191
+
192
+ if enable_wetext and enable_normalize_tts_text:
193
+ normalization_method = f"wetext:{normalization_language}+robust"
194
+ elif enable_wetext:
195
+ normalization_method = f"wetext:{normalization_language}"
196
+ elif enable_normalize_tts_text:
197
+ normalization_method = "robust"
198
+ else:
199
+ normalization_method = "none"
200
 
201
  return {
202
  "text": final_text,
203
  "prompt_text": final_prompt_text,
204
  "normalized_text": final_text,
205
  "normalized_prompt_text": final_prompt_text,
206
+ "normalization_method": normalization_method,
207
  "text_normalization_language": normalization_language,
208
+ "text_normalization_enabled": bool(enable_wetext or enable_normalize_tts_text),
209
+ "wetext_enabled": bool(enable_wetext),
210
+ "normalize_tts_text_enabled": bool(enable_normalize_tts_text),
211
  }