Spaces:
Running on Zero
Running on Zero
| import os | |
| import random | |
| import torch | |
| import torch.utils.data | |
| import librosa | |
| def mag_pha_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True): | |
| hann_window = torch.hann_window(win_size).to(y.device) | |
| stft_spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, | |
| center=center, pad_mode='reflect', normalized=False, return_complex=True) | |
| stft_spec = torch.view_as_real(stft_spec) | |
| mag = torch.sqrt(stft_spec.pow(2).sum(-1)+(1e-9)) | |
| pha = torch.atan2(stft_spec[:, :, :, 1]+(1e-10), stft_spec[:, :, :, 0]+(1e-5)) | |
| # Magnitude Compression | |
| mag = torch.pow(mag, compress_factor) | |
| com = torch.stack((mag*torch.cos(pha), mag*torch.sin(pha)), dim=-1) | |
| return mag, pha, com | |
| def mag_pha_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True): | |
| # Magnitude Decompression | |
| mag = torch.pow(mag, (1.0/compress_factor)) | |
| com = torch.complex(mag*torch.cos(pha), mag*torch.sin(pha)) | |
| hann_window = torch.hann_window(win_size).to(com.device) | |
| wav = torch.istft(com, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center) | |
| return wav | |
| def get_dataset_filelist(a): | |
| with open(a.input_training_file, 'r', encoding='utf-8') as fi: | |
| training_indexes = [x.split('|')[0] for x in fi.read().split('\n') if len(x) > 0] | |
| with open(a.input_validation_file, 'r', encoding='utf-8') as fi: | |
| validation_indexes = [x.split('|')[0] for x in fi.read().split('\n') if len(x) > 0] | |
| return training_indexes, validation_indexes | |
| class Dataset(torch.utils.data.Dataset): | |
| def __init__(self, training_indexes, clean_wavs_dir, noisy_wavs_dir, segment_size, | |
| sampling_rate, split=True, shuffle=True, n_cache_reuse=1, device=None): | |
| self.audio_indexes = training_indexes | |
| random.seed(1234) | |
| if shuffle: | |
| random.shuffle(self.audio_indexes) | |
| self.clean_wavs_dir = clean_wavs_dir | |
| self.noisy_wavs_dir = noisy_wavs_dir | |
| self.segment_size = segment_size | |
| self.sampling_rate = sampling_rate | |
| self.split = split | |
| self.cached_clean_wav = None | |
| self.cached_noisy_wav = None | |
| self.n_cache_reuse = n_cache_reuse | |
| self._cache_ref_count = 0 | |
| self.device = device | |
| def __getitem__(self, index): | |
| filename = self.audio_indexes[index] | |
| if self._cache_ref_count == 0: | |
| clean_audio, _ = librosa.load(os.path.join(self.clean_wavs_dir, filename + '.wav'), sr=self.sampling_rate) | |
| noisy_audio, _ = librosa.load(os.path.join(self.noisy_wavs_dir, filename + '.wav'), sr=self.sampling_rate) | |
| length = min(len(clean_audio), len(noisy_audio)) | |
| clean_audio, noisy_audio = clean_audio[: length], noisy_audio[: length] | |
| self.cached_clean_wav = clean_audio | |
| self.cached_noisy_wav = noisy_audio | |
| self._cache_ref_count = self.n_cache_reuse | |
| else: | |
| clean_audio = self.cached_clean_wav | |
| noisy_audio = self.cached_noisy_wav | |
| self._cache_ref_count -= 1 | |
| clean_audio, noisy_audio = torch.FloatTensor(clean_audio), torch.FloatTensor(noisy_audio) | |
| norm_factor = torch.sqrt(len(noisy_audio) / torch.sum(noisy_audio ** 2.0)) | |
| clean_audio = (clean_audio * norm_factor).unsqueeze(0) | |
| noisy_audio = (noisy_audio * norm_factor).unsqueeze(0) | |
| assert clean_audio.size(1) == noisy_audio.size(1) | |
| if self.split: | |
| if clean_audio.size(1) >= self.segment_size: | |
| max_audio_start = clean_audio.size(1) - self.segment_size | |
| audio_start = random.randint(0, max_audio_start) | |
| clean_audio = clean_audio[:, audio_start: audio_start+self.segment_size] | |
| noisy_audio = noisy_audio[:, audio_start: audio_start+self.segment_size] | |
| else: | |
| clean_audio = torch.nn.functional.pad(clean_audio, (0, self.segment_size - clean_audio.size(1)), 'constant') | |
| noisy_audio = torch.nn.functional.pad(noisy_audio, (0, self.segment_size - noisy_audio.size(1)), 'constant') | |
| return (clean_audio.squeeze(), noisy_audio.squeeze()) | |
| def __len__(self): | |
| return len(self.audio_indexes) | |