Spaces:
Build error
Build error
| # -*- coding: utf-8 -*- | |
| import os.path as osp | |
| import random | |
| import numpy as np | |
| import random | |
| import soundfile as sf | |
| import librosa | |
| import torch | |
| try: | |
| import torchaudio | |
| except ImportError: | |
| torchaudio = None | |
| import torch.utils.data | |
| import torch.distributed as dist | |
| from multiprocessing import Pool | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.DEBUG) | |
| import pandas as pd | |
| # class TextCleaner: | |
| # def __init__(self, symbol_dict, debug=True): | |
| # self.word_index_dictionary = symbol_dict | |
| # self.debug = debug | |
| # def __call__(self, text): | |
| # indexes = [] | |
| # for char in text: | |
| # try: | |
| # indexes.append(self.word_index_dictionary[char]) | |
| # except KeyError as e: | |
| # if self.debug: | |
| # print("\nWARNING UNKNOWN IPA CHARACTERS/LETTERS: ", char) | |
| # print("To ignore set 'debug' to false in the config") | |
| # continue | |
| # return indexes | |
| SPECT_PARAMS = { | |
| "n_fft": 2048, | |
| "win_length": 1200, | |
| "hop_length": 300, | |
| } | |
| # Dùng đầy đủ params cho MelSpectrogram (tránh thiếu n_fft/win/hop) | |
| MEL_PARAMS = { | |
| "n_mels": 80, | |
| "n_fft": 2048, | |
| "win_length": 1200, | |
| "hop_length": 300, | |
| } | |
| mean, std = -4, 4 | |
| # Cache MelSpectrogram theo sample_rate | |
| _MEL_CACHE = {} | |
| def _require_torchaudio(context: str) -> None: | |
| if torchaudio is None: | |
| raise RuntimeError( | |
| f"torchaudio is required for {context} but is not installed in this environment. " | |
| "For HF Spaces inference, you should not instantiate FilePathDataset / mel extraction." | |
| ) | |
| def get_mel_transform(sample_rate: int = 16000): | |
| _require_torchaudio("mel extraction") | |
| if sample_rate not in _MEL_CACHE: | |
| _MEL_CACHE[sample_rate] = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=sample_rate, | |
| n_mels=MEL_PARAMS["n_mels"], | |
| n_fft=MEL_PARAMS["n_fft"], | |
| win_length=MEL_PARAMS["win_length"], | |
| hop_length=MEL_PARAMS["hop_length"], | |
| ) | |
| return _MEL_CACHE[sample_rate] | |
| def preprocess(wave: np.ndarray, sample_rate: int = 1600016000): | |
| """ | |
| wave: 1D numpy float array | |
| return: mel tensor shape (1, n_mels, T) | |
| """ | |
| _require_torchaudio("preprocess()") | |
| if wave.ndim != 1: | |
| wave = np.asarray(wave).squeeze() | |
| wave_tensor = torch.from_numpy(wave).float() | |
| to_mel = get_mel_transform(sample_rate) | |
| mel = to_mel(wave_tensor) # (n_mels, T) | |
| mel = (torch.log(mel + 1e-5) - mean) / std | |
| return mel.unsqueeze(0) # (1, n_mels, T) | |
| class TextCleaner: | |
| def __init__(self, symbol_dict, debug=True): | |
| self.symbol_dict = symbol_dict | |
| self.debug = debug | |
| def __call__(self, text: str): | |
| indexes = [] | |
| missing = [] | |
| for ch in text: | |
| if ch in self.symbol_dict: | |
| indexes.append(self.symbol_dict[ch]) | |
| else: | |
| missing.append(ch) | |
| if self.debug and missing: | |
| print(f"[TextCleaner] missing {len(missing)} symbols. sample={missing[:30]}") | |
| return indexes | |
| class FilePathDataset(torch.utils.data.Dataset): | |
| def __init__( | |
| self, | |
| data_list, | |
| root_path, | |
| symbol_dict, | |
| sr=16000, | |
| data_augxmentation=False, | |
| validation=False, | |
| debug=True, | |
| ): | |
| _require_torchaudio("FilePathDataset (training dataloader)") | |
| _data_list = [l.strip().split("|") for l in data_list] | |
| self.data_list = _data_list # [wav_path, text] (hoặc thêm speaker_id tuỳ bạn) | |
| self.text_cleaner = TextCleaner(symbol_dict, debug) | |
| self.sr = sr | |
| self.df = pd.DataFrame(self.data_list) | |
| # training-only: mel transform | |
| self.to_melspec = get_mel_transform(self.sr) | |
| self.mean, self.std = -4, 4 | |
| self.data_augmentation = data_augmentation and (not validation) | |
| self.max_mel_length = 192 | |
| self.root_path = root_path | |
| def __len__(self): | |
| return len(self.data_list) | |
| def __getitem__(self, idx): | |
| data = self.data_list[idx] | |
| path = data[0] | |
| wave, text_tensor = self._load_tensor(data) | |
| mel_tensor = preprocess(wave, sample_rate=self.sr).squeeze() # (n_mels, T) | |
| acoustic_feature = mel_tensor | |
| length_feature = acoustic_feature.size(1) | |
| acoustic_feature = acoustic_feature[:, : (length_feature - length_feature % 2)] | |
| return acoustic_feature, text_tensor, path, wave | |
| def _load_tensor(self, data): | |
| # data có thể là [wave_path, text] hoặc [wave_path, text, speaker_id] | |
| wave_path = data[0] | |
| text = data[1] | |
| wave, sr = sf.read(osp.join(self.root_path, wave_path)) | |
| if isinstance(wave, np.ndarray) and wave.ndim == 2 and wave.shape[-1] == 2: | |
| wave = wave[:, 0].squeeze() | |
| if sr != self.sr: | |
| wave = librosa.resample(wave, orig_sr=sr, target_sr=self.sr) | |
| # padding 0.5s mỗi bên (24000 * 0.5 = 12000) | |
| wave = np.concatenate([np.zeros([8000]), wave, np.zeros([8000])], axis=0) | |
| text_ids = self.text_cleaner(text) | |
| # BOS/EOS = 0 như code gốc của bạn | |
| text_ids.insert(0, 0) | |
| text_ids.append(0) | |
| text_tensor = torch.LongTensor(text_ids) | |
| return wave, text_tensor | |
| def _load_data(self, data): | |
| wave, text_tensor = self._load_tensor(data) | |
| mel_tensor = preprocess(wave, sample_rate=self.sr).squeeze() | |
| mel_length = mel_tensor.size(1) | |
| if mel_length > self.max_mel_length: | |
| random_start = np.random.randint(0, mel_length - self.max_mel_length) | |
| mel_tensor = mel_tensor[:, random_start : random_start + self.max_mel_length] | |
| return mel_tensor | |
| class Collater(object): | |
| """ | |
| Args: | |
| adaptive_batch_size (bool): if true, decrease batch size when long data comes. | |
| """ | |
| def __init__(self, return_wave=False): | |
| self.text_pad_index = 0 | |
| self.min_mel_length = 192 | |
| self.max_mel_length = 192 | |
| self.return_wave = return_wave | |
| def __call__(self, batch): | |
| batch_size = len(batch) | |
| # sort by mel length | |
| lengths = [b[0].shape[1] for b in batch] | |
| batch_indexes = np.argsort(lengths)[::-1] | |
| batch = [batch[bid] for bid in batch_indexes] | |
| nmels = batch[0][0].size(0) | |
| max_mel_length = max([b[0].shape[1] for b in batch]) | |
| max_text_length = max([b[1].shape[0] for b in batch]) | |
| mels = torch.zeros((batch_size, nmels, max_mel_length)).float() | |
| texts = torch.zeros((batch_size, max_text_length)).long() | |
| input_lengths = torch.zeros(batch_size).long() | |
| output_lengths = torch.zeros(batch_size).long() | |
| paths = ['' for _ in range(batch_size)] | |
| waves = [None for _ in range(batch_size)] | |
| for bid, (mel, text, path, wave) in enumerate(batch): | |
| mel_size = mel.size(1) | |
| text_size = text.size(0) | |
| mels[bid, :, :mel_size] = mel | |
| texts[bid, :text_size] = text | |
| input_lengths[bid] = text_size | |
| output_lengths[bid] = mel_size | |
| paths[bid] = path | |
| waves[bid] = wave | |
| return waves, texts, input_lengths, mels, output_lengths | |
| def get_length(wave_path, root_path): | |
| info = sf.info(osp.join(root_path, wave_path)) | |
| return info.frames * (16000 / info.samplerate) | |
| def build_dataloader(path_list, | |
| root_path, | |
| symbol_dict, | |
| validation=False, | |
| batch_size=4, | |
| num_workers=1, | |
| device='cpu', | |
| collate_config={}, | |
| dataset_config={}): | |
| dataset = FilePathDataset(path_list, root_path, symbol_dict, validation=validation, **dataset_config) | |
| collate_fn = Collater(**collate_config) | |
| print("Getting sample lengths...") | |
| num_processes = num_workers * 2 | |
| if num_processes != 0: | |
| list_of_tuples = [(d[0], root_path) for d in dataset.data_list] | |
| with Pool(processes=num_processes) as pool: | |
| sample_lengths = pool.starmap(get_length, list_of_tuples, chunksize=16) | |
| else: | |
| sample_lengths = [] | |
| for d in dataset.data_list: | |
| sample_lengths.append(get_length(d[0], root_path)) | |
| data_loader = torch.utils.data.DataLoader( | |
| dataset, | |
| num_workers=num_workers, | |
| batch_sampler=BatchSampler( | |
| sample_lengths, | |
| batch_size, | |
| shuffle=(not validation), | |
| drop_last=(not validation), | |
| num_replicas=1, | |
| rank=0, | |
| ), | |
| collate_fn=collate_fn, | |
| pin_memory=(device != "cpu"), | |
| ) | |
| return data_loader | |
| #https://github.com/duerig/StyleTTS2/ | |
| class BatchSampler(torch.utils.data.Sampler): | |
| def __init__( | |
| self, | |
| sample_lengths, | |
| batch_sizes, | |
| num_replicas=None, | |
| rank=None, | |
| shuffle=True, | |
| drop_last=False, | |
| ): | |
| self.batch_sizes = batch_sizes | |
| if num_replicas is None: | |
| self.num_replicas = dist.get_world_size() | |
| else: | |
| self.num_replicas = num_replicas | |
| if rank is None: | |
| self.rank = dist.get_rank() | |
| else: | |
| self.rank = rank | |
| self.shuffle = shuffle | |
| self.drop_last = drop_last | |
| self.time_bins = {} | |
| self.epoch = 0 | |
| self.total_len = 0 | |
| self.last_bin = None | |
| for i in range(len(sample_lengths)): | |
| bin_num = self.get_time_bin(sample_lengths[i]) | |
| if bin_num != -1: | |
| if bin_num not in self.time_bins: | |
| self.time_bins[bin_num] = [] | |
| self.time_bins[bin_num].append(i) | |
| for key in self.time_bins.keys(): | |
| val = self.time_bins[key] | |
| total_batch = self.batch_sizes * num_replicas | |
| self.total_len += len(val) // total_batch | |
| if not self.drop_last and len(val) % total_batch != 0: | |
| self.total_len += 1 | |
| def __iter__(self): | |
| sampler_order = list(self.time_bins.keys()) | |
| sampler_indices = [] | |
| if self.shuffle: | |
| sampler_indices = torch.randperm(len(sampler_order)).tolist() | |
| else: | |
| sampler_indices = list(range(len(sampler_order))) | |
| for index in sampler_indices: | |
| key = sampler_order[index] | |
| current_bin = self.time_bins[key] | |
| dist = torch.utils.data.distributed.DistributedSampler( | |
| current_bin, | |
| num_replicas=self.num_replicas, | |
| rank=self.rank, | |
| shuffle=self.shuffle, | |
| drop_last=self.drop_last, | |
| ) | |
| dist.set_epoch(self.epoch) | |
| sampler = torch.utils.data.sampler.BatchSampler( | |
| dist, self.batch_sizes, self.drop_last | |
| ) | |
| for item_list in sampler: | |
| self.last_bin = key | |
| yield [current_bin[i] for i in item_list] | |
| def __len__(self): | |
| return self.total_len | |
| def set_epoch(self, epoch): | |
| self.epoch = epoch | |
| def get_time_bin(self, sample_count): | |
| result = -1 | |
| frames = sample_count // 300 | |
| if frames >= 20: | |
| result = (frames - 20) // 20 | |
| return result |