| | import glob |
| | import os |
| | import random |
| |
|
| | import librosa |
| | import numpy as np |
| | import soundfile as sf |
| | import torch |
| | from numpy.random import default_rng |
| | from pydtmc import MarkovChain |
| | from sklearn.model_selection import train_test_split |
| | from torch.utils.data import Dataset |
| |
|
| | from config import CONFIG |
| |
|
| | np.random.seed(0) |
| | rng = default_rng() |
| |
|
| |
|
| | def load_audio( |
| | path, |
| | sample_rate: int = 16000, |
| | chunk_len=None, |
| | ): |
| | with sf.SoundFile(path) as f: |
| | sr = f.samplerate |
| | audio_len = f.frames |
| |
|
| | if chunk_len is not None and chunk_len < audio_len: |
| | start_index = torch.randint(0, audio_len - chunk_len, (1,))[0] |
| |
|
| | frames = f._prepare_read(start_index, start_index + chunk_len, -1) |
| | audio = f.read(frames, always_2d=True, dtype="float32") |
| |
|
| | else: |
| | audio = f.read(always_2d=True, dtype="float32") |
| |
|
| | if sr != sample_rate: |
| | audio = librosa.resample(np.squeeze(audio), sr, sample_rate)[:, np.newaxis] |
| |
|
| | return audio.T |
| |
|
| |
|
| | def pad(sig, length): |
| | if sig.shape[1] < length: |
| | pad_len = length - sig.shape[1] |
| | sig = torch.hstack((sig, torch.zeros((sig.shape[0], pad_len)))) |
| |
|
| | else: |
| | start = random.randint(0, sig.shape[1] - length) |
| | sig = sig[:, start:start + length] |
| | return sig |
| |
|
| |
|
| | class MaskGenerator: |
| | def __init__(self, is_train=True, probs=((0.9, 0.1), (0.5, 0.1), (0.5, 0.5))): |
| | ''' |
| | is_train: if True, mask generator for training otherwise for evaluation |
| | probs: a list of transition probability (p_N, p_L) for Markov Chain. Only allow 1 tuple if 'is_train=False' |
| | ''' |
| | self.is_train = is_train |
| | self.probs = probs |
| | self.mcs = [] |
| | if self.is_train: |
| | for prob in probs: |
| | self.mcs.append(MarkovChain([[prob[0], 1 - prob[0]], [1 - prob[1], prob[1]]], ['1', '0'])) |
| | else: |
| | assert len(probs) == 1 |
| | prob = self.probs[0] |
| | self.mcs.append(MarkovChain([[prob[0], 1 - prob[0]], [1 - prob[1], prob[1]]], ['1', '0'])) |
| |
|
| | def gen_mask(self, length, seed=0): |
| | if self.is_train: |
| | mc = random.choice(self.mcs) |
| | else: |
| | mc = self.mcs[0] |
| | mask = mc.walk(length - 1, seed=seed) |
| | mask = np.array(list(map(int, mask))) |
| | return mask |
| |
|
| |
|
| | class TestLoader(Dataset): |
| | def __init__(self): |
| | dataset_name = CONFIG.DATA.dataset |
| | self.mask = CONFIG.DATA.EVAL.masking |
| |
|
| | self.target_root = CONFIG.DATA.data_dir[dataset_name]['root'] |
| | txt_list = CONFIG.DATA.data_dir[dataset_name]['test'] |
| | self.data_list = self.load_txt(txt_list) |
| | if self.mask == 'real': |
| | trace_txt = glob.glob(os.path.join(CONFIG.DATA.EVAL.trace_path, '*.txt')) |
| | trace_txt.sort() |
| | self.trace_list = [1 - np.array(list(map(int, open(txt, 'r').read().strip('\n').split('\n')))) for txt in |
| | trace_txt] |
| | else: |
| | self.mask_generator = MaskGenerator(is_train=False, probs=CONFIG.DATA.EVAL.transition_probs) |
| |
|
| | self.sr = CONFIG.DATA.sr |
| | self.stride = CONFIG.DATA.stride |
| | self.window_size = CONFIG.DATA.window_size |
| | self.audio_chunk_len = CONFIG.DATA.audio_chunk_len |
| | self.p_size = CONFIG.DATA.EVAL.packet_size |
| | self.hann = torch.sqrt(torch.hann_window(self.window_size)) |
| |
|
| | def __len__(self): |
| | return len(self.data_list) |
| |
|
| | def load_txt(self, txt_list): |
| | target = [] |
| | with open(txt_list) as f: |
| | for line in f: |
| | target.append(os.path.join(self.target_root, line.strip('\n'))) |
| | target = list(set(target)) |
| | target.sort() |
| | return target |
| |
|
| | def __getitem__(self, index): |
| | target = load_audio(self.data_list[index], sample_rate=self.sr) |
| | target = target[:, :(target.shape[1] // self.p_size) * self.p_size] |
| |
|
| | sig = np.reshape(target, (-1, self.p_size)).copy() |
| | if self.mask == 'real': |
| | mask = self.trace_list[index % len(self.trace_list)] |
| | mask = np.repeat(mask, np.ceil(len(sig) / len(mask)), 0)[:len(sig)][:, np.newaxis] |
| | else: |
| | mask = self.mask_generator.gen_mask(len(sig), seed=index)[:, np.newaxis] |
| | sig *= mask |
| | sig = torch.tensor(sig).reshape(-1) |
| |
|
| | target = torch.tensor(target).squeeze(0) |
| |
|
| | sig_wav = sig.clone() |
| | target_wav = target.clone() |
| |
|
| | target = torch.stft(target, self.window_size, self.stride, window=self.hann, |
| | return_complex=False).permute(2, 0, 1) |
| | sig = torch.stft(sig, self.window_size, self.stride, window=self.hann, return_complex=False).permute(2, 0, 1) |
| | return sig.float(), target.float(), sig_wav, target_wav |
| |
|
| |
|
| | class BlindTestLoader(Dataset): |
| | def __init__(self, test_dir): |
| | self.data_list = glob.glob(os.path.join(test_dir, '*.wav')) |
| | self.sr = CONFIG.DATA.sr |
| | self.stride = CONFIG.DATA.stride |
| | self.chunk_len = CONFIG.DATA.window_size |
| | self.hann = torch.sqrt(torch.hann_window(self.chunk_len)) |
| |
|
| | def __len__(self): |
| | return len(self.data_list) |
| |
|
| | def __getitem__(self, index): |
| | sig = load_audio(self.data_list[index], sample_rate=self.sr) |
| | sig = torch.from_numpy(sig).squeeze(0) |
| | sig = torch.stft(sig, self.chunk_len, self.stride, window=self.hann, return_complex=False).permute(2, 0, 1) |
| | return sig.float() |
| |
|
| |
|
| | class TrainDataset(Dataset): |
| |
|
| | def __init__(self, mode='train'): |
| | dataset_name = CONFIG.DATA.dataset |
| | self.target_root = CONFIG.DATA.data_dir[dataset_name]['root'] |
| |
|
| | txt_list = CONFIG.DATA.data_dir[dataset_name]['train'] |
| | self.data_list = self.load_txt(txt_list) |
| |
|
| | if mode == 'train': |
| | self.data_list, _ = train_test_split(self.data_list, test_size=CONFIG.TRAIN.val_split, random_state=0) |
| |
|
| | elif mode == 'val': |
| | _, self.data_list = train_test_split(self.data_list, test_size=CONFIG.TRAIN.val_split, random_state=0) |
| |
|
| | self.p_sizes = CONFIG.DATA.TRAIN.packet_sizes |
| | self.mode = mode |
| | self.sr = CONFIG.DATA.sr |
| | self.window = CONFIG.DATA.audio_chunk_len |
| | self.stride = CONFIG.DATA.stride |
| | self.chunk_len = CONFIG.DATA.window_size |
| | self.hann = torch.sqrt(torch.hann_window(self.chunk_len)) |
| | self.mask_generator = MaskGenerator(is_train=True, probs=CONFIG.DATA.TRAIN.transition_probs) |
| |
|
| | def __len__(self): |
| | return len(self.data_list) |
| |
|
| | def load_txt(self, txt_list): |
| | target = [] |
| | with open(txt_list) as f: |
| | for line in f: |
| | target.append(os.path.join(self.target_root, line.strip('\n'))) |
| | target = list(set(target)) |
| | target.sort() |
| | return target |
| |
|
| | def fetch_audio(self, index): |
| | sig = load_audio(self.data_list[index], sample_rate=self.sr, chunk_len=self.window) |
| | while sig.shape[1] < self.window: |
| | idx = torch.randint(0, len(self.data_list), (1,))[0] |
| | pad_len = self.window - sig.shape[1] |
| | if pad_len < 0.02 * self.sr: |
| | padding = np.zeros((1, pad_len), dtype=np.float) |
| | else: |
| | padding = load_audio(self.data_list[idx], sample_rate=self.sr, chunk_len=pad_len) |
| | sig = np.hstack((sig, padding)) |
| | return sig |
| |
|
| | def __getitem__(self, index): |
| | sig = self.fetch_audio(index) |
| |
|
| | sig = sig.reshape(-1).astype(np.float32) |
| |
|
| | target = torch.tensor(sig.copy()) |
| | p_size = random.choice(self.p_sizes) |
| |
|
| | sig = np.reshape(sig, (-1, p_size)) |
| | mask = self.mask_generator.gen_mask(len(sig), seed=index)[:, np.newaxis] |
| | sig *= mask |
| | sig = torch.tensor(sig.copy()).reshape(-1) |
| |
|
| | target = torch.stft(target, self.chunk_len, self.stride, window=self.hann, |
| | return_complex=False).permute(2, 0, 1).float() |
| | sig = torch.stft(sig, self.chunk_len, self.stride, window=self.hann, return_complex=False) |
| | sig = sig.permute(2, 0, 1).float() |
| | return sig, target |
| |
|