ttsStyleTTS2 / meldataset.py
stephenhoang's picture
Update meldataset.py
a96c595 verified
# -*- coding: utf-8 -*-
import os.path as osp
import random
import numpy as np
import random
import soundfile as sf
import librosa
import torch
try:
import torchaudio
except ImportError:
torchaudio = None
import torch.utils.data
import torch.distributed as dist
from multiprocessing import Pool
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
import pandas as pd
# class TextCleaner:
# def __init__(self, symbol_dict, debug=True):
# self.word_index_dictionary = symbol_dict
# self.debug = debug
# def __call__(self, text):
# indexes = []
# for char in text:
# try:
# indexes.append(self.word_index_dictionary[char])
# except KeyError as e:
# if self.debug:
# print("\nWARNING UNKNOWN IPA CHARACTERS/LETTERS: ", char)
# print("To ignore set 'debug' to false in the config")
# continue
# return indexes
SPECT_PARAMS = {
"n_fft": 2048,
"win_length": 1200,
"hop_length": 300,
}
# Dùng đầy đủ params cho MelSpectrogram (tránh thiếu n_fft/win/hop)
MEL_PARAMS = {
"n_mels": 80,
"n_fft": 2048,
"win_length": 1200,
"hop_length": 300,
}
mean, std = -4, 4
# Cache MelSpectrogram theo sample_rate
_MEL_CACHE = {}
def _require_torchaudio(context: str) -> None:
if torchaudio is None:
raise RuntimeError(
f"torchaudio is required for {context} but is not installed in this environment. "
"For HF Spaces inference, you should not instantiate FilePathDataset / mel extraction."
)
def get_mel_transform(sample_rate: int = 16000):
_require_torchaudio("mel extraction")
if sample_rate not in _MEL_CACHE:
_MEL_CACHE[sample_rate] = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_mels=MEL_PARAMS["n_mels"],
n_fft=MEL_PARAMS["n_fft"],
win_length=MEL_PARAMS["win_length"],
hop_length=MEL_PARAMS["hop_length"],
)
return _MEL_CACHE[sample_rate]
def preprocess(wave: np.ndarray, sample_rate: int = 1600016000):
"""
wave: 1D numpy float array
return: mel tensor shape (1, n_mels, T)
"""
_require_torchaudio("preprocess()")
if wave.ndim != 1:
wave = np.asarray(wave).squeeze()
wave_tensor = torch.from_numpy(wave).float()
to_mel = get_mel_transform(sample_rate)
mel = to_mel(wave_tensor) # (n_mels, T)
mel = (torch.log(mel + 1e-5) - mean) / std
return mel.unsqueeze(0) # (1, n_mels, T)
class TextCleaner:
def __init__(self, symbol_dict, debug=True):
self.symbol_dict = symbol_dict
self.debug = debug
def __call__(self, text: str):
indexes = []
missing = []
for ch in text:
if ch in self.symbol_dict:
indexes.append(self.symbol_dict[ch])
else:
missing.append(ch)
if self.debug and missing:
print(f"[TextCleaner] missing {len(missing)} symbols. sample={missing[:30]}")
return indexes
class FilePathDataset(torch.utils.data.Dataset):
def __init__(
self,
data_list,
root_path,
symbol_dict,
sr=16000,
data_augxmentation=False,
validation=False,
debug=True,
):
_require_torchaudio("FilePathDataset (training dataloader)")
_data_list = [l.strip().split("|") for l in data_list]
self.data_list = _data_list # [wav_path, text] (hoặc thêm speaker_id tuỳ bạn)
self.text_cleaner = TextCleaner(symbol_dict, debug)
self.sr = sr
self.df = pd.DataFrame(self.data_list)
# training-only: mel transform
self.to_melspec = get_mel_transform(self.sr)
self.mean, self.std = -4, 4
self.data_augmentation = data_augmentation and (not validation)
self.max_mel_length = 192
self.root_path = root_path
def __len__(self):
return len(self.data_list)
def __getitem__(self, idx):
data = self.data_list[idx]
path = data[0]
wave, text_tensor = self._load_tensor(data)
mel_tensor = preprocess(wave, sample_rate=self.sr).squeeze() # (n_mels, T)
acoustic_feature = mel_tensor
length_feature = acoustic_feature.size(1)
acoustic_feature = acoustic_feature[:, : (length_feature - length_feature % 2)]
return acoustic_feature, text_tensor, path, wave
def _load_tensor(self, data):
# data có thể là [wave_path, text] hoặc [wave_path, text, speaker_id]
wave_path = data[0]
text = data[1]
wave, sr = sf.read(osp.join(self.root_path, wave_path))
if isinstance(wave, np.ndarray) and wave.ndim == 2 and wave.shape[-1] == 2:
wave = wave[:, 0].squeeze()
if sr != self.sr:
wave = librosa.resample(wave, orig_sr=sr, target_sr=self.sr)
# padding 0.5s mỗi bên (24000 * 0.5 = 12000)
wave = np.concatenate([np.zeros([8000]), wave, np.zeros([8000])], axis=0)
text_ids = self.text_cleaner(text)
# BOS/EOS = 0 như code gốc của bạn
text_ids.insert(0, 0)
text_ids.append(0)
text_tensor = torch.LongTensor(text_ids)
return wave, text_tensor
def _load_data(self, data):
wave, text_tensor = self._load_tensor(data)
mel_tensor = preprocess(wave, sample_rate=self.sr).squeeze()
mel_length = mel_tensor.size(1)
if mel_length > self.max_mel_length:
random_start = np.random.randint(0, mel_length - self.max_mel_length)
mel_tensor = mel_tensor[:, random_start : random_start + self.max_mel_length]
return mel_tensor
class Collater(object):
"""
Args:
adaptive_batch_size (bool): if true, decrease batch size when long data comes.
"""
def __init__(self, return_wave=False):
self.text_pad_index = 0
self.min_mel_length = 192
self.max_mel_length = 192
self.return_wave = return_wave
def __call__(self, batch):
batch_size = len(batch)
# sort by mel length
lengths = [b[0].shape[1] for b in batch]
batch_indexes = np.argsort(lengths)[::-1]
batch = [batch[bid] for bid in batch_indexes]
nmels = batch[0][0].size(0)
max_mel_length = max([b[0].shape[1] for b in batch])
max_text_length = max([b[1].shape[0] for b in batch])
mels = torch.zeros((batch_size, nmels, max_mel_length)).float()
texts = torch.zeros((batch_size, max_text_length)).long()
input_lengths = torch.zeros(batch_size).long()
output_lengths = torch.zeros(batch_size).long()
paths = ['' for _ in range(batch_size)]
waves = [None for _ in range(batch_size)]
for bid, (mel, text, path, wave) in enumerate(batch):
mel_size = mel.size(1)
text_size = text.size(0)
mels[bid, :, :mel_size] = mel
texts[bid, :text_size] = text
input_lengths[bid] = text_size
output_lengths[bid] = mel_size
paths[bid] = path
waves[bid] = wave
return waves, texts, input_lengths, mels, output_lengths
def get_length(wave_path, root_path):
info = sf.info(osp.join(root_path, wave_path))
return info.frames * (16000 / info.samplerate)
def build_dataloader(path_list,
root_path,
symbol_dict,
validation=False,
batch_size=4,
num_workers=1,
device='cpu',
collate_config={},
dataset_config={}):
dataset = FilePathDataset(path_list, root_path, symbol_dict, validation=validation, **dataset_config)
collate_fn = Collater(**collate_config)
print("Getting sample lengths...")
num_processes = num_workers * 2
if num_processes != 0:
list_of_tuples = [(d[0], root_path) for d in dataset.data_list]
with Pool(processes=num_processes) as pool:
sample_lengths = pool.starmap(get_length, list_of_tuples, chunksize=16)
else:
sample_lengths = []
for d in dataset.data_list:
sample_lengths.append(get_length(d[0], root_path))
data_loader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_sampler=BatchSampler(
sample_lengths,
batch_size,
shuffle=(not validation),
drop_last=(not validation),
num_replicas=1,
rank=0,
),
collate_fn=collate_fn,
pin_memory=(device != "cpu"),
)
return data_loader
#https://github.com/duerig/StyleTTS2/
class BatchSampler(torch.utils.data.Sampler):
def __init__(
self,
sample_lengths,
batch_sizes,
num_replicas=None,
rank=None,
shuffle=True,
drop_last=False,
):
self.batch_sizes = batch_sizes
if num_replicas is None:
self.num_replicas = dist.get_world_size()
else:
self.num_replicas = num_replicas
if rank is None:
self.rank = dist.get_rank()
else:
self.rank = rank
self.shuffle = shuffle
self.drop_last = drop_last
self.time_bins = {}
self.epoch = 0
self.total_len = 0
self.last_bin = None
for i in range(len(sample_lengths)):
bin_num = self.get_time_bin(sample_lengths[i])
if bin_num != -1:
if bin_num not in self.time_bins:
self.time_bins[bin_num] = []
self.time_bins[bin_num].append(i)
for key in self.time_bins.keys():
val = self.time_bins[key]
total_batch = self.batch_sizes * num_replicas
self.total_len += len(val) // total_batch
if not self.drop_last and len(val) % total_batch != 0:
self.total_len += 1
def __iter__(self):
sampler_order = list(self.time_bins.keys())
sampler_indices = []
if self.shuffle:
sampler_indices = torch.randperm(len(sampler_order)).tolist()
else:
sampler_indices = list(range(len(sampler_order)))
for index in sampler_indices:
key = sampler_order[index]
current_bin = self.time_bins[key]
dist = torch.utils.data.distributed.DistributedSampler(
current_bin,
num_replicas=self.num_replicas,
rank=self.rank,
shuffle=self.shuffle,
drop_last=self.drop_last,
)
dist.set_epoch(self.epoch)
sampler = torch.utils.data.sampler.BatchSampler(
dist, self.batch_sizes, self.drop_last
)
for item_list in sampler:
self.last_bin = key
yield [current_bin[i] for i in item_list]
def __len__(self):
return self.total_len
def set_epoch(self, epoch):
self.epoch = epoch
def get_time_bin(self, sample_count):
result = -1
frames = sample_count // 300
if frames >= 20:
result = (frames - 20) // 20
return result