Spaces:
Runtime error
Runtime error
| # Copyright (c) 2022 NVIDIA CORPORATION. | |
| # Licensed under the MIT license. | |
| import os | |
| import numpy as np | |
| from scipy.io.wavfile import read as wavread | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| import torch | |
| from torch.utils.data import Dataset | |
| from torch.utils.data.distributed import DistributedSampler | |
| import random | |
| random.seed(0) | |
| torch.manual_seed(0) | |
| np.random.seed(0) | |
| from torchvision import datasets, models, transforms | |
| import torchaudio | |
| class CleanNoisyPairDataset(Dataset): | |
| """ | |
| Create a Dataset of clean and noisy audio pairs. | |
| Each element is a tuple of the form (clean waveform, noisy waveform, file_id) | |
| """ | |
| def __init__(self, root='./', subset='training', crop_length_sec=0): | |
| super(CleanNoisyPairDataset).__init__() | |
| assert subset is None or subset in ["training", "testing"] | |
| self.crop_length_sec = crop_length_sec | |
| self.subset = subset | |
| N_clean = len(os.listdir(os.path.join(root, 'training_set/clean'))) | |
| N_noisy = len(os.listdir(os.path.join(root, 'training_set/noisy'))) | |
| assert N_clean == N_noisy | |
| if subset == "training": | |
| self.files = [(os.path.join(root, 'training_set/clean', 'fileid_{}.wav'.format(i)), | |
| os.path.join(root, 'training_set/noisy', 'fileid_{}.wav'.format(i))) for i in range(N_clean)] | |
| elif subset == "testing": | |
| sortkey = lambda name: '_'.join(name.split('_')[-2:]) # specific for dns due to test sample names | |
| _p = os.path.join(root, 'datasets/test_set/synthetic/no_reverb') # path for DNS | |
| clean_files = os.listdir(os.path.join(_p, 'clean')) | |
| noisy_files = os.listdir(os.path.join(_p, 'noisy')) | |
| clean_files.sort(key=sortkey) | |
| noisy_files.sort(key=sortkey) | |
| self.files = [] | |
| for _c, _n in zip(clean_files, noisy_files): | |
| assert sortkey(_c) == sortkey(_n) | |
| self.files.append((os.path.join(_p, 'clean', _c), | |
| os.path.join(_p, 'noisy', _n))) | |
| self.crop_length_sec = 0 | |
| else: | |
| raise NotImplementedError | |
| def __getitem__(self, n): | |
| fileid = self.files[n] | |
| clean_audio, sample_rate = torchaudio.load(fileid[0]) | |
| noisy_audio, sample_rate = torchaudio.load(fileid[1]) | |
| clean_audio, noisy_audio = clean_audio.squeeze(0), noisy_audio.squeeze(0) | |
| assert len(clean_audio) == len(noisy_audio) | |
| crop_length = int(self.crop_length_sec * sample_rate) | |
| assert crop_length < len(clean_audio) | |
| # random crop | |
| if self.subset != 'testing' and crop_length > 0: | |
| start = np.random.randint(low=0, high=len(clean_audio) - crop_length + 1) | |
| clean_audio = clean_audio[start:(start + crop_length)] | |
| noisy_audio = noisy_audio[start:(start + crop_length)] | |
| clean_audio, noisy_audio = clean_audio.unsqueeze(0), noisy_audio.unsqueeze(0) | |
| return (clean_audio, noisy_audio, fileid) | |
| def __len__(self): | |
| return len(self.files) | |
| def load_CleanNoisyPairDataset(root, subset, crop_length_sec, batch_size, sample_rate, num_gpus=1): | |
| """ | |
| Get dataloader with distributed sampling | |
| """ | |
| dataset = CleanNoisyPairDataset(root=root, subset=subset, crop_length_sec=crop_length_sec) | |
| kwargs = {"batch_size": batch_size, "num_workers": 4, "pin_memory": False, "drop_last": False} | |
| if num_gpus > 1: | |
| train_sampler = DistributedSampler(dataset) | |
| dataloader = torch.utils.data.DataLoader(dataset, sampler=train_sampler, **kwargs) | |
| else: | |
| dataloader = torch.utils.data.DataLoader(dataset, sampler=None, shuffle=True, **kwargs) | |
| return dataloader | |
| if __name__ == '__main__': | |
| import json | |
| with open('./configs/DNS-large-full.json') as f: | |
| data = f.read() | |
| config = json.loads(data) | |
| trainset_config = config["trainset_config"] | |
| trainloader = load_CleanNoisyPairDataset(**trainset_config, subset='training', batch_size=2, num_gpus=1) | |
| testloader = load_CleanNoisyPairDataset(**trainset_config, subset='testing', batch_size=2, num_gpus=1) | |
| print(len(trainloader), len(testloader)) | |
| for clean_audio, noisy_audio, fileid in trainloader: | |
| clean_audio = clean_audio.cuda() | |
| noisy_audio = noisy_audio.cuda() | |
| print(clean_audio.shape, noisy_audio.shape, fileid) | |
| break | |