from __future__ import annotations import logging from pathlib import Path import re import threading from dataclasses import dataclass from tts_robust_normalizer_single_script import normalize_tts_text ENGLISH_VOICES = frozenset({"Trump", "Ava", "Bella", "Adam", "Nathan"}) CUSTOM_ZH_WETEXT_CACHE_DIR = Path(__file__).resolve().parent / ".cache" / "wetext_zh_no_erhua_keep_punct" _ZH_WETEXT_KEEP_HYPHEN = "___KEEP_HYPHEN_BEFORE_ZH_WETEXT___" @dataclass(frozen=True) class TextNormalizationSnapshot: state: str message: str error: str | None = None ready: bool = False available: bool = False @property def failed(self) -> bool: return self.state == "failed" class WeTextProcessingManager: def __init__(self) -> None: self._lock = threading.Lock() self._normalize_lock = threading.Lock() self._thread: threading.Thread | None = None self._started = False self._state = "pending" self._message = "Waiting for WeTextProcessing preload." self._error: str | None = None self._available = True self._normalizers: dict[str, object] | None = None def snapshot(self) -> TextNormalizationSnapshot: with self._lock: return TextNormalizationSnapshot( state=self._state, message=self._message, error=self._error, ready=self._state == "ready", available=self._available, ) def _set_state(self, *, state: str, message: str, error: str | None = None) -> None: with self._lock: self._state = state self._message = message self._error = error def start(self) -> None: with self._lock: if self._started: return self._started = True self._thread = threading.Thread(target=self._run, name="wetext-preload", daemon=True) self._thread.start() def ensure_ready(self) -> TextNormalizationSnapshot: with self._lock: if not self._started: self._started = True self._thread = threading.Thread(target=self._run, name="wetext-preload", daemon=True) self._thread.start() thread = self._thread if thread is not None and thread.is_alive(): thread.join() return self.snapshot() def close(self) -> None: return def _run(self) -> None: if not self._available: self._set_state( state="failed", message="WeTextProcessing unavailable.", error="installed WeTextProcessing modules are unavailable", ) return try: self._set_state(state="running", message="Loading WeTextProcessing graphs.", error=None) self._ensure_normalizers_loaded() self._set_state(state="ready", message="WeTextProcessing ready. languages=zh,en", error=None) except Exception as exc: logging.exception("WeTextProcessing preload failed") self._set_state(state="failed", message="WeTextProcessing preload failed.", error=str(exc)) def _ensure_normalizers_loaded(self) -> dict[str, object]: with self._lock: if self._normalizers is not None: return self._normalizers from tn.chinese.normalizer import Normalizer as ZhNormalizer from tn.english.normalizer import Normalizer as EnNormalizer logging.getLogger().setLevel(logging.INFO) self._normalizers = { "zh": ZhNormalizer( cache_dir=str(CUSTOM_ZH_WETEXT_CACHE_DIR), overwrite_cache=False, remove_interjections=False, remove_erhua=False, full_to_half=False, ), "en": EnNormalizer(overwrite_cache=False), } return self._normalizers def normalize(self, *, text: str, prompt_text: str, language: str) -> tuple[str, str]: snapshot = self.ensure_ready() if not snapshot.ready: raise RuntimeError(snapshot.error or snapshot.message) with self._normalize_lock: normalizers = self._ensure_normalizers_loaded() if language not in normalizers: raise ValueError(f"Unsupported text normalization language: {language}") normalizer = normalizers[language] normalized_text = normalizer.normalize(text) if text else "" normalized_prompt_text = normalizer.normalize(prompt_text) if prompt_text else "" return normalized_text, normalized_prompt_text def resolve_text_normalization_language(*, text: str, voice: str) -> str: if re.search(r"[\u3400-\u9fff]", text): return "zh" if re.search(r"[A-Za-z]", text): return "en" if voice in ENGLISH_VOICES: return "en" return "zh" def _rewrite_hyphens_before_zh_wetext(text: str) -> str: """Avoid Chinese WeText reading non-numeric hyphens as '减'.""" rewritten = str(text or "") if "-" not in rewritten: return rewritten # Preserve start-of-text negatives like `-2`. rewritten = re.sub( r"(^\s*)-\s*(?=\d)", rf"\1{_ZH_WETEXT_KEEP_HYPHEN}", rewritten, ) # Preserve negatives after common delimiters like `x=-2` or `(-2)`. rewritten = re.sub( r"([=:+*/,(,::;;(【\[{])\s*-\s*(?=\d)", rf"\1{_ZH_WETEXT_KEEP_HYPHEN}", rewritten, ) # Preserve Chinese-context negatives like `为-2` / `计算-2`. rewritten = re.sub( r"([\u3400-\u9fff])\s*-\s*(?=\d)", rf"\1{_ZH_WETEXT_KEEP_HYPHEN}", rewritten, ) # Preserve numeric ranges/dates like `10-3` / `2024-05-01`. rewritten = re.sub( r"(\d)\s*-\s*(?=\d)", rf"\1{_ZH_WETEXT_KEEP_HYPHEN}", rewritten, ) # Chinese compound phrases sound more natural with a pause boundary. rewritten = re.sub( r"([\u3400-\u9fff])\s*-\s*(?=[\u3400-\u9fff])", r"\1,", rewritten, ) # Remaining token-internal hyphens are flattened to spaces for zh WeText. rewritten = re.sub( r"([^\s-])\s*-\s*(?=[^\s-])", r"\1 ", rewritten, ) rewritten = re.sub(r" {2,}", " ", rewritten).strip() return rewritten.replace(_ZH_WETEXT_KEEP_HYPHEN, "-") def prepare_tts_request_texts( *, text: str, prompt_text: str = "", voice: str = "", enable_wetext: bool, enable_normalize_tts_text: bool = True, text_normalizer_manager: WeTextProcessingManager | None, ) -> dict[str, object]: raw_text = str(text or "") raw_prompt_text = str(prompt_text or "") normalization_stages: list[str] = [] normalization_language = "" intermediate_text = raw_text intermediate_prompt_text = raw_prompt_text if enable_normalize_tts_text and enable_wetext: pre_robust_text = normalize_tts_text(raw_text) pre_robust_prompt_text = normalize_tts_text(raw_prompt_text) if raw_prompt_text else "" if pre_robust_text != raw_text: logging.info( "normalized text chars_before=%d chars_after=%d stage=robust_pre", len(raw_text), len(pre_robust_text), ) if raw_prompt_text and pre_robust_prompt_text != raw_prompt_text: logging.info( "normalized prompt_text chars_before=%d chars_after=%d stage=robust_pre", len(raw_prompt_text), len(pre_robust_prompt_text), ) intermediate_text = pre_robust_text intermediate_prompt_text = pre_robust_prompt_text normalization_stages.append("robust_pre") if enable_wetext: if text_normalizer_manager is None: raise RuntimeError("WeTextProcessing manager is unavailable.") wetext_input_text = intermediate_text wetext_input_prompt_text = intermediate_prompt_text normalization_language = resolve_text_normalization_language(text=wetext_input_text, voice=voice) if normalization_language == "zh": rewritten_wetext_input_text = _rewrite_hyphens_before_zh_wetext(wetext_input_text) rewritten_wetext_input_prompt_text = _rewrite_hyphens_before_zh_wetext(wetext_input_prompt_text) if rewritten_wetext_input_text != wetext_input_text: logging.info( "rewrote zh wetext text hyphens chars_before=%d chars_after=%d stage=zh_wetext_hyphen_guard", len(wetext_input_text), len(rewritten_wetext_input_text), ) if wetext_input_prompt_text and rewritten_wetext_input_prompt_text != wetext_input_prompt_text: logging.info( "rewrote zh wetext prompt_text hyphens chars_before=%d chars_after=%d stage=zh_wetext_hyphen_guard", len(wetext_input_prompt_text), len(rewritten_wetext_input_prompt_text), ) wetext_input_text = rewritten_wetext_input_text wetext_input_prompt_text = rewritten_wetext_input_prompt_text intermediate_text, intermediate_prompt_text = text_normalizer_manager.normalize( text=wetext_input_text, prompt_text=wetext_input_prompt_text, language=normalization_language, ) if intermediate_text != wetext_input_text: logging.info( "normalized text chars_before=%d chars_after=%d stage=wetext language=%s", len(wetext_input_text), len(intermediate_text), normalization_language, ) if wetext_input_prompt_text and intermediate_prompt_text != wetext_input_prompt_text: logging.info( "normalized prompt_text chars_before=%d chars_after=%d stage=wetext language=%s", len(wetext_input_prompt_text), len(intermediate_prompt_text), normalization_language, ) normalization_stages.append(f"wetext:{normalization_language}" if normalization_language else "wetext") final_text = intermediate_text final_prompt_text = intermediate_prompt_text if enable_normalize_tts_text: final_text = normalize_tts_text(intermediate_text) final_prompt_text = normalize_tts_text(intermediate_prompt_text) if intermediate_prompt_text else "" robust_stage_name = "robust_post" if enable_wetext else "robust" if final_text != intermediate_text: logging.info( "normalized text chars_before=%d chars_after=%d stage=%s", len(intermediate_text), len(final_text), robust_stage_name, ) if intermediate_prompt_text and final_prompt_text != intermediate_prompt_text: logging.info( "normalized prompt_text chars_before=%d chars_after=%d stage=%s", len(intermediate_prompt_text), len(final_prompt_text), robust_stage_name, ) normalization_stages.append(robust_stage_name) return { "text": final_text, "prompt_text": final_prompt_text, "normalized_text": final_text, "normalized_prompt_text": final_prompt_text, "normalization_method": "+".join(normalization_stages) if normalization_stages else "none", "text_normalization_language": normalization_language, "text_normalization_enabled": bool(enable_wetext or enable_normalize_tts_text), "wetext_processing_enabled": bool(enable_wetext), "normalize_tts_text_enabled": bool(enable_normalize_tts_text), }