Spaces:
Runtime error
Runtime error
| import os | |
| import random | |
| import warnings | |
| import soundfile as sf | |
| import torch | |
| from numpy import trim_zeros | |
| from speechbrain.pretrained import EncoderClassifier | |
| from torch.multiprocessing import Manager | |
| from torch.multiprocessing import Process | |
| from torch.multiprocessing import set_start_method | |
| from torch.utils.data import Dataset | |
| from tqdm import tqdm | |
| from Preprocessing.ArticulatoryCombinedTextFrontend import ArticulatoryCombinedTextFrontend | |
| from Preprocessing.AudioPreprocessor import AudioPreprocessor | |
| class AlignerDataset(Dataset): | |
| def __init__(self, | |
| path_to_transcript_dict, | |
| cache_dir, | |
| lang, | |
| loading_processes=30, # careful with the amount of processes if you use silence removal, only as many processes as you have cores | |
| min_len_in_seconds=1, | |
| max_len_in_seconds=20, | |
| cut_silences=False, | |
| rebuild_cache=False, | |
| verbose=False, | |
| device="cpu"): | |
| os.makedirs(cache_dir, exist_ok=True) | |
| if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache: | |
| if (device == "cuda" or device == torch.device("cuda")) and cut_silences: | |
| try: | |
| set_start_method('spawn') # in order to be able to make use of cuda in multiprocessing | |
| except RuntimeError: | |
| pass | |
| elif cut_silences: | |
| torch.set_num_threads(1) | |
| if cut_silences: | |
| torch.hub.load(repo_or_dir='snakers4/silero-vad', | |
| model='silero_vad', | |
| force_reload=False, | |
| onnx=False, | |
| verbose=False) # download and cache for it to be loaded and used later | |
| torch.set_grad_enabled(True) | |
| resource_manager = Manager() | |
| self.path_to_transcript_dict = resource_manager.dict(path_to_transcript_dict) | |
| key_list = list(self.path_to_transcript_dict.keys()) | |
| with open(os.path.join(cache_dir, "files_used.txt"), encoding='utf8', mode="w") as files_used_note: | |
| files_used_note.write(str(key_list)) | |
| random.shuffle(key_list) | |
| # build cache | |
| print("... building dataset cache ...") | |
| self.datapoints = resource_manager.list() | |
| # make processes | |
| key_splits = list() | |
| process_list = list() | |
| for i in range(loading_processes): | |
| key_splits.append(key_list[i * len(key_list) // loading_processes:(i + 1) * len(key_list) // loading_processes]) | |
| for key_split in key_splits: | |
| process_list.append( | |
| Process(target=self.cache_builder_process, | |
| args=(key_split, | |
| lang, | |
| min_len_in_seconds, | |
| max_len_in_seconds, | |
| cut_silences, | |
| verbose, | |
| device), | |
| daemon=True)) | |
| process_list[-1].start() | |
| for process in process_list: | |
| process.join() | |
| self.datapoints = list(self.datapoints) | |
| tensored_datapoints = list() | |
| # we had to turn all of the tensors to numpy arrays to avoid shared memory | |
| # issues. Now that the multi-processing is over, we can convert them back | |
| # to tensors to save on conversions in the future. | |
| print("Converting into convenient format...") | |
| norm_waves = list() | |
| for datapoint in tqdm(self.datapoints): | |
| tensored_datapoints.append([torch.Tensor(datapoint[0]), | |
| torch.LongTensor(datapoint[1]), | |
| torch.Tensor(datapoint[2]), | |
| torch.LongTensor(datapoint[3])]) | |
| norm_waves.append(torch.Tensor(datapoint[-1])) | |
| self.datapoints = tensored_datapoints | |
| pop_indexes = list() | |
| for index, el in enumerate(self.datapoints): | |
| try: | |
| if len(el[0][0]) != 66: | |
| pop_indexes.append(index) | |
| except TypeError: | |
| pop_indexes.append(index) | |
| for pop_index in sorted(pop_indexes, reverse=True): | |
| print(f"There seems to be a problem in the transcriptions. Deleting datapoint {pop_index}.") | |
| self.datapoints.pop(pop_index) | |
| # add speaker embeddings | |
| self.speaker_embeddings = list() | |
| speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb", | |
| run_opts={"device": str(device)}, | |
| savedir="Models/SpeakerEmbedding/speechbrain_speaker_embedding_ecapa") | |
| with torch.no_grad(): | |
| for wave in tqdm(norm_waves): | |
| self.speaker_embeddings.append(speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(device).unsqueeze(0)).squeeze().cpu()) | |
| # save to cache | |
| torch.save((self.datapoints, norm_waves, self.speaker_embeddings), os.path.join(cache_dir, "aligner_train_cache.pt")) | |
| else: | |
| # just load the datapoints from cache | |
| self.datapoints = torch.load(os.path.join(cache_dir, "aligner_train_cache.pt"), map_location='cpu') | |
| if len(self.datapoints) == 2: | |
| # speaker embeddings are still missing, have to add them here | |
| wave_datapoints = self.datapoints[1] | |
| self.datapoints = self.datapoints[0] | |
| self.speaker_embeddings = list() | |
| speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb", | |
| run_opts={"device": str(device)}, | |
| savedir="Models/SpeakerEmbedding/speechbrain_speaker_embedding_ecapa") | |
| with torch.no_grad(): | |
| for wave in tqdm(wave_datapoints): | |
| self.speaker_embeddings.append(speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(device).unsqueeze(0)).squeeze().cpu()) | |
| torch.save((self.datapoints, wave_datapoints, self.speaker_embeddings), os.path.join(cache_dir, "aligner_train_cache.pt")) | |
| else: | |
| self.speaker_embeddings = self.datapoints[2] | |
| self.datapoints = self.datapoints[0] | |
| self.tf = ArticulatoryCombinedTextFrontend(language=lang, use_word_boundaries=True) | |
| print(f"Prepared an Aligner dataset with {len(self.datapoints)} datapoints in {cache_dir}.") | |
| def cache_builder_process(self, | |
| path_list, | |
| lang, | |
| min_len, | |
| max_len, | |
| cut_silences, | |
| verbose, | |
| device): | |
| process_internal_dataset_chunk = list() | |
| tf = ArticulatoryCombinedTextFrontend(language=lang, use_word_boundaries=False) | |
| _, sr = sf.read(path_list[0]) | |
| ap = AudioPreprocessor(input_sr=sr, output_sr=16000, melspec_buckets=80, hop_length=256, n_fft=1024, cut_silence=cut_silences, device=device) | |
| for path in tqdm(path_list): | |
| if self.path_to_transcript_dict[path].strip() == "": | |
| continue | |
| wave, sr = sf.read(path) | |
| dur_in_seconds = len(wave) / sr | |
| if not (min_len <= dur_in_seconds <= max_len): | |
| if verbose: | |
| print(f"Excluding {path} because of its duration of {round(dur_in_seconds, 2)} seconds.") | |
| continue | |
| try: | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") # otherwise we get tons of warnings about an RNN not being in contiguous chunks | |
| norm_wave = ap.audio_to_wave_tensor(normalize=True, audio=wave) | |
| except ValueError: | |
| continue | |
| dur_in_seconds = len(norm_wave) / 16000 | |
| if not (min_len <= dur_in_seconds <= max_len): | |
| if verbose: | |
| print(f"Excluding {path} because of its duration of {round(dur_in_seconds, 2)} seconds.") | |
| continue | |
| norm_wave = torch.tensor(trim_zeros(norm_wave.numpy())) | |
| # raw audio preprocessing is done | |
| transcript = self.path_to_transcript_dict[path] | |
| try: | |
| cached_text = tf.string_to_tensor(transcript, handle_missing=False).squeeze(0).cpu().numpy() | |
| except KeyError: | |
| tf.string_to_tensor(transcript, handle_missing=True).squeeze(0).cpu().numpy() | |
| continue # we skip sentences with unknown symbols | |
| try: | |
| if len(cached_text[0]) != 66: | |
| print(f"There seems to be a problem with the following transcription: {transcript}") | |
| continue | |
| except TypeError: | |
| print(f"There seems to be a problem with the following transcription: {transcript}") | |
| continue | |
| cached_text_len = torch.LongTensor([len(cached_text)]).numpy() | |
| cached_speech = ap.audio_to_mel_spec_tensor(audio=norm_wave, normalize=False, explicit_sampling_rate=16000).transpose(0, 1).cpu().numpy() | |
| cached_speech_len = torch.LongTensor([len(cached_speech)]).numpy() | |
| process_internal_dataset_chunk.append([cached_text, | |
| cached_text_len, | |
| cached_speech, | |
| cached_speech_len, | |
| norm_wave.cpu().detach().numpy()]) | |
| self.datapoints += process_internal_dataset_chunk | |
| def __getitem__(self, index): | |
| text_vector = self.datapoints[index][0] | |
| tokens = list() | |
| for vector in text_vector: | |
| for phone in self.tf.phone_to_vector: | |
| if vector.numpy().tolist() == self.tf.phone_to_vector[phone]: | |
| tokens.append(self.tf.phone_to_id[phone]) | |
| # this is terribly inefficient, but it's good enough for testing for now. | |
| tokens = torch.LongTensor(tokens) | |
| return tokens, \ | |
| self.datapoints[index][1], \ | |
| self.datapoints[index][2], \ | |
| self.datapoints[index][3], \ | |
| self.speaker_embeddings[index] | |
| def __len__(self): | |
| return len(self.datapoints) | |