|
|
import re |
|
|
import yaml |
|
|
from munch import Munch |
|
|
import numpy as np |
|
|
import librosa |
|
|
import noisereduce as nr |
|
|
from .meldataset import TextCleaner |
|
|
import torch |
|
|
import torchaudio |
|
|
from nltk.tokenize import word_tokenize |
|
|
import nltk |
|
|
nltk.download('punkt_tab') |
|
|
|
|
|
from .models import ProsodyPredictor, TextEncoder, StyleEncoder |
|
|
|
|
|
class Preprocess: |
|
|
def __text_normalize(self, text): |
|
|
punctuation = [",", "、", "،", ";", "(", ".", "。", "…", "!", "–", ":", "?"] |
|
|
map_to = "." |
|
|
punctuation_pattern = re.compile(f"[{''.join(re.escape(p) for p in punctuation)}]") |
|
|
|
|
|
text = punctuation_pattern.sub(map_to, text) |
|
|
|
|
|
text = re.sub(r'\s+', ' ', text).strip() |
|
|
return text |
|
|
def __merge_fragments(self, texts, n): |
|
|
merged = [] |
|
|
i = 0 |
|
|
while i < len(texts): |
|
|
fragment = texts[i] |
|
|
j = i + 1 |
|
|
while len(fragment.split()) < n and j < len(texts): |
|
|
fragment += ", " + texts[j] |
|
|
j += 1 |
|
|
merged.append(fragment) |
|
|
i = j |
|
|
if len(merged[-1].split()) < n and len(merged) > 1: |
|
|
merged[-2] = merged[-2] + ", " + merged[-1] |
|
|
del merged[-1] |
|
|
else: |
|
|
merged[-1] = merged[-1] |
|
|
return merged |
|
|
def wave_preprocess(self, wave): |
|
|
to_mel = torchaudio.transforms.MelSpectrogram(n_mels=80, n_fft=2048, win_length=1200, hop_length=300) |
|
|
mean, std = -4, 4 |
|
|
wave_tensor = torch.from_numpy(wave).float() |
|
|
mel_tensor = to_mel(wave_tensor) |
|
|
mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std |
|
|
return mel_tensor |
|
|
def text_preprocess(self, text, n_merge=12): |
|
|
text_norm = self.__text_normalize(text).split(".") |
|
|
text_norm = [s.strip() for s in text_norm] |
|
|
text_norm = list(filter(lambda x: x != '', text_norm)) |
|
|
text_norm = self.__merge_fragments(text_norm, n=n_merge) |
|
|
return text_norm |
|
|
def length_to_mask(self, lengths): |
|
|
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) |
|
|
mask = torch.gt(mask+1, lengths.unsqueeze(1)) |
|
|
return mask |
|
|
|
|
|
|
|
|
class StyleTTS2(torch.nn.Module): |
|
|
def __init__(self, config_path, models_path): |
|
|
super().__init__() |
|
|
self.register_buffer("get_device", torch.empty(0)) |
|
|
self.preprocess = Preprocess() |
|
|
self.ref_s = None |
|
|
config = yaml.safe_load(open(config_path, "r", encoding="utf-8")) |
|
|
|
|
|
try: |
|
|
symbols = ( |
|
|
list(config['symbol']['pad']) + |
|
|
list(config['symbol']['punctuation']) + |
|
|
list(config['symbol']['letters']) + |
|
|
list(config['symbol']['letters_ipa']) + |
|
|
list(config['symbol']['extend']) |
|
|
) |
|
|
symbol_dict = {} |
|
|
for i in range(len((symbols))): |
|
|
symbol_dict[symbols[i]] = i |
|
|
|
|
|
n_token = len(symbol_dict) + 1 |
|
|
print("\nFound:", n_token, "symbols") |
|
|
except Exception as e: |
|
|
print(f"\nERROR: Cannot find {e} in config file!\nYour config file is likely outdated, please download updated version from the repository.") |
|
|
raise SystemExit(1) |
|
|
|
|
|
args = self.__recursive_munch(config['model_params']) |
|
|
args['n_token'] = n_token |
|
|
|
|
|
self.cleaner = TextCleaner(symbol_dict, debug=False) |
|
|
|
|
|
assert args.decoder.type in ['istftnet', 'hifigan', 'vocos'], 'Decoder type unknown' |
|
|
|
|
|
if args.decoder.type == "istftnet": |
|
|
from .Modules.istftnet import Decoder |
|
|
self.decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels, |
|
|
resblock_kernel_sizes = args.decoder.resblock_kernel_sizes, |
|
|
upsample_rates = args.decoder.upsample_rates, |
|
|
upsample_initial_channel=args.decoder.upsample_initial_channel, |
|
|
resblock_dilation_sizes=args.decoder.resblock_dilation_sizes, |
|
|
upsample_kernel_sizes=args.decoder.upsample_kernel_sizes, |
|
|
gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size) |
|
|
elif args.decoder.type == "hifigan": |
|
|
from .Modules.hifigan import Decoder |
|
|
self.decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels, |
|
|
resblock_kernel_sizes = args.decoder.resblock_kernel_sizes, |
|
|
upsample_rates = args.decoder.upsample_rates, |
|
|
upsample_initial_channel=args.decoder.upsample_initial_channel, |
|
|
resblock_dilation_sizes=args.decoder.resblock_dilation_sizes, |
|
|
upsample_kernel_sizes=args.decoder.upsample_kernel_sizes) |
|
|
elif args.decoder.type == "vocos": |
|
|
from .Modules.vocos import Decoder |
|
|
self.decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels, |
|
|
intermediate_dim=args.decoder.intermediate_dim, |
|
|
num_layers=args.decoder.num_layers, |
|
|
gen_istft_n_fft=args.decoder.gen_istft_n_fft, |
|
|
gen_istft_hop_size=args.decoder.gen_istft_hop_size) |
|
|
|
|
|
self.predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout) |
|
|
self.text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token) |
|
|
self.style_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) |
|
|
|
|
|
self.__load_models(models_path) |
|
|
|
|
|
def __recursive_munch(self, d): |
|
|
if isinstance(d, dict): |
|
|
return Munch((k, self.__recursive_munch(v)) for k, v in d.items()) |
|
|
elif isinstance(d, list): |
|
|
return [self.__recursive_munch(v) for v in d] |
|
|
else: |
|
|
return d |
|
|
|
|
|
def __replace_outliers_zscore(self, tensor, threshold=3.0, factor=0.95): |
|
|
mean = tensor.mean() |
|
|
std = tensor.std() |
|
|
z = (tensor - mean) / std |
|
|
|
|
|
|
|
|
outlier_mask = torch.abs(z) > threshold |
|
|
|
|
|
sign = torch.sign(tensor - mean) |
|
|
replacement = mean + sign * (threshold * std * factor) |
|
|
|
|
|
result = tensor.clone() |
|
|
result[outlier_mask] = replacement[outlier_mask] |
|
|
|
|
|
return result |
|
|
|
|
|
def __load_models(self, models_path): |
|
|
module_params = [] |
|
|
model = {'decoder':self.decoder, 'predictor':self.predictor, 'text_encoder':self.text_encoder, 'style_encoder':self.style_encoder} |
|
|
|
|
|
params_whole = torch.load(models_path, map_location='cpu') |
|
|
params = params_whole['net'] |
|
|
params = {key: value for key, value in params.items() if key in model.keys()} |
|
|
|
|
|
for key in model: |
|
|
try: |
|
|
model[key].load_state_dict(params[key]) |
|
|
except: |
|
|
from collections import OrderedDict |
|
|
state_dict = params[key] |
|
|
new_state_dict = OrderedDict() |
|
|
for k, v in state_dict.items(): |
|
|
name = k[7:] |
|
|
new_state_dict[name] = v |
|
|
model[key].load_state_dict(new_state_dict, strict=False) |
|
|
|
|
|
total_params = sum(p.numel() for p in model[key].parameters()) |
|
|
print(key,":",total_params) |
|
|
module_params.append(total_params) |
|
|
|
|
|
print('\nTotal',":",sum(module_params)) |
|
|
|
|
|
def __compute_style(self, path, denoise, split_dur): |
|
|
device = self.get_device.device |
|
|
denoise = min(denoise, 1) |
|
|
if split_dur != 0: split_dur = max(int(split_dur), 1) |
|
|
max_samples = 24000*20 |
|
|
print("Computing the style for:", path) |
|
|
|
|
|
wave, sr = librosa.load(path, sr=24000) |
|
|
audio, index = librosa.effects.trim(wave, top_db=30) |
|
|
if sr != 24000: |
|
|
audio = librosa.resample(audio, sr, 24000) |
|
|
if len(audio) > max_samples: |
|
|
audio = audio[:max_samples] |
|
|
|
|
|
if denoise > 0.0: |
|
|
audio_denoise = nr.reduce_noise(y=audio, sr=sr, n_fft=2048, win_length=1200, hop_length=300) |
|
|
audio = audio*(1-denoise) + audio_denoise*denoise |
|
|
|
|
|
with torch.no_grad(): |
|
|
if split_dur>0 and len(audio)/sr>=4: |
|
|
|
|
|
count = 0 |
|
|
ref_s = None |
|
|
jump = sr*split_dur |
|
|
total_len = len(audio) |
|
|
|
|
|
|
|
|
mel_tensor = self.preprocess.wave_preprocess(audio[0:jump]).to(device) |
|
|
ref_s = self.style_encoder(mel_tensor.unsqueeze(1)) |
|
|
count += 1 |
|
|
for i in range(jump, total_len, jump): |
|
|
if i+jump >= total_len: |
|
|
left_dur = (total_len-i)/sr |
|
|
if left_dur >= 1: |
|
|
mel_tensor = self.preprocess.wave_preprocess(audio[i:total_len]).to(device) |
|
|
ref_s += self.style_encoder(mel_tensor.unsqueeze(1)) |
|
|
count += 1 |
|
|
continue |
|
|
mel_tensor = self.preprocess.wave_preprocess(audio[i:i+jump]).to(device) |
|
|
ref_s += self.style_encoder(mel_tensor.unsqueeze(1)) |
|
|
count += 1 |
|
|
ref_s /= count |
|
|
else: |
|
|
mel_tensor = self.preprocess.wave_preprocess(audio).to(device) |
|
|
ref_s = self.style_encoder(mel_tensor.unsqueeze(1)) |
|
|
|
|
|
return ref_s |
|
|
|
|
|
def __inference(self, phonem, ref_s, speed=1, prev_d_mean=0, t=0.1): |
|
|
device = self.get_device.device |
|
|
speed = min(max(speed, 0.0001), 2) |
|
|
|
|
|
phonem = ' '.join(word_tokenize(phonem)) |
|
|
tokens = self.cleaner(phonem) |
|
|
tokens.insert(0, 0) |
|
|
tokens.append(0) |
|
|
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) |
|
|
text_mask = self.preprocess.length_to_mask(input_lengths).to(device) |
|
|
|
|
|
|
|
|
t_en = self.text_encoder(tokens, input_lengths, text_mask) |
|
|
s = ref_s.to(device) |
|
|
|
|
|
|
|
|
d = self.predictor.text_encoder(t_en, s, input_lengths, text_mask) |
|
|
x, _ = self.predictor.lstm(d) |
|
|
duration = self.predictor.duration_proj(x) |
|
|
duration = torch.sigmoid(duration).sum(axis=-1) |
|
|
|
|
|
if prev_d_mean != 0: |
|
|
dur_stats = torch.empty(duration.shape).normal_(mean=prev_d_mean, std=duration.std()).to(device) |
|
|
else: |
|
|
dur_stats = torch.empty(duration.shape).normal_(mean=duration.mean(), std=duration.std()).to(device) |
|
|
duration = duration*(1-t) + dur_stats*t |
|
|
duration[:,1:-2] = self.__replace_outliers_zscore(duration[:,1:-2]) |
|
|
|
|
|
duration /= speed |
|
|
|
|
|
pred_dur = torch.round(duration.squeeze()).clamp(min=1) |
|
|
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) |
|
|
c_frame = 0 |
|
|
for i in range(pred_aln_trg.size(0)): |
|
|
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 |
|
|
c_frame += int(pred_dur[i].data) |
|
|
alignment = pred_aln_trg.unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
en = (d.transpose(-1, -2) @ alignment) |
|
|
F0_pred, N_pred = self.predictor.F0Ntrain(en, s) |
|
|
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device)) |
|
|
|
|
|
out = self.decoder(asr, F0_pred, N_pred, s) |
|
|
|
|
|
return out.squeeze().cpu().numpy(), duration.mean() |
|
|
|
|
|
def get_styles(self, speaker, denoise=0.3, avg_style=True, load_styles=False): |
|
|
if not load_styles: |
|
|
if avg_style: split_dur = 3 |
|
|
else: split_dur = 0 |
|
|
self.ref_s = self.__compute_style(speaker['path'], denoise=denoise, split_dur=split_dur) |
|
|
else: |
|
|
if self.ref_s is None: |
|
|
raise Exception("Have to compute or load the styles first!") |
|
|
style = { |
|
|
'style': self.ref_s, |
|
|
'path': speaker['path'], |
|
|
'speed': speaker['speed'], |
|
|
} |
|
|
return style |
|
|
|
|
|
def save_styles(self, save_dir): |
|
|
if self.ref_s is not None: |
|
|
torch.save(self.ref_s, save_dir) |
|
|
print("Saved styles!") |
|
|
else: |
|
|
raise Exception("Have to compute the styles before saving it.") |
|
|
|
|
|
def load_styles(self, save_dir): |
|
|
try: |
|
|
self.ref_s = torch.load(save_dir) |
|
|
print("Loaded styles!") |
|
|
except Exception as e: |
|
|
print(e) |
|
|
|
|
|
def generate(self, phonem, style, stabilize=True, n_merge=16): |
|
|
if stabilize: smooth_value=0.2 |
|
|
else: smooth_value=0 |
|
|
|
|
|
list_wav = [] |
|
|
prev_d_mean = 0 |
|
|
|
|
|
print("Generating Audio...") |
|
|
text_norm = self.preprocess.text_preprocess(phonem, n_merge=n_merge) |
|
|
for sentence in text_norm: |
|
|
wav, prev_d_mean = self.__inference(sentence, style['style'], speed=style['speed'], prev_d_mean=prev_d_mean, t=smooth_value) |
|
|
wav = wav[4000:-4000] |
|
|
list_wav.append(wav) |
|
|
|
|
|
final_wav = np.concatenate(list_wav) |
|
|
final_wav = np.concatenate([np.zeros([4000]), final_wav, np.zeros([4000])], axis=0) |
|
|
return final_wav |