StyleTTS_dolly / verify_infer_mp3.py
hieuducle's picture
Initial upload from script
53ff274 verified
import torch
torch.manual_seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
import random
random.seed(0)
import numpy as np
np.random.seed(0)
import time
import yaml
from munch import Munch
import torch.nn.functional as F
import torchaudio
import librosa
from nltk.tokenize import word_tokenize
from pydub import AudioSegment # Thêm thư viện này
from models import *
from utils import *
from text_utils import TextCleaner
import soundfile as sf
import os
# Khởi tạo
textclenaer = TextCleaner()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
to_mel = torchaudio.transforms.MelSpectrogram(
n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
mean, std = -4, 4
def length_to_mask(lengths):
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
mask = torch.gt(mask+1, lengths.unsqueeze(1))
return mask
def preprocess(wave):
wave_tensor = torch.from_numpy(wave).float()
mel_tensor = to_mel(wave_tensor)
mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
return mel_tensor
def compute_style(path):
wave, sr = librosa.load(path, sr=24000)
audio, index = librosa.effects.trim(wave, top_db=30)
if sr != 24000:
audio = librosa.resample(audio, sr, 24000)
mel_tensor = preprocess(audio).to(device)
with torch.no_grad():
ref_s = model.style_encoder(mel_tensor.unsqueeze(1))
ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))
return torch.cat([ref_s, ref_p], dim=1)
# load phonemizer
import phonemizer
global_phonemizer = phonemizer.backend.EspeakBackend(language='vi', preserve_punctuation=True, with_stress=True, language_switch='remove-flags')
config = yaml.safe_load(open("/workspace/StyleTTS2/Configs/config_ft.yml"))
# load pretrained models
ASR_config = config.get('ASR_config', False)
ASR_path = config.get('ASR_path', False)
text_aligner = load_ASR_models(ASR_path, ASR_config)
F0_path = config.get('F0_path', False)
pitch_extractor = load_F0_models(F0_path)
from Utils.PLBERT.util import load_plbert
BERT_path = config.get('PLBERT_dir', False)
plbert = load_plbert(BERT_path)
model_params = recursive_munch(config['model_params'])
model = build_model(model_params, text_aligner, pitch_extractor, plbert)
_ = [model[key].eval() for key in model]
_ = [model[key].to(device) for key in model]
# Load weights
params_whole = torch.load("/workspace/StyleTTS2/Models/Dolly/model_iter_00011000.pth", map_location='cpu')
params = params_whole['net']
for key in model:
if key in params:
try:
model[key].load_state_dict(params[key])
except:
from collections import OrderedDict
state_dict = params[key]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model[key].load_state_dict(new_state_dict, strict=False)
_ = [model[key].eval() for key in model]
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
sampler = DiffusionSampler(
model.diffusion.diffusion,
sampler=ADPM2Sampler(),
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0),
clamp=False
)
def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=10, embedding_scale=1):
text = text.strip()
ps = global_phonemizer.phonemize([text])
ps = word_tokenize(ps[0])
ps = ' '.join(ps)
ps = ps.replace('``', '"').replace("''", '"')
ps = ps.replace('t̪', '\uFFFF').replace('t', 'tʰ').replace('\uFFFF', 't')
tokens = textclenaer(ps)
tokens.insert(0, 0)
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
with torch.no_grad():
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
t_en = model.text_encoder(tokens, input_lengths, text_mask)
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
embedding=bert_dur,
embedding_scale=embedding_scale,
features=ref_s,
num_steps=diffusion_steps).squeeze(1)
if s_prev is not None:
s_pred = t * s_prev + (1 - t) * s_pred
s = s_pred[:, 128:]
ref = s_pred[:, :128]
ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
s = beta * s + (1 - beta) * ref_s[:, 128:]
s_pred = torch.cat([ref, s], dim=-1)
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1)
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
c_frame += int(pred_dur[i].data)
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
if model_params.decoder.type == "hifigan":
asr_new = torch.zeros_like(en)
asr_new[:, :, 0] = en[:, :, 0]
asr_new[:, :, 1:] = en[:, :, 0:-1]
en = asr_new
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
if model_params.decoder.type == "hifigan":
asr_new = torch.zeros_like(asr)
asr_new[:, :, 0] = asr[:, :, 0]
asr_new[:, :, 1:] = asr[:, :, 0:-1]
asr = asr_new
out = model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0))
return out.squeeze().cpu().numpy()[..., :-100], s_pred
# --- BẮT ĐẦU QUÁ TRÌNH INFERENCE ---
passage = '''Huyền thoại SEA Games bật khóc khi chỉ giành HC đồng Kết thúc lượt bơi chung kết, Quah Jing Wen không kìm được cảm xúc. Nữ kình ngư Singapore ôm mặt bật khóc ngay khi nhìn lên bảng thành tích, nơi tên cô xếp sau nhóm dẫn đầu. Những giọt nước mắt của Jing Wen khiến các đối thủ lặng đi, bởi đây là kỳ SEA Games cô đặt nhiều kỳ vọng nhất, sau thời gian dài vật lộn với chấn thương và áp lực phải duy trì truyền thống huy hoàng của bơi Singapore. Khoảnh khắc ấy phản ánh phần nào sức ép các VĐV bơi Singapore đang gánh. Sau nhiều năm thống trị khu vực, họ vẫn duy trì vị thế dẫn đầu, nhưng không còn một mình một ngựa như trước. Một số nội dung nam vắng Joseph Schooling từ lâu khiến khoảng trống thành tích chưa được lấp đầy. Ở nhóm nữ, dàn VĐV trẻ như Letitia Sim vẫn mang về huy chương đều đặn nhưng gặp sự cạnh tranh quyết liệt từ Thái Lan và Việt Nam, theo CNA.'''
# passage = '''sắc, nặng, ngã, huyền, hỏi, mỹ, lã chã'''
path = "./audio_ref/hn_nguyet_nga.wav"
s_ref = compute_style(path)
sentences = passage.split('.')
wavs = []
s_prev = None
for text in sentences:
if text.strip() == "": continue
text += '.'
wav, s_prev = LFinference(text, s_prev, s_ref,
alpha = 0.3, beta = 0.7, t = 0.7,
diffusion_steps=10, embedding_scale=1.5)
wavs.append(wav)
# Ghép các đoạn audio
final_wav = np.concatenate(wavs)
# --- LƯU FILE MP3 ---
# 1. Chuyển đổi sang định dạng int16 (PCM 16-bit)
final_wav_int16 = (final_wav * 32767).astype(np.int16)
# 2. Tạo đối tượng AudioSegment
audio_segment = AudioSegment(
final_wav_int16.tobytes(),
frame_rate=24000,
sample_width=final_wav_int16.dtype.itemsize,
channels=1
)
# 3. Xuất file MP3
out_path_mp3 = "./audio_clone/hn_nguyet_nga_clone.mp3"
os.makedirs(os.path.dirname(out_path_mp3), exist_ok=True)
audio_segment.export(out_path_mp3, format="mp3", bitrate="192k")
print("--- THÀNH CÔNG ---")
print("Đã lưu file MP3 tại:", out_path_mp3)