ttsStyleTTS2 / inference.py
stephenhoang's picture
Update inference.py
33bb00d verified
import re
import sys
import yaml
import numpy as np
import librosa
import torch
import phonemizer
import noisereduce as nr
from munch import Munch
from meldataset import TextCleaner
from models import ProsodyPredictor, TextEncoder, StyleEncoder
from Modules.hifigan import Decoder
# -------------------------
# Windows-only espeak-ng loader
# -------------------------
if sys.platform.startswith("win"):
try:
from phonemizer.backend.espeak.wrapper import EspeakWrapper
import espeakng_loader
EspeakWrapper.set_library(espeakng_loader.get_library_path())
except Exception as e:
print(e)
_TOKEN_RE = re.compile(r"\S+")
def normalize_phonem_tokens(phonem: str) -> str:
return " ".join(_TOKEN_RE.findall((phonem or "").strip()))
def espeak_phn(text: str, lang: str) -> str:
"""
Nếu phonemizer/espeak lỗi -> raise để bạn biết ngay thiếu espeak-ng / libespeak-ng1 / voice 'vi'
"""
try:
backend = phonemizer.backend.EspeakBackend(
language=lang,
preserve_punctuation=True,
with_stress=True,
language_switch="remove-flags",
)
out = backend.phonemize([text])[0]
out = (out or "").strip()
if len(out) == 0:
raise RuntimeError(f"phonemizer returned empty output for lang='{lang}', text='{text[:50]}'")
return out
except Exception as e:
raise RuntimeError(f"espeak/phonemizer failed (lang={lang}). Error: {e}")
class Preprocess:
def __text_normalize(self, text):
punctuation = [",", "、", "،", ";", "(", ".", "。", "…", "!", "–", ":", "?"]
map_to = "."
punctuation_pattern = re.compile(f"[{''.join(re.escape(p) for p in punctuation)}]")
text = punctuation_pattern.sub(map_to, text)
text = re.sub(r"\s+", " ", text).strip()
return text
def __merge_fragments(self, texts, n):
merged = []
i = 0
while i < len(texts):
fragment = texts[i]
j = i + 1
while len(fragment.split()) < n and j < len(texts):
fragment += ", " + texts[j]
j += 1
merged.append(fragment)
i = j
if len(merged) > 1 and len(merged[-1].split()) < n:
merged[-2] = merged[-2] + ", " + merged[-1]
del merged[-1]
return merged
def wave_preprocess(self, wave, sr=24000):
wave = np.asarray(wave, dtype=np.float32).squeeze()
mel = librosa.feature.melspectrogram(
y=wave,
sr=sr,
n_fft=2048,
win_length=1200,
hop_length=300,
n_mels=80,
power=2.0,
) # (80, T)
mean, std = -4, 4
mel = np.log(1e-5 + mel)
mel = (mel - mean) / std
return torch.from_numpy(mel).float().unsqueeze(0) # (1, 80, T)
def text_preprocess(self, text, n_merge=12):
text_norm = self.__text_normalize(text).split(".")
text_norm = [s.strip() for s in text_norm if s.strip()]
return self.__merge_fragments(text_norm, n=n_merge)
def length_to_mask(self, lengths):
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
return torch.gt(mask + 1, lengths.unsqueeze(1))
class StyleTTS2(torch.nn.Module):
def __init__(self, config_path, models_path):
super().__init__()
self.register_buffer("get_device", torch.empty(0))
self.preprocess = Preprocess()
config = yaml.safe_load(open(config_path, "r", encoding="utf-8"))
symbols = (
list(config["symbol"]["pad"])
+ list(config["symbol"]["punctuation"])
+ list(config["symbol"]["letters"])
+ list(config["symbol"]["letters_ipa"])
+ list(config["symbol"]["extend"])
)
symbol_dict = {s: i for i, s in enumerate(symbols)}
n_token = len(symbol_dict) + 1
print("\nFound:", n_token, "symbols")
args = self.__recursive_munch(config["model_params"])
args["n_token"] = n_token
self.cleaner = TextCleaner(symbol_dict, debug=True)
self.decoder = Decoder(
dim_in=args.hidden_dim,
style_dim=args.style_dim,
dim_out=args.n_mels,
resblock_kernel_sizes=args.decoder.resblock_kernel_sizes,
upsample_rates=args.decoder.upsample_rates,
upsample_initial_channel=args.decoder.upsample_initial_channel,
resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
)
self.predictor = ProsodyPredictor(
style_dim=args.style_dim,
d_hid=args.hidden_dim,
nlayers=args.n_layer,
max_dur=args.max_dur,
dropout=args.dropout,
)
self.text_encoder = TextEncoder(
channels=args.hidden_dim,
kernel_size=5,
depth=args.n_layer,
n_symbols=args.n_token,
)
self.style_encoder = StyleEncoder(
dim_in=args.dim_in,
style_dim=args.style_dim,
max_conv_dim=args.hidden_dim,
)
n_speakers = config["data_params"]["n_speakers"]
self.spk_emb = torch.nn.Embedding(
n_speakers,
args.style_dim
)
self.spk_ln = torch.nn.LayerNorm(args.style_dim)
self.__load_models(models_path)
def text_to_sequence_char_level(text, symbol_dict):
seq = []
for ch in text:
if ch == " ":
continue
if ch in symbol_dict:
seq.append(symbol_dict[ch])
else:
print("[WARN] dropped char:", repr(ch))
return seq
def __recursive_munch(self, d):
if isinstance(d, dict):
return Munch((k, self.__recursive_munch(v)) for k, v in d.items())
if isinstance(d, list):
return [self.__recursive_munch(v) for v in d]
return d
def __replace_outliers_zscore(self, tensor, threshold=3.0, factor=0.95):
mean = tensor.mean()
std = tensor.std()
z = (tensor - mean) / (std + 1e-8)
outlier_mask = torch.abs(z) > threshold
sign = torch.sign(tensor - mean)
replacement = mean + sign * (threshold * std * factor)
result = tensor.clone()
result[outlier_mask] = replacement[outlier_mask]
return result
def __load_models(self, models_path):
# model = {
# "decoder": self.decoder,
# "predictor": self.predictor,
# "text_encoder": self.text_encoder,
# "style_encoder": self.style_encoder,
# }
model = {
"decoder": self.decoder,
"predictor": self.predictor,
"text_encoder": self.text_encoder,
"style_encoder": self.style_encoder,
"spk_emb": self.spk_emb,
"spk_ln": self.spk_ln,
}
params_whole = torch.load(models_path, map_location="cpu")
params = params_whole["net"]
params = {k: v for k, v in params.items() if k in model}
for k in model:
try:
model[k].load_state_dict(params[k])
except Exception:
from collections import OrderedDict
new_state_dict = OrderedDict()
for kk, vv in params[k].items():
new_state_dict[kk[7:]] = vv # strip "module."
model[k].load_state_dict(new_state_dict, strict=False)
print(k, ":", sum(p.numel() for p in model[k].parameters()))
def __compute_style(self, path, denoise, split_dur):
device = self.get_device.device
denoise = min(float(denoise), 1.0)
split_dur = int(split_dur) if split_dur else 0
wave, sr = librosa.load(path, sr=24000)
audio, _ = librosa.effects.trim(wave, top_db=30)
if denoise > 0.0:
audio_denoise = nr.reduce_noise(
y=audio, sr=sr, n_fft=2048, win_length=1200, hop_length=300
)
audio = audio * (1 - denoise) + audio_denoise * denoise
with torch.no_grad():
if split_dur > 0 and len(audio) / sr >= 4:
jump = sr * split_dur
total_len = len(audio)
ref_s = None
count = 0
for i in range(0, total_len, jump):
seg = audio[i : min(i + jump, total_len)]
if len(seg) < sr: # <1s thì bỏ
continue
mel = self.preprocess.wave_preprocess(seg).to(device)
s = self.style_encoder(mel.unsqueeze(1))
ref_s = s if ref_s is None else (ref_s + s)
count += 1
if ref_s is None:
mel = self.preprocess.wave_preprocess(audio).to(device)
ref_s = self.style_encoder(mel.unsqueeze(1))
else:
ref_s = ref_s / count
else:
mel = self.preprocess.wave_preprocess(audio).to(device)
ref_s = self.style_encoder(mel.unsqueeze(1))
return ref_s
def __inference(self, phonem, ref_s, speed=1.0, prev_d_mean=0.0, t=0.1):
device = self.get_device.device
tokens = self.cleaner(phonem)
print("[DBG] token_len =", len(tokens))
print("[DBG] phonem_head =", phonem[:80])
if len(tokens) == 0:
raise RuntimeError("Token sequence is empty!")
tokens = [0] + tokens + [0]
tokens = torch.LongTensor(tokens).unsqueeze(0).to(device)
print("\n========== TOKEN DEBUG ==========")
print("Max token id:", tokens.max().item())
print("Min token id:", tokens.min().item())
print("First 50 tokens:", tokens[0][:50].tolist())
print("=================================\n")
with torch.no_grad():
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = self.preprocess.length_to_mask(input_lengths).to(device)
t_en = self.text_encoder(tokens, input_lengths, text_mask)
s = ref_s.to(device)
if hasattr(self, "spk_emb"):
spk_id = torch.LongTensor([self.current_speaker_id]).to(device)
spk_vec = self.spk_emb(spk_id) # [1, style_dim]
spk_vec = self.spk_ln(spk_vec)
# merge speaker embedding into style
s = s + spk_vec
print("🎤 Speaker embedding injected:", self.current_speaker_id)
print("🎤 Style vec mean/std:", s.mean().item(), s.std().item())
d = self.predictor.text_encoder(t_en, s, input_lengths, text_mask)
x, _ = self.predictor.lstm(d)
duration = self.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(dim=-1)
if prev_d_mean != 0:
dur_stats = torch.empty_like(duration).normal_(mean=prev_d_mean, std=duration.std() + 1e-8).to(device)
else:
dur_stats = torch.empty_like(duration).normal_(mean=duration.mean(), std=duration.std() + 1e-8).to(device)
duration = duration * (1 - t) + dur_stats * t
duration[:, 1:-2] = self.__replace_outliers_zscore(duration[:, 1:-2])
duration = duration / speed
pred_dur = torch.round(duration.squeeze(0)).clamp(min=1)
L = int(input_lengths.item())
T = int(pred_dur.sum().item())
pred_aln_trg = torch.zeros((L, T), device=device)
c = 0
for i in range(L):
di = int(pred_dur[i].item())
pred_aln_trg[i, c : c + di] = 1
c += di
alignment = pred_aln_trg.unsqueeze(0)
en = d.transpose(-1, -2) @ alignment
F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
asr = t_en @ pred_aln_trg.unsqueeze(0)
out = self.decoder(asr, F0_pred, N_pred, s)
return out.squeeze().cpu().numpy(), float(duration.mean().item())
def get_styles(self, speakers, denoise=0.3, avg_style=True):
split_dur = 2 if avg_style else 0
styles = {}
for sid, meta in speakers.items():
ref_s = self.__compute_style(meta["path"], denoise=denoise, split_dur=split_dur)
styles[sid] = {
"style": ref_s,
"path": meta["path"],
"lang": meta["lang"],
"speed": meta["speed"],
}
return styles
def generate(self, text, styles, stabilize=True, n_merge=16, default_speaker="[id_1]"):
smooth_value = 0.2 if stabilize else 0.0
list_wav = []
prev_d_mean = 0.0
text = re.sub(r"[\n\r\t\f\v]", "", text)
# split by speaker tags
parts = re.split(r"(\[id_\d+\])", text)
if len(parts) <= 1 or re.match(r"(\[id_\d+\])", parts[0]) is None:
parts.insert(0, default_speaker)
speaker_tag = None # "id_1"
speaker_id = None # int
current_ref_s = None
speed = 1.0
for p in parts:
# -----------------------------
# Parse speaker tag
# -----------------------------
if re.match(r"(\[id_\d+\])", p):
speaker_tag = p.strip("[]") # "id_1"
speaker_id = int(speaker_tag.replace("id_", ""))
# expose speaker id for inference
self.current_speaker_id = speaker_id
if speaker_tag not in styles:
raise RuntimeError(f"Speaker {speaker_tag} not found in styles!")
current_ref_s = styles[speaker_tag]["style"]
speed = styles[speaker_tag]["speed"]
print("🎤 Active speaker:", speaker_tag, "->", speaker_id)
continue
if not p.strip():
continue
if speaker_id is None:
raise RuntimeError("Speaker ID chưa được set! Kiểm tra tag [id_x].")
# -----------------------------
# Text → phoneme → waveform
# -----------------------------
for sentence in self.preprocess.text_preprocess(p, n_merge=n_merge):
phonem = espeak_phn(sentence, styles[speaker_tag]["lang"])
wav, prev_d_mean = self.__inference(
phonem,
current_ref_s,
speed=speed,
prev_d_mean=prev_d_mean,
t=smooth_value,
)
# -----------------------------
# Debug
# -----------------------------
print("[DBG] wav shape:", wav.shape)
print("[DBG] wav min/max:", wav.min().item(), wav.max().item())
print("[DBG] wav mean abs:", np.abs(wav).mean())
# -----------------------------
# Safe trim
# -----------------------------
trim = int(0.05 * 24000) # 50 ms
if wav.shape[0] > 4 * trim:
wav = wav[trim:-trim]
if wav.size > 0:
list_wav.append(wav)
# -----------------------------
# Merge all chunks
# -----------------------------
if len(list_wav) == 0:
print("⚠️ No audio generated → return silence")
return np.zeros((2400,), dtype=np.float32)
final_wav = np.concatenate(list_wav)
# pad head & tail
pad = int(0.05 * 24000)
final_wav = np.concatenate(
[np.zeros((pad,), dtype=np.float32), final_wav, np.zeros((pad,), dtype=np.float32)]
)
return final_wav