ClearSep / dataset.py
Tianhao Wang
first commit
dbbd709
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :Waveformer-main
@File :dataset_online.py
@IDE :PyCharm
@Author :Aisaka/Hao Ma @SDU
@Date :2023/11/1 下午6:47
'''
import os
import random
import torch
import torchaudio
import torchaudio.transforms as AT
import csv
import json
import numpy as np
import librosa
def labels2caption(labels):
prefix = "The sound of " if len(labels) == 1 else "The sounds of "
caption = prefix + ', '.join(labels)
return caption
class CLAPSepDataSet(torch.utils.data.Dataset): # type: ignore
def __init__(self, data_list, dset='', silence_rate=0.05, chunk_dur=10, sr=None, resample_rate=None):
assert dset in ['train', 'val'], \
"`dset` must be one of ['train', 'val']"
self.dset = dset
self.silence_rate = silence_rate
self.chunk_dur = chunk_dur
self.data_meta = dict()
self.text_dict = dict()
with open(data_list, 'r', encoding='utf-8') as d:
reader = csv.reader(d, skipinitialspace=True)
for row in reader:
assert os.path.exists(row[0])
self.data_meta[row[0]] = row[1:]
label = ', '.join(row[1:])
if label not in self.text_dict:
self.text_dict[label] = []
self.text_dict[label].append(row[0])
# self.data_meta.pop('file_name')
self.augmentation = torchaudio.transforms.SpeedPerturbation(48000, [0.9, 1.1])
self.data_names = list(self.data_meta.keys())
if dset == 'val':
self.noise_names = []
for name in self.data_names:
noise_name = self.choose_other_samples(', '.join(self.data_meta[name]), 1)[0]
self.noise_names.append(noise_name)
if resample_rate is not None:
self.resampler = AT.Resample(sr, resample_rate)
self.sr = sr
self.resample_rate = resample_rate
else:
self.sr = sr
def __len__(self):
return len(self.data_names)
def choose_other_samples(self, target_text, num):
candidates = list(self.text_dict.keys())
candidates.remove(target_text)
chosen_text = random.sample(candidates, num)
chosen_samples = [random.choice(self.text_dict[text]) for text in chosen_text]
return chosen_samples
def load_wav(self, path):
max_length = self.sr * self.chunk_dur
wav = librosa.core.load(path, sr=self.sr)[0]
if len(wav) > max_length:
wav = wav[0:max_length]
# pad audio to max length, 10s for AudioCaps
if len(wav) < max_length:
wav = np.pad(wav, (0, max_length - len(wav)), 'constant')
return wav
def __getitem__(self, idx):
tgt_name = self.data_names[idx]
if self.dset =='train':
noise_name = tgt_name
while set(self.data_meta[noise_name]) & set(self.data_meta[tgt_name]):
noise_name = random.choice(self.data_names)
else:
noise_name = self.noise_names[idx]
snr = torch.zeros((1,))
# snr = (torch.rand((1,)) * 10 - 5) if self.dset == 'train' else torch.zeros((1,))
tgt = torch.tensor(self.load_wav(tgt_name)).unsqueeze(0)
noise = torch.tensor(self.load_wav(noise_name)).unsqueeze(0)
# assert not torch.isnan(tgt).any()
# assert not torch.isnan(noise).any()
mixed = torchaudio.functional.add_noise(tgt, noise, snr=snr)
assert not torch.isnan(mixed).any(), f"tgt: {tgt_name}, noise: {noise_name}"
pos_sample, _ = self.augmentation(self.resampler(tgt.squeeze()))
neg_sample, _ = self.augmentation(self.resampler(noise.squeeze()))
max_value = torch.max(torch.abs(mixed))
if max_value > 1:
tgt *= 0.9 / max_value
mixed *= 0.9 / max_value
tgt = tgt.squeeze()
mixed = mixed.squeeze()
tgt_cap = labels2caption(self.data_meta[tgt_name])
neg_cap = labels2caption(self.data_meta[noise_name])
mixed_resample = self.resampler(mixed)
# silence query
if self.dset =='train' and random.random() < self.silence_rate:
other_name = tgt_name
while set(self.data_meta[other_name]) & (set(self.data_meta[tgt_name]) | set(self.data_meta[noise_name])):
other_name = random.choice(self.data_names)
tgt = torch.zeros_like(mixed)
neg_cap = labels2caption(self.data_meta[tgt_name] + self.data_meta[noise_name])
tgt_cap = labels2caption(self.data_meta[other_name])
pos_sample, _ = self.augmentation(self.resampler(torch.tensor(self.load_wav(other_name))))
neg_sample, _ = self.augmentation(mixed_resample)
return mixed, mixed_resample, tgt_cap, neg_cap, tgt, self.pad_or_trim(pos_sample), self.pad_or_trim(neg_sample)
def pad_or_trim(self, wav_in):
target_len = 48000 * self.chunk_dur
if wav_in.size(0) < target_len:
wav_in = torch.nn.functional.pad(wav_in, (0, target_len - wav_in.size(0)))
elif wav_in.size(0) > target_len:
wav_in = wav_in[:target_len]
max_value = torch.max(torch.abs(wav_in))
if max_value > 1:
wav_in *= 0.9 / max_value
return wav_in
class CLAPSepDataEngineDataSet(torch.utils.data.Dataset): # type: ignore
def __init__(self, data_list, dset='', data_engine_json='', silence_rate=0.05, chunk_dur=10, sr=None, resample_rate=None):
assert dset in ['train', 'val'], \
"`dset` must be one of ['train', 'val']"
self.dset = dset
self.silence_rate = silence_rate
self.chunk_dur = chunk_dur
self.data_meta = dict()
with open(data_list, 'r', encoding='utf-8') as d:
reader = csv.reader(d, skipinitialspace=True)
for row in reader:
assert os.path.exists(row[0]), row[0]
self.data_meta[row[0]] = row[1:]
# self.data_meta.pop('file_name')
self.augmentation = torchaudio.transforms.SpeedPerturbation(48000, [0.9, 1.1])
self.data_names = list(self.data_meta.keys())
if dset == 'val':
self.noise_names = []
for name in self.data_names:
noise_name = name
while set(self.data_meta[noise_name]) & set(self.data_meta[name]):
noise_name = random.choice(self.data_names)
self.noise_names.append(noise_name)
self.data_engine_dict = {}
if os.path.exists(data_engine_json):
self.data_engine_dict = json.load(open(data_engine_json, 'r'))
if resample_rate is not None:
self.resampler = AT.Resample(sr, resample_rate)
self.sr = sr
self.resample_rate = resample_rate
else:
self.sr = sr
def __len__(self):
return len(self.data_names)
def load_wav(self, path):
max_length = self.sr * self.chunk_dur
wav = librosa.core.load(path, sr=self.sr)[0]
if len(wav) > max_length:
wav = wav[0:max_length]
# pad audio to max length, 10s for AudioCaps
if len(wav) < max_length:
wav = np.pad(wav, (0, max_length - len(wav)), 'constant')
return wav
def __getitem__(self, idx):
tgt_name = self.data_names[idx]
if self.dset =='train':
noise_name = tgt_name
while set(self.data_meta[noise_name]) & set(self.data_meta[tgt_name]):
noise_name = random.choice(self.data_names)
else:
noise_name = self.noise_names[idx]
snr = torch.zeros((1,))
# snr = (torch.rand((1,)) * 10 - 5) if self.dset == 'train' else torch.zeros((1,))
tgt = torch.tensor(self.load_wav(tgt_name)).unsqueeze(0)
noise = torch.tensor(self.load_wav(noise_name)).unsqueeze(0)
# assert not torch.isnan(tgt).any()
# assert not torch.isnan(noise).any()
mixed = torchaudio.functional.add_noise(tgt, noise, snr=snr)
# assert not torch.isnan(mixed).any(), f"tgt: {tgt_name}, noise: {noise_name}"
pos_sample, _ = self.augmentation(self.resampler(tgt.squeeze()))
noise = noise.squeeze()
max_value = torch.max(torch.abs(mixed))
if max_value > 1:
tgt *= 0.9 / max_value
mixed *= 0.9 / max_value
tgt = tgt.squeeze()
mixed = mixed.squeeze()
tgt_cap = labels2caption(self.data_meta[tgt_name])
neg_cap = labels2caption(self.data_meta[noise_name])
mixed_resample = self.resampler(mixed)
# A(A1, A2) + B, A1 as target, A2 + B as noise
# video = tgt_name.split('/')[-1][:-4]
# if self.dset =='train' and video in self.data_engine_dict and random.random() > 0.5:
# items = self.data_engine_dict[video]
# tgt_idx = random.choice(range(0, len(items)))
# tgt_item = items[tgt_idx]
# items.pop(tgt_idx)
# tgt = torch.tensor(self.load_wav(tgt_item[0]))
# max_value = torch.max(torch.abs(tgt))
# if max_value > 1:
# tgt *= 0.9 / max_value
# tgt_cap = tgt_item[1]
# if len(items) > 0:
# noises = [torch.tensor(self.load_wav(x[0])) for x in items]
# noises.append(noise)
# noise_caps = [neg_cap.replace('sound', 'sounds')] + [x[1] for x in items]
# noise = torch.mean(torch.stack(noises, dim=0), dim=0)
# neg_cap = ', '.join(noise_caps)
# A(A1, A2), A1 as target, others as noise
video = tgt_name.split('/')[-1][:-4]
if self.dset =='train' and video in self.data_engine_dict and random.random() > 0.5:
mixed = tgt
mixed_resample = self.resampler(mixed)
items = self.data_engine_dict[video]
tgt_idx = random.choice(range(0, len(items)))
tgt_item = items[tgt_idx]
items.pop(tgt_idx)
tgt = torch.tensor(self.load_wav(tgt_item[0]))
max_value = torch.max(torch.abs(tgt))
if max_value > 1:
tgt *= 0.9 / max_value
tgt_cap = tgt_item[1]
if len(items) > 0:
noises = [torch.tensor(self.load_wav(x[0])) for x in items]
noise_caps = [x[1] for x in items]
noise = torch.mean(torch.stack(noises, dim=0), dim=0)
neg_cap = labels2caption(noise_caps)
# silence query
elif self.dset =='train' and random.random() < self.silence_rate:
other_name = tgt_name
while set(self.data_meta[other_name]) & (set(self.data_meta[tgt_name]) | set(self.data_meta[noise_name])):
other_name = random.choice(self.data_names)
tgt = torch.zeros_like(mixed)
neg_cap = labels2caption(self.data_meta[tgt_name] + self.data_meta[noise_name])
tgt_cap = labels2caption(self.data_meta[other_name])
pos_sample, _ = self.augmentation(self.resampler(torch.tensor(self.load_wav(other_name))))
noise = mixed
neg_sample, _ = self.augmentation(self.resampler(noise))
return mixed, mixed_resample, tgt_cap, neg_cap, tgt, self.pad_or_trim(pos_sample), self.pad_or_trim(neg_sample)
def pad_or_trim(self, wav_in):
target_len = 48000 * self.chunk_dur
if wav_in.size(0) < target_len:
wav_in = torch.nn.functional.pad(wav_in, (0, target_len - wav_in.size(0)))
elif wav_in.size(0) > target_len:
wav_in = wav_in[:target_len]
max_value = torch.max(torch.abs(wav_in))
if max_value > 1:
wav_in *= 0.9 / max_value
return wav_in