Spaces:
Sleeping
Sleeping
| import re | |
| import sys | |
| import yaml | |
| import numpy as np | |
| import librosa | |
| import torch | |
| import phonemizer | |
| import noisereduce as nr | |
| from munch import Munch | |
| from meldataset import TextCleaner | |
| from models import ProsodyPredictor, TextEncoder, StyleEncoder | |
| from Modules.hifigan import Decoder | |
| # ------------------------- | |
| # Windows-only espeak-ng loader | |
| # ------------------------- | |
| if sys.platform.startswith("win"): | |
| try: | |
| from phonemizer.backend.espeak.wrapper import EspeakWrapper | |
| import espeakng_loader | |
| EspeakWrapper.set_library(espeakng_loader.get_library_path()) | |
| except Exception as e: | |
| print(e) | |
| _TOKEN_RE = re.compile(r"\S+") | |
| def normalize_phonem_tokens(phonem: str) -> str: | |
| return " ".join(_TOKEN_RE.findall((phonem or "").strip())) | |
| def espeak_phn(text: str, lang: str) -> str: | |
| """ | |
| Nếu phonemizer/espeak lỗi -> raise để bạn biết ngay thiếu espeak-ng / libespeak-ng1 / voice 'vi' | |
| """ | |
| try: | |
| backend = phonemizer.backend.EspeakBackend( | |
| language=lang, | |
| preserve_punctuation=True, | |
| with_stress=True, | |
| language_switch="remove-flags", | |
| ) | |
| out = backend.phonemize([text])[0] | |
| out = (out or "").strip() | |
| if len(out) == 0: | |
| raise RuntimeError(f"phonemizer returned empty output for lang='{lang}', text='{text[:50]}'") | |
| return out | |
| except Exception as e: | |
| raise RuntimeError(f"espeak/phonemizer failed (lang={lang}). Error: {e}") | |
| class Preprocess: | |
| def __text_normalize(self, text): | |
| punctuation = [",", "、", "،", ";", "(", ".", "。", "…", "!", "–", ":", "?"] | |
| map_to = "." | |
| punctuation_pattern = re.compile(f"[{''.join(re.escape(p) for p in punctuation)}]") | |
| text = punctuation_pattern.sub(map_to, text) | |
| text = re.sub(r"\s+", " ", text).strip() | |
| return text | |
| def __merge_fragments(self, texts, n): | |
| merged = [] | |
| i = 0 | |
| while i < len(texts): | |
| fragment = texts[i] | |
| j = i + 1 | |
| while len(fragment.split()) < n and j < len(texts): | |
| fragment += ", " + texts[j] | |
| j += 1 | |
| merged.append(fragment) | |
| i = j | |
| if len(merged) > 1 and len(merged[-1].split()) < n: | |
| merged[-2] = merged[-2] + ", " + merged[-1] | |
| del merged[-1] | |
| return merged | |
| def wave_preprocess(self, wave, sr=24000): | |
| wave = np.asarray(wave, dtype=np.float32).squeeze() | |
| mel = librosa.feature.melspectrogram( | |
| y=wave, | |
| sr=sr, | |
| n_fft=2048, | |
| win_length=1200, | |
| hop_length=300, | |
| n_mels=80, | |
| power=2.0, | |
| ) # (80, T) | |
| mean, std = -4, 4 | |
| mel = np.log(1e-5 + mel) | |
| mel = (mel - mean) / std | |
| return torch.from_numpy(mel).float().unsqueeze(0) # (1, 80, T) | |
| def text_preprocess(self, text, n_merge=12): | |
| text_norm = self.__text_normalize(text).split(".") | |
| text_norm = [s.strip() for s in text_norm if s.strip()] | |
| return self.__merge_fragments(text_norm, n=n_merge) | |
| def length_to_mask(self, lengths): | |
| mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) | |
| return torch.gt(mask + 1, lengths.unsqueeze(1)) | |
| class StyleTTS2(torch.nn.Module): | |
| def __init__(self, config_path, models_path): | |
| super().__init__() | |
| self.register_buffer("get_device", torch.empty(0)) | |
| self.preprocess = Preprocess() | |
| config = yaml.safe_load(open(config_path, "r", encoding="utf-8")) | |
| symbols = ( | |
| list(config["symbol"]["pad"]) | |
| + list(config["symbol"]["punctuation"]) | |
| + list(config["symbol"]["letters"]) | |
| + list(config["symbol"]["letters_ipa"]) | |
| + list(config["symbol"]["extend"]) | |
| ) | |
| symbol_dict = {s: i for i, s in enumerate(symbols)} | |
| n_token = len(symbol_dict) + 1 | |
| print("\nFound:", n_token, "symbols") | |
| args = self.__recursive_munch(config["model_params"]) | |
| args["n_token"] = n_token | |
| self.cleaner = TextCleaner(symbol_dict, debug=True) | |
| self.decoder = Decoder( | |
| dim_in=args.hidden_dim, | |
| style_dim=args.style_dim, | |
| dim_out=args.n_mels, | |
| resblock_kernel_sizes=args.decoder.resblock_kernel_sizes, | |
| upsample_rates=args.decoder.upsample_rates, | |
| upsample_initial_channel=args.decoder.upsample_initial_channel, | |
| resblock_dilation_sizes=args.decoder.resblock_dilation_sizes, | |
| upsample_kernel_sizes=args.decoder.upsample_kernel_sizes, | |
| ) | |
| self.predictor = ProsodyPredictor( | |
| style_dim=args.style_dim, | |
| d_hid=args.hidden_dim, | |
| nlayers=args.n_layer, | |
| max_dur=args.max_dur, | |
| dropout=args.dropout, | |
| ) | |
| self.text_encoder = TextEncoder( | |
| channels=args.hidden_dim, | |
| kernel_size=5, | |
| depth=args.n_layer, | |
| n_symbols=args.n_token, | |
| ) | |
| self.style_encoder = StyleEncoder( | |
| dim_in=args.dim_in, | |
| style_dim=args.style_dim, | |
| max_conv_dim=args.hidden_dim, | |
| ) | |
| n_speakers = config["data_params"]["n_speakers"] | |
| self.spk_emb = torch.nn.Embedding( | |
| n_speakers, | |
| args.style_dim | |
| ) | |
| self.spk_ln = torch.nn.LayerNorm(args.style_dim) | |
| self.__load_models(models_path) | |
| def text_to_sequence_char_level(text, symbol_dict): | |
| seq = [] | |
| for ch in text: | |
| if ch == " ": | |
| continue | |
| if ch in symbol_dict: | |
| seq.append(symbol_dict[ch]) | |
| else: | |
| print("[WARN] dropped char:", repr(ch)) | |
| return seq | |
| def __recursive_munch(self, d): | |
| if isinstance(d, dict): | |
| return Munch((k, self.__recursive_munch(v)) for k, v in d.items()) | |
| if isinstance(d, list): | |
| return [self.__recursive_munch(v) for v in d] | |
| return d | |
| def __replace_outliers_zscore(self, tensor, threshold=3.0, factor=0.95): | |
| mean = tensor.mean() | |
| std = tensor.std() | |
| z = (tensor - mean) / (std + 1e-8) | |
| outlier_mask = torch.abs(z) > threshold | |
| sign = torch.sign(tensor - mean) | |
| replacement = mean + sign * (threshold * std * factor) | |
| result = tensor.clone() | |
| result[outlier_mask] = replacement[outlier_mask] | |
| return result | |
| def __load_models(self, models_path): | |
| # model = { | |
| # "decoder": self.decoder, | |
| # "predictor": self.predictor, | |
| # "text_encoder": self.text_encoder, | |
| # "style_encoder": self.style_encoder, | |
| # } | |
| model = { | |
| "decoder": self.decoder, | |
| "predictor": self.predictor, | |
| "text_encoder": self.text_encoder, | |
| "style_encoder": self.style_encoder, | |
| "spk_emb": self.spk_emb, | |
| "spk_ln": self.spk_ln, | |
| } | |
| params_whole = torch.load(models_path, map_location="cpu") | |
| params = params_whole["net"] | |
| params = {k: v for k, v in params.items() if k in model} | |
| for k in model: | |
| try: | |
| model[k].load_state_dict(params[k]) | |
| except Exception: | |
| from collections import OrderedDict | |
| new_state_dict = OrderedDict() | |
| for kk, vv in params[k].items(): | |
| new_state_dict[kk[7:]] = vv # strip "module." | |
| model[k].load_state_dict(new_state_dict, strict=False) | |
| print(k, ":", sum(p.numel() for p in model[k].parameters())) | |
| def __compute_style(self, path, denoise, split_dur): | |
| device = self.get_device.device | |
| denoise = min(float(denoise), 1.0) | |
| split_dur = int(split_dur) if split_dur else 0 | |
| wave, sr = librosa.load(path, sr=24000) | |
| audio, _ = librosa.effects.trim(wave, top_db=30) | |
| if denoise > 0.0: | |
| audio_denoise = nr.reduce_noise( | |
| y=audio, sr=sr, n_fft=2048, win_length=1200, hop_length=300 | |
| ) | |
| audio = audio * (1 - denoise) + audio_denoise * denoise | |
| with torch.no_grad(): | |
| if split_dur > 0 and len(audio) / sr >= 4: | |
| jump = sr * split_dur | |
| total_len = len(audio) | |
| ref_s = None | |
| count = 0 | |
| for i in range(0, total_len, jump): | |
| seg = audio[i : min(i + jump, total_len)] | |
| if len(seg) < sr: # <1s thì bỏ | |
| continue | |
| mel = self.preprocess.wave_preprocess(seg).to(device) | |
| s = self.style_encoder(mel.unsqueeze(1)) | |
| ref_s = s if ref_s is None else (ref_s + s) | |
| count += 1 | |
| if ref_s is None: | |
| mel = self.preprocess.wave_preprocess(audio).to(device) | |
| ref_s = self.style_encoder(mel.unsqueeze(1)) | |
| else: | |
| ref_s = ref_s / count | |
| else: | |
| mel = self.preprocess.wave_preprocess(audio).to(device) | |
| ref_s = self.style_encoder(mel.unsqueeze(1)) | |
| return ref_s | |
| def __inference(self, phonem, ref_s, speed=1.0, prev_d_mean=0.0, t=0.1): | |
| device = self.get_device.device | |
| tokens = self.cleaner(phonem) | |
| print("[DBG] token_len =", len(tokens)) | |
| print("[DBG] phonem_head =", phonem[:80]) | |
| if len(tokens) == 0: | |
| raise RuntimeError("Token sequence is empty!") | |
| tokens = [0] + tokens + [0] | |
| tokens = torch.LongTensor(tokens).unsqueeze(0).to(device) | |
| print("\n========== TOKEN DEBUG ==========") | |
| print("Max token id:", tokens.max().item()) | |
| print("Min token id:", tokens.min().item()) | |
| print("First 50 tokens:", tokens[0][:50].tolist()) | |
| print("=================================\n") | |
| with torch.no_grad(): | |
| input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) | |
| text_mask = self.preprocess.length_to_mask(input_lengths).to(device) | |
| t_en = self.text_encoder(tokens, input_lengths, text_mask) | |
| s = ref_s.to(device) | |
| if hasattr(self, "spk_emb"): | |
| spk_id = torch.LongTensor([self.current_speaker_id]).to(device) | |
| spk_vec = self.spk_emb(spk_id) # [1, style_dim] | |
| spk_vec = self.spk_ln(spk_vec) | |
| # merge speaker embedding into style | |
| s = s + spk_vec | |
| print("🎤 Speaker embedding injected:", self.current_speaker_id) | |
| print("🎤 Style vec mean/std:", s.mean().item(), s.std().item()) | |
| d = self.predictor.text_encoder(t_en, s, input_lengths, text_mask) | |
| x, _ = self.predictor.lstm(d) | |
| duration = self.predictor.duration_proj(x) | |
| duration = torch.sigmoid(duration).sum(dim=-1) | |
| if prev_d_mean != 0: | |
| dur_stats = torch.empty_like(duration).normal_(mean=prev_d_mean, std=duration.std() + 1e-8).to(device) | |
| else: | |
| dur_stats = torch.empty_like(duration).normal_(mean=duration.mean(), std=duration.std() + 1e-8).to(device) | |
| duration = duration * (1 - t) + dur_stats * t | |
| duration[:, 1:-2] = self.__replace_outliers_zscore(duration[:, 1:-2]) | |
| duration = duration / speed | |
| pred_dur = torch.round(duration.squeeze(0)).clamp(min=1) | |
| L = int(input_lengths.item()) | |
| T = int(pred_dur.sum().item()) | |
| pred_aln_trg = torch.zeros((L, T), device=device) | |
| c = 0 | |
| for i in range(L): | |
| di = int(pred_dur[i].item()) | |
| pred_aln_trg[i, c : c + di] = 1 | |
| c += di | |
| alignment = pred_aln_trg.unsqueeze(0) | |
| en = d.transpose(-1, -2) @ alignment | |
| F0_pred, N_pred = self.predictor.F0Ntrain(en, s) | |
| asr = t_en @ pred_aln_trg.unsqueeze(0) | |
| out = self.decoder(asr, F0_pred, N_pred, s) | |
| return out.squeeze().cpu().numpy(), float(duration.mean().item()) | |
| def get_styles(self, speakers, denoise=0.3, avg_style=True): | |
| split_dur = 2 if avg_style else 0 | |
| styles = {} | |
| for sid, meta in speakers.items(): | |
| ref_s = self.__compute_style(meta["path"], denoise=denoise, split_dur=split_dur) | |
| styles[sid] = { | |
| "style": ref_s, | |
| "path": meta["path"], | |
| "lang": meta["lang"], | |
| "speed": meta["speed"], | |
| } | |
| return styles | |
| def generate(self, text, styles, stabilize=True, n_merge=16, default_speaker="[id_1]"): | |
| smooth_value = 0.2 if stabilize else 0.0 | |
| list_wav = [] | |
| prev_d_mean = 0.0 | |
| text = re.sub(r"[\n\r\t\f\v]", "", text) | |
| # split by speaker tags | |
| parts = re.split(r"(\[id_\d+\])", text) | |
| if len(parts) <= 1 or re.match(r"(\[id_\d+\])", parts[0]) is None: | |
| parts.insert(0, default_speaker) | |
| speaker_tag = None # "id_1" | |
| speaker_id = None # int | |
| current_ref_s = None | |
| speed = 1.0 | |
| for p in parts: | |
| # ----------------------------- | |
| # Parse speaker tag | |
| # ----------------------------- | |
| if re.match(r"(\[id_\d+\])", p): | |
| speaker_tag = p.strip("[]") # "id_1" | |
| speaker_id = int(speaker_tag.replace("id_", "")) | |
| # expose speaker id for inference | |
| self.current_speaker_id = speaker_id | |
| if speaker_tag not in styles: | |
| raise RuntimeError(f"Speaker {speaker_tag} not found in styles!") | |
| current_ref_s = styles[speaker_tag]["style"] | |
| speed = styles[speaker_tag]["speed"] | |
| print("🎤 Active speaker:", speaker_tag, "->", speaker_id) | |
| continue | |
| if not p.strip(): | |
| continue | |
| if speaker_id is None: | |
| raise RuntimeError("Speaker ID chưa được set! Kiểm tra tag [id_x].") | |
| # ----------------------------- | |
| # Text → phoneme → waveform | |
| # ----------------------------- | |
| for sentence in self.preprocess.text_preprocess(p, n_merge=n_merge): | |
| phonem = espeak_phn(sentence, styles[speaker_tag]["lang"]) | |
| wav, prev_d_mean = self.__inference( | |
| phonem, | |
| current_ref_s, | |
| speed=speed, | |
| prev_d_mean=prev_d_mean, | |
| t=smooth_value, | |
| ) | |
| # ----------------------------- | |
| # Debug | |
| # ----------------------------- | |
| print("[DBG] wav shape:", wav.shape) | |
| print("[DBG] wav min/max:", wav.min().item(), wav.max().item()) | |
| print("[DBG] wav mean abs:", np.abs(wav).mean()) | |
| # ----------------------------- | |
| # Safe trim | |
| # ----------------------------- | |
| trim = int(0.05 * 24000) # 50 ms | |
| if wav.shape[0] > 4 * trim: | |
| wav = wav[trim:-trim] | |
| if wav.size > 0: | |
| list_wav.append(wav) | |
| # ----------------------------- | |
| # Merge all chunks | |
| # ----------------------------- | |
| if len(list_wav) == 0: | |
| print("⚠️ No audio generated → return silence") | |
| return np.zeros((2400,), dtype=np.float32) | |
| final_wav = np.concatenate(list_wav) | |
| # pad head & tail | |
| pad = int(0.05 * 24000) | |
| final_wav = np.concatenate( | |
| [np.zeros((pad,), dtype=np.float32), final_wav, np.zeros((pad,), dtype=np.float32)] | |
| ) | |
| return final_wav | |