Spaces:
Running
on
T4
Running
on
T4
| import os | |
| import random | |
| import librosa | |
| import soundfile as sf | |
| import torch | |
| from speechbrain.pretrained import EncoderClassifier | |
| from torch.multiprocessing import Manager | |
| from torch.multiprocessing import Process | |
| from torch.utils.data import Dataset | |
| from torchaudio.transforms import Resample | |
| from tqdm import tqdm | |
| from Preprocessing.EnCodecAudioPreprocessor import CodecAudioPreprocessor | |
| from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend | |
| from Utility.storage_config import MODELS_DIR | |
| class CodecAlignerDataset(Dataset): | |
| def __init__(self, | |
| path_to_transcript_dict, | |
| cache_dir, | |
| lang, | |
| loading_processes, | |
| device, | |
| min_len_in_seconds=1, | |
| max_len_in_seconds=15, | |
| rebuild_cache=False, | |
| verbose=False, | |
| phone_input=False, | |
| allow_unknown_symbols=False, | |
| gpu_count=1, | |
| rank=0): | |
| self.gpu_count = gpu_count | |
| self.rank = rank | |
| if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache: | |
| self._build_dataset_cache(path_to_transcript_dict=path_to_transcript_dict, | |
| cache_dir=cache_dir, | |
| lang=lang, | |
| loading_processes=loading_processes, | |
| device=device, | |
| min_len_in_seconds=min_len_in_seconds, | |
| max_len_in_seconds=max_len_in_seconds, | |
| verbose=verbose, | |
| phone_input=phone_input, | |
| allow_unknown_symbols=allow_unknown_symbols, | |
| gpu_count=gpu_count, | |
| rank=rank) | |
| self.lang = lang | |
| self.device = device | |
| self.cache_dir = cache_dir | |
| self.tf = ArticulatoryCombinedTextFrontend(language=self.lang) | |
| cache = torch.load(os.path.join(self.cache_dir, "aligner_train_cache.pt"), map_location='cpu') | |
| self.speaker_embeddings = cache[2] | |
| self.datapoints = cache[0] | |
| if self.gpu_count > 1: | |
| # we only keep a chunk of the dataset in memory to avoid redundancy. Which chunk, we figure out using the rank. | |
| while len(self.datapoints) % self.gpu_count != 0: | |
| self.datapoints.pop(-1) # a bit unfortunate, but if you're using multiple GPUs, you probably have a ton of datapoints anyway. | |
| chunksize = int(len(self.datapoints) / self.gpu_count) | |
| self.datapoints = self.datapoints[chunksize * self.rank:chunksize * (self.rank + 1)] | |
| self.speaker_embeddings = self.speaker_embeddings[chunksize * self.rank:chunksize * (self.rank + 1)] | |
| print(f"Loaded an Aligner dataset with {len(self.datapoints)} datapoints from {cache_dir}.") | |
| def _build_dataset_cache(self, | |
| path_to_transcript_dict, | |
| cache_dir, | |
| lang, | |
| loading_processes, | |
| device, | |
| min_len_in_seconds=1, | |
| max_len_in_seconds=15, | |
| verbose=False, | |
| phone_input=False, | |
| allow_unknown_symbols=False, | |
| gpu_count=1, | |
| rank=0 | |
| ): | |
| if gpu_count != 1: | |
| import sys | |
| print("Please run the feature extraction using only a single GPU. Multi-GPU is only supported for training.") | |
| sys.exit() | |
| os.makedirs(cache_dir, exist_ok=True) | |
| if type(path_to_transcript_dict) != dict: | |
| path_to_transcript_dict = path_to_transcript_dict() # in this case we passed a function instead of the dict, so that the function isn't executed if not necessary. | |
| torch.multiprocessing.set_start_method('spawn', force=True) | |
| resource_manager = Manager() | |
| self.path_to_transcript_dict = resource_manager.dict(path_to_transcript_dict) | |
| key_list = list(self.path_to_transcript_dict.keys()) | |
| with open(os.path.join(cache_dir, "files_used.txt"), encoding='utf8', mode="w") as files_used_note: | |
| files_used_note.write(str(key_list)) | |
| fisher_yates_shuffle(key_list) | |
| # build cache | |
| print("... building dataset cache ...") | |
| self.result_pool = resource_manager.list() | |
| # make processes | |
| key_splits = list() | |
| process_list = list() | |
| for i in range(loading_processes): | |
| key_splits.append( | |
| key_list[i * len(key_list) // loading_processes:(i + 1) * len(key_list) // loading_processes]) | |
| for key_split in key_splits: | |
| process_list.append( | |
| Process(target=self._cache_builder_process, | |
| args=(key_split, | |
| lang, | |
| min_len_in_seconds, | |
| max_len_in_seconds, | |
| verbose, | |
| device, | |
| phone_input, | |
| allow_unknown_symbols), | |
| daemon=True)) | |
| process_list[-1].start() | |
| for process in process_list: | |
| process.join() | |
| print("pooling results...") | |
| pooled_datapoints = list() | |
| for chunk in self.result_pool: | |
| for datapoint in chunk: | |
| pooled_datapoints.append(datapoint) # unpack into a joint list | |
| self.result_pool = pooled_datapoints | |
| del pooled_datapoints | |
| print("converting text to tensors...") | |
| text_tensors = [torch.ShortTensor(x[0]) for x in self.result_pool] # turn everything back to tensors (had to turn it to np arrays to avoid multiprocessing issues) | |
| print("converting speech to tensors...") | |
| speech_tensors = [torch.ShortTensor(x[1]) for x in self.result_pool] | |
| print("converting waves to tensors...") | |
| norm_waves = [torch.Tensor(x[2]) for x in self.result_pool] | |
| print("unpacking file list...") | |
| filepaths = [x[3] for x in self.result_pool] | |
| del self.result_pool | |
| self.datapoints = list(zip(text_tensors, speech_tensors)) | |
| del text_tensors | |
| del speech_tensors | |
| print("done!") | |
| # add speaker embeddings | |
| self.speaker_embeddings = list() | |
| speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb", | |
| run_opts={"device": str(device)}, | |
| savedir=os.path.join(MODELS_DIR, "Embedding", "speechbrain_speaker_embedding_ecapa")) | |
| with torch.inference_mode(): | |
| for wave in tqdm(norm_waves): | |
| self.speaker_embeddings.append(speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(device).unsqueeze(0)).squeeze().cpu()) | |
| # save to cache | |
| if len(self.datapoints) == 0: | |
| raise RuntimeError # something went wrong and there are no datapoints | |
| torch.save((self.datapoints, None, self.speaker_embeddings, filepaths), | |
| os.path.join(cache_dir, "aligner_train_cache.pt")) | |
| def _cache_builder_process(self, | |
| path_list, | |
| lang, | |
| min_len, | |
| max_len, | |
| verbose, | |
| device, | |
| phone_input, | |
| allow_unknown_symbols): | |
| process_internal_dataset_chunk = list() | |
| torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround | |
| # careful: assumes 16kHz or 8kHz audio | |
| silero_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', | |
| model='silero_vad', | |
| force_reload=False, | |
| onnx=False, | |
| verbose=False) | |
| (get_speech_timestamps, | |
| save_audio, | |
| read_audio, | |
| VADIterator, | |
| collect_chunks) = utils | |
| torch.set_grad_enabled(True) # finding this issue was very infuriating: silero sets | |
| # this to false globally during model loading rather than using inference mode or no_grad | |
| silero_model = silero_model.to(device) | |
| silence = torch.zeros([16000 // 4], device=device) | |
| tf = ArticulatoryCombinedTextFrontend(language=lang) | |
| _, sr = sf.read(path_list[0]) | |
| assumed_sr = sr | |
| ap = CodecAudioPreprocessor(input_sr=assumed_sr, device=device) | |
| resample = Resample(orig_freq=assumed_sr, new_freq=16000).to(device) | |
| for path in tqdm(path_list): | |
| if self.path_to_transcript_dict[path].strip() == "": | |
| continue | |
| try: | |
| wave, sr = sf.read(path) | |
| except: | |
| print(f"Problem with an audio file: {path}") | |
| continue | |
| wave = librosa.to_mono(wave) | |
| if sr != assumed_sr: | |
| assumed_sr = sr | |
| ap = CodecAudioPreprocessor(input_sr=assumed_sr, device=device) | |
| resample = Resample(orig_freq=assumed_sr, new_freq=16000).to(device) | |
| print(f"{path} has a different sampling rate --> adapting the codec processor") | |
| try: | |
| norm_wave = resample(torch.tensor(wave).float().to(device)) | |
| except ValueError: | |
| continue | |
| dur_in_seconds = len(norm_wave) / 16000 | |
| if not (min_len <= dur_in_seconds <= max_len): | |
| if verbose: | |
| print(f"Excluding {path} because of its duration of {round(dur_in_seconds, 2)} seconds.") | |
| continue | |
| # remove silences from front and back, then add constant 1/4th second silences back to front and back | |
| with torch.no_grad(): | |
| speech_timestamps = get_speech_timestamps(norm_wave, silero_model, sampling_rate=16000) | |
| try: | |
| result = norm_wave[speech_timestamps[0]['start']:speech_timestamps[-1]['end']] | |
| except IndexError: | |
| print("Audio might be too short to cut silences from front and back.") | |
| continue | |
| wave = torch.cat([silence, result, silence]) | |
| # raw audio preprocessing is done | |
| transcript = self.path_to_transcript_dict[path] | |
| try: | |
| try: | |
| cached_text = tf.string_to_tensor(transcript, handle_missing=False, input_phonemes=phone_input).squeeze(0).cpu().numpy() | |
| except KeyError: | |
| cached_text = tf.string_to_tensor(transcript, handle_missing=True, input_phonemes=phone_input).squeeze(0).cpu().numpy() | |
| if not allow_unknown_symbols: | |
| continue # we skip sentences with unknown symbols | |
| except ValueError: | |
| # this can happen for Mandarin Chinese, when the syllabification of pinyin doesn't work. In that case, we just skip the sample. | |
| continue | |
| except KeyError: | |
| # this can happen for Mandarin Chinese, when the syllabification of pinyin doesn't work. In that case, we just skip the sample. | |
| continue | |
| cached_speech = ap.audio_to_codebook_indexes(audio=wave, current_sampling_rate=16000).transpose(0, 1).cpu().numpy() | |
| process_internal_dataset_chunk.append([cached_text, | |
| cached_speech, | |
| result.cpu().detach().numpy(), | |
| path]) | |
| self.result_pool.append(process_internal_dataset_chunk) | |
| def __getitem__(self, index): | |
| text_vector = self.datapoints[index][0] | |
| tokens = self.tf.text_vectors_to_id_sequence(text_vector=text_vector) | |
| tokens = torch.LongTensor(tokens) | |
| token_len = torch.LongTensor([len(tokens)]) | |
| codes = self.datapoints[index][1] | |
| if codes.size()[0] != 24: # no clue why this is sometimes the case | |
| codes = codes.transpose(0, 1) | |
| return tokens, \ | |
| token_len, \ | |
| codes, \ | |
| None, \ | |
| self.speaker_embeddings[index] | |
| def __len__(self): | |
| return len(self.datapoints) | |
| def fisher_yates_shuffle(lst): | |
| for i in range(len(lst) - 1, 0, -1): | |
| j = random.randint(0, i) | |
| lst[i], lst[j] = lst[j], lst[i] | |