import os import random import torch import torch.utils.data import librosa def mag_pha_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True): hann_window = torch.hann_window(win_size).to(y.device) stft_spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center, pad_mode='reflect', normalized=False, return_complex=True) stft_spec = torch.view_as_real(stft_spec) mag = torch.sqrt(stft_spec.pow(2).sum(-1)+(1e-9)) pha = torch.atan2(stft_spec[:, :, :, 1]+(1e-10), stft_spec[:, :, :, 0]+(1e-5)) # Magnitude Compression mag = torch.pow(mag, compress_factor) com = torch.stack((mag*torch.cos(pha), mag*torch.sin(pha)), dim=-1) return mag, pha, com def mag_pha_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True): # Magnitude Decompression mag = torch.pow(mag, (1.0/compress_factor)) com = torch.complex(mag*torch.cos(pha), mag*torch.sin(pha)) hann_window = torch.hann_window(win_size).to(com.device) wav = torch.istft(com, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center) return wav def get_dataset_filelist(a): with open(a.input_training_file, 'r', encoding='utf-8') as fi: training_indexes = [x.split('|')[0] for x in fi.read().split('\n') if len(x) > 0] with open(a.input_validation_file, 'r', encoding='utf-8') as fi: validation_indexes = [x.split('|')[0] for x in fi.read().split('\n') if len(x) > 0] return training_indexes, validation_indexes class Dataset(torch.utils.data.Dataset): def __init__(self, training_indexes, clean_wavs_dir, noisy_wavs_dir, segment_size, sampling_rate, split=True, shuffle=True, n_cache_reuse=1, device=None): self.audio_indexes = training_indexes random.seed(1234) if shuffle: random.shuffle(self.audio_indexes) self.clean_wavs_dir = clean_wavs_dir self.noisy_wavs_dir = noisy_wavs_dir self.segment_size = segment_size self.sampling_rate = sampling_rate self.split = split self.cached_clean_wav = None self.cached_noisy_wav = None self.n_cache_reuse = n_cache_reuse self._cache_ref_count = 0 self.device = device def __getitem__(self, index): filename = self.audio_indexes[index] if self._cache_ref_count == 0: clean_audio, _ = librosa.load(os.path.join(self.clean_wavs_dir, filename + '.wav'), sr=self.sampling_rate) noisy_audio, _ = librosa.load(os.path.join(self.noisy_wavs_dir, filename + '.wav'), sr=self.sampling_rate) length = min(len(clean_audio), len(noisy_audio)) clean_audio, noisy_audio = clean_audio[: length], noisy_audio[: length] self.cached_clean_wav = clean_audio self.cached_noisy_wav = noisy_audio self._cache_ref_count = self.n_cache_reuse else: clean_audio = self.cached_clean_wav noisy_audio = self.cached_noisy_wav self._cache_ref_count -= 1 clean_audio, noisy_audio = torch.FloatTensor(clean_audio), torch.FloatTensor(noisy_audio) norm_factor = torch.sqrt(len(noisy_audio) / torch.sum(noisy_audio ** 2.0)) clean_audio = (clean_audio * norm_factor).unsqueeze(0) noisy_audio = (noisy_audio * norm_factor).unsqueeze(0) assert clean_audio.size(1) == noisy_audio.size(1) if self.split: if clean_audio.size(1) >= self.segment_size: max_audio_start = clean_audio.size(1) - self.segment_size audio_start = random.randint(0, max_audio_start) clean_audio = clean_audio[:, audio_start: audio_start+self.segment_size] noisy_audio = noisy_audio[:, audio_start: audio_start+self.segment_size] else: clean_audio = torch.nn.functional.pad(clean_audio, (0, self.segment_size - clean_audio.size(1)), 'constant') noisy_audio = torch.nn.functional.pad(noisy_audio, (0, self.segment_size - noisy_audio.size(1)), 'constant') return (clean_audio.squeeze(), noisy_audio.squeeze()) def __len__(self): return len(self.audio_indexes)