import re import sys import yaml import numpy as np import librosa import torch import phonemizer import noisereduce as nr from munch import Munch from meldataset import TextCleaner from models import ProsodyPredictor, TextEncoder, StyleEncoder from Modules.hifigan import Decoder # ------------------------- # Windows-only espeak-ng loader # ------------------------- if sys.platform.startswith("win"): try: from phonemizer.backend.espeak.wrapper import EspeakWrapper import espeakng_loader EspeakWrapper.set_library(espeakng_loader.get_library_path()) except Exception as e: print(e) _TOKEN_RE = re.compile(r"\S+") def normalize_phonem_tokens(phonem: str) -> str: return " ".join(_TOKEN_RE.findall((phonem or "").strip())) def espeak_phn(text: str, lang: str) -> str: """ Nếu phonemizer/espeak lỗi -> raise để bạn biết ngay thiếu espeak-ng / libespeak-ng1 / voice 'vi' """ try: backend = phonemizer.backend.EspeakBackend( language=lang, preserve_punctuation=True, with_stress=True, language_switch="remove-flags", ) out = backend.phonemize([text])[0] out = (out or "").strip() if len(out) == 0: raise RuntimeError(f"phonemizer returned empty output for lang='{lang}', text='{text[:50]}'") return out except Exception as e: raise RuntimeError(f"espeak/phonemizer failed (lang={lang}). Error: {e}") class Preprocess: def __text_normalize(self, text): punctuation = [",", "、", "،", ";", "(", ".", "。", "…", "!", "–", ":", "?"] map_to = "." punctuation_pattern = re.compile(f"[{''.join(re.escape(p) for p in punctuation)}]") text = punctuation_pattern.sub(map_to, text) text = re.sub(r"\s+", " ", text).strip() return text def __merge_fragments(self, texts, n): merged = [] i = 0 while i < len(texts): fragment = texts[i] j = i + 1 while len(fragment.split()) < n and j < len(texts): fragment += ", " + texts[j] j += 1 merged.append(fragment) i = j if len(merged) > 1 and len(merged[-1].split()) < n: merged[-2] = merged[-2] + ", " + merged[-1] del merged[-1] return merged def wave_preprocess(self, wave, sr=24000): wave = np.asarray(wave, dtype=np.float32).squeeze() mel = librosa.feature.melspectrogram( y=wave, sr=sr, n_fft=2048, win_length=1200, hop_length=300, n_mels=80, power=2.0, ) # (80, T) mean, std = -4, 4 mel = np.log(1e-5 + mel) mel = (mel - mean) / std return torch.from_numpy(mel).float().unsqueeze(0) # (1, 80, T) def text_preprocess(self, text, n_merge=12): text_norm = self.__text_normalize(text).split(".") text_norm = [s.strip() for s in text_norm if s.strip()] return self.__merge_fragments(text_norm, n=n_merge) def length_to_mask(self, lengths): mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) return torch.gt(mask + 1, lengths.unsqueeze(1)) class StyleTTS2(torch.nn.Module): def __init__(self, config_path, models_path): super().__init__() self.register_buffer("get_device", torch.empty(0)) self.preprocess = Preprocess() config = yaml.safe_load(open(config_path, "r", encoding="utf-8")) symbols = ( list(config["symbol"]["pad"]) + list(config["symbol"]["punctuation"]) + list(config["symbol"]["letters"]) + list(config["symbol"]["letters_ipa"]) + list(config["symbol"]["extend"]) ) symbol_dict = {s: i for i, s in enumerate(symbols)} n_token = len(symbol_dict) + 1 print("\nFound:", n_token, "symbols") args = self.__recursive_munch(config["model_params"]) args["n_token"] = n_token self.cleaner = TextCleaner(symbol_dict, debug=True) self.decoder = Decoder( dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels, resblock_kernel_sizes=args.decoder.resblock_kernel_sizes, upsample_rates=args.decoder.upsample_rates, upsample_initial_channel=args.decoder.upsample_initial_channel, resblock_dilation_sizes=args.decoder.resblock_dilation_sizes, upsample_kernel_sizes=args.decoder.upsample_kernel_sizes, ) self.predictor = ProsodyPredictor( style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout, ) self.text_encoder = TextEncoder( channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token, ) self.style_encoder = StyleEncoder( dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim, ) n_speakers = config["data_params"]["n_speakers"] self.spk_emb = torch.nn.Embedding( n_speakers, args.style_dim ) self.spk_ln = torch.nn.LayerNorm(args.style_dim) self.__load_models(models_path) def text_to_sequence_char_level(text, symbol_dict): seq = [] for ch in text: if ch == " ": continue if ch in symbol_dict: seq.append(symbol_dict[ch]) else: print("[WARN] dropped char:", repr(ch)) return seq def __recursive_munch(self, d): if isinstance(d, dict): return Munch((k, self.__recursive_munch(v)) for k, v in d.items()) if isinstance(d, list): return [self.__recursive_munch(v) for v in d] return d def __replace_outliers_zscore(self, tensor, threshold=3.0, factor=0.95): mean = tensor.mean() std = tensor.std() z = (tensor - mean) / (std + 1e-8) outlier_mask = torch.abs(z) > threshold sign = torch.sign(tensor - mean) replacement = mean + sign * (threshold * std * factor) result = tensor.clone() result[outlier_mask] = replacement[outlier_mask] return result def __load_models(self, models_path): # model = { # "decoder": self.decoder, # "predictor": self.predictor, # "text_encoder": self.text_encoder, # "style_encoder": self.style_encoder, # } model = { "decoder": self.decoder, "predictor": self.predictor, "text_encoder": self.text_encoder, "style_encoder": self.style_encoder, "spk_emb": self.spk_emb, "spk_ln": self.spk_ln, } params_whole = torch.load(models_path, map_location="cpu") params = params_whole["net"] params = {k: v for k, v in params.items() if k in model} for k in model: try: model[k].load_state_dict(params[k]) except Exception: from collections import OrderedDict new_state_dict = OrderedDict() for kk, vv in params[k].items(): new_state_dict[kk[7:]] = vv # strip "module." model[k].load_state_dict(new_state_dict, strict=False) print(k, ":", sum(p.numel() for p in model[k].parameters())) def __compute_style(self, path, denoise, split_dur): device = self.get_device.device denoise = min(float(denoise), 1.0) split_dur = int(split_dur) if split_dur else 0 wave, sr = librosa.load(path, sr=24000) audio, _ = librosa.effects.trim(wave, top_db=30) if denoise > 0.0: audio_denoise = nr.reduce_noise( y=audio, sr=sr, n_fft=2048, win_length=1200, hop_length=300 ) audio = audio * (1 - denoise) + audio_denoise * denoise with torch.no_grad(): if split_dur > 0 and len(audio) / sr >= 4: jump = sr * split_dur total_len = len(audio) ref_s = None count = 0 for i in range(0, total_len, jump): seg = audio[i : min(i + jump, total_len)] if len(seg) < sr: # <1s thì bỏ continue mel = self.preprocess.wave_preprocess(seg).to(device) s = self.style_encoder(mel.unsqueeze(1)) ref_s = s if ref_s is None else (ref_s + s) count += 1 if ref_s is None: mel = self.preprocess.wave_preprocess(audio).to(device) ref_s = self.style_encoder(mel.unsqueeze(1)) else: ref_s = ref_s / count else: mel = self.preprocess.wave_preprocess(audio).to(device) ref_s = self.style_encoder(mel.unsqueeze(1)) return ref_s def __inference(self, phonem, ref_s, speed=1.0, prev_d_mean=0.0, t=0.1): device = self.get_device.device tokens = self.cleaner(phonem) print("[DBG] token_len =", len(tokens)) print("[DBG] phonem_head =", phonem[:80]) if len(tokens) == 0: raise RuntimeError("Token sequence is empty!") tokens = [0] + tokens + [0] tokens = torch.LongTensor(tokens).unsqueeze(0).to(device) print("\n========== TOKEN DEBUG ==========") print("Max token id:", tokens.max().item()) print("Min token id:", tokens.min().item()) print("First 50 tokens:", tokens[0][:50].tolist()) print("=================================\n") with torch.no_grad(): input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) text_mask = self.preprocess.length_to_mask(input_lengths).to(device) t_en = self.text_encoder(tokens, input_lengths, text_mask) s = ref_s.to(device) if hasattr(self, "spk_emb"): spk_id = torch.LongTensor([self.current_speaker_id]).to(device) spk_vec = self.spk_emb(spk_id) # [1, style_dim] spk_vec = self.spk_ln(spk_vec) # merge speaker embedding into style s = s + spk_vec print("🎤 Speaker embedding injected:", self.current_speaker_id) print("🎤 Style vec mean/std:", s.mean().item(), s.std().item()) d = self.predictor.text_encoder(t_en, s, input_lengths, text_mask) x, _ = self.predictor.lstm(d) duration = self.predictor.duration_proj(x) duration = torch.sigmoid(duration).sum(dim=-1) if prev_d_mean != 0: dur_stats = torch.empty_like(duration).normal_(mean=prev_d_mean, std=duration.std() + 1e-8).to(device) else: dur_stats = torch.empty_like(duration).normal_(mean=duration.mean(), std=duration.std() + 1e-8).to(device) duration = duration * (1 - t) + dur_stats * t duration[:, 1:-2] = self.__replace_outliers_zscore(duration[:, 1:-2]) duration = duration / speed pred_dur = torch.round(duration.squeeze(0)).clamp(min=1) L = int(input_lengths.item()) T = int(pred_dur.sum().item()) pred_aln_trg = torch.zeros((L, T), device=device) c = 0 for i in range(L): di = int(pred_dur[i].item()) pred_aln_trg[i, c : c + di] = 1 c += di alignment = pred_aln_trg.unsqueeze(0) en = d.transpose(-1, -2) @ alignment F0_pred, N_pred = self.predictor.F0Ntrain(en, s) asr = t_en @ pred_aln_trg.unsqueeze(0) out = self.decoder(asr, F0_pred, N_pred, s) return out.squeeze().cpu().numpy(), float(duration.mean().item()) def get_styles(self, speakers, denoise=0.3, avg_style=True): split_dur = 2 if avg_style else 0 styles = {} for sid, meta in speakers.items(): ref_s = self.__compute_style(meta["path"], denoise=denoise, split_dur=split_dur) styles[sid] = { "style": ref_s, "path": meta["path"], "lang": meta["lang"], "speed": meta["speed"], } return styles def generate(self, text, styles, stabilize=True, n_merge=16, default_speaker="[id_1]"): smooth_value = 0.2 if stabilize else 0.0 list_wav = [] prev_d_mean = 0.0 text = re.sub(r"[\n\r\t\f\v]", "", text) # split by speaker tags parts = re.split(r"(\[id_\d+\])", text) if len(parts) <= 1 or re.match(r"(\[id_\d+\])", parts[0]) is None: parts.insert(0, default_speaker) speaker_tag = None # "id_1" speaker_id = None # int current_ref_s = None speed = 1.0 for p in parts: # ----------------------------- # Parse speaker tag # ----------------------------- if re.match(r"(\[id_\d+\])", p): speaker_tag = p.strip("[]") # "id_1" speaker_id = int(speaker_tag.replace("id_", "")) # expose speaker id for inference self.current_speaker_id = speaker_id if speaker_tag not in styles: raise RuntimeError(f"Speaker {speaker_tag} not found in styles!") current_ref_s = styles[speaker_tag]["style"] speed = styles[speaker_tag]["speed"] print("🎤 Active speaker:", speaker_tag, "->", speaker_id) continue if not p.strip(): continue if speaker_id is None: raise RuntimeError("Speaker ID chưa được set! Kiểm tra tag [id_x].") # ----------------------------- # Text → phoneme → waveform # ----------------------------- for sentence in self.preprocess.text_preprocess(p, n_merge=n_merge): phonem = espeak_phn(sentence, styles[speaker_tag]["lang"]) wav, prev_d_mean = self.__inference( phonem, current_ref_s, speed=speed, prev_d_mean=prev_d_mean, t=smooth_value, ) # ----------------------------- # Debug # ----------------------------- print("[DBG] wav shape:", wav.shape) print("[DBG] wav min/max:", wav.min().item(), wav.max().item()) print("[DBG] wav mean abs:", np.abs(wav).mean()) # ----------------------------- # Safe trim # ----------------------------- trim = int(0.05 * 24000) # 50 ms if wav.shape[0] > 4 * trim: wav = wav[trim:-trim] if wav.size > 0: list_wav.append(wav) # ----------------------------- # Merge all chunks # ----------------------------- if len(list_wav) == 0: print("⚠️ No audio generated → return silence") return np.zeros((2400,), dtype=np.float32) final_wav = np.concatenate(list_wav) # pad head & tail pad = int(0.05 * 24000) final_wav = np.concatenate( [np.zeros((pad,), dtype=np.float32), final_wav, np.zeros((pad,), dtype=np.float32)] ) return final_wav