|
|
|
|
|
|
|
|
''' |
|
|
@Project :Waveformer-main |
|
|
@File :dataset_online.py |
|
|
@IDE :PyCharm |
|
|
@Author :Aisaka/Hao Ma @SDU |
|
|
@Date :2023/11/1 下午6:47 |
|
|
''' |
|
|
import os |
|
|
import random |
|
|
|
|
|
import torch |
|
|
import torchaudio |
|
|
import torchaudio.transforms as AT |
|
|
import csv |
|
|
import json |
|
|
import numpy as np |
|
|
import librosa |
|
|
|
|
|
|
|
|
def labels2caption(labels): |
|
|
prefix = "The sound of " if len(labels) == 1 else "The sounds of " |
|
|
caption = prefix + ', '.join(labels) |
|
|
return caption |
|
|
|
|
|
|
|
|
class CLAPSepDataSet(torch.utils.data.Dataset): |
|
|
|
|
|
def __init__(self, data_list, dset='', silence_rate=0.05, chunk_dur=10, sr=None, resample_rate=None): |
|
|
assert dset in ['train', 'val'], \ |
|
|
"`dset` must be one of ['train', 'val']" |
|
|
self.dset = dset |
|
|
self.silence_rate = silence_rate |
|
|
self.chunk_dur = chunk_dur |
|
|
self.data_meta = dict() |
|
|
self.text_dict = dict() |
|
|
with open(data_list, 'r', encoding='utf-8') as d: |
|
|
reader = csv.reader(d, skipinitialspace=True) |
|
|
for row in reader: |
|
|
assert os.path.exists(row[0]) |
|
|
self.data_meta[row[0]] = row[1:] |
|
|
label = ', '.join(row[1:]) |
|
|
if label not in self.text_dict: |
|
|
self.text_dict[label] = [] |
|
|
self.text_dict[label].append(row[0]) |
|
|
|
|
|
self.augmentation = torchaudio.transforms.SpeedPerturbation(48000, [0.9, 1.1]) |
|
|
|
|
|
self.data_names = list(self.data_meta.keys()) |
|
|
if dset == 'val': |
|
|
self.noise_names = [] |
|
|
for name in self.data_names: |
|
|
noise_name = self.choose_other_samples(', '.join(self.data_meta[name]), 1)[0] |
|
|
self.noise_names.append(noise_name) |
|
|
|
|
|
if resample_rate is not None: |
|
|
self.resampler = AT.Resample(sr, resample_rate) |
|
|
self.sr = sr |
|
|
self.resample_rate = resample_rate |
|
|
else: |
|
|
self.sr = sr |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data_names) |
|
|
|
|
|
def choose_other_samples(self, target_text, num): |
|
|
candidates = list(self.text_dict.keys()) |
|
|
candidates.remove(target_text) |
|
|
chosen_text = random.sample(candidates, num) |
|
|
chosen_samples = [random.choice(self.text_dict[text]) for text in chosen_text] |
|
|
return chosen_samples |
|
|
|
|
|
def load_wav(self, path): |
|
|
max_length = self.sr * self.chunk_dur |
|
|
wav = librosa.core.load(path, sr=self.sr)[0] |
|
|
if len(wav) > max_length: |
|
|
wav = wav[0:max_length] |
|
|
|
|
|
|
|
|
if len(wav) < max_length: |
|
|
wav = np.pad(wav, (0, max_length - len(wav)), 'constant') |
|
|
return wav |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
tgt_name = self.data_names[idx] |
|
|
if self.dset =='train': |
|
|
noise_name = tgt_name |
|
|
while set(self.data_meta[noise_name]) & set(self.data_meta[tgt_name]): |
|
|
noise_name = random.choice(self.data_names) |
|
|
else: |
|
|
noise_name = self.noise_names[idx] |
|
|
|
|
|
snr = torch.zeros((1,)) |
|
|
|
|
|
tgt = torch.tensor(self.load_wav(tgt_name)).unsqueeze(0) |
|
|
noise = torch.tensor(self.load_wav(noise_name)).unsqueeze(0) |
|
|
|
|
|
|
|
|
mixed = torchaudio.functional.add_noise(tgt, noise, snr=snr) |
|
|
assert not torch.isnan(mixed).any(), f"tgt: {tgt_name}, noise: {noise_name}" |
|
|
pos_sample, _ = self.augmentation(self.resampler(tgt.squeeze())) |
|
|
neg_sample, _ = self.augmentation(self.resampler(noise.squeeze())) |
|
|
|
|
|
max_value = torch.max(torch.abs(mixed)) |
|
|
if max_value > 1: |
|
|
tgt *= 0.9 / max_value |
|
|
mixed *= 0.9 / max_value |
|
|
|
|
|
tgt = tgt.squeeze() |
|
|
mixed = mixed.squeeze() |
|
|
tgt_cap = labels2caption(self.data_meta[tgt_name]) |
|
|
neg_cap = labels2caption(self.data_meta[noise_name]) |
|
|
mixed_resample = self.resampler(mixed) |
|
|
|
|
|
|
|
|
if self.dset =='train' and random.random() < self.silence_rate: |
|
|
other_name = tgt_name |
|
|
while set(self.data_meta[other_name]) & (set(self.data_meta[tgt_name]) | set(self.data_meta[noise_name])): |
|
|
other_name = random.choice(self.data_names) |
|
|
tgt = torch.zeros_like(mixed) |
|
|
neg_cap = labels2caption(self.data_meta[tgt_name] + self.data_meta[noise_name]) |
|
|
tgt_cap = labels2caption(self.data_meta[other_name]) |
|
|
pos_sample, _ = self.augmentation(self.resampler(torch.tensor(self.load_wav(other_name)))) |
|
|
neg_sample, _ = self.augmentation(mixed_resample) |
|
|
|
|
|
return mixed, mixed_resample, tgt_cap, neg_cap, tgt, self.pad_or_trim(pos_sample), self.pad_or_trim(neg_sample) |
|
|
|
|
|
def pad_or_trim(self, wav_in): |
|
|
target_len = 48000 * self.chunk_dur |
|
|
if wav_in.size(0) < target_len: |
|
|
wav_in = torch.nn.functional.pad(wav_in, (0, target_len - wav_in.size(0))) |
|
|
elif wav_in.size(0) > target_len: |
|
|
wav_in = wav_in[:target_len] |
|
|
max_value = torch.max(torch.abs(wav_in)) |
|
|
if max_value > 1: |
|
|
wav_in *= 0.9 / max_value |
|
|
return wav_in |
|
|
|
|
|
|
|
|
class CLAPSepDataEngineDataSet(torch.utils.data.Dataset): |
|
|
|
|
|
def __init__(self, data_list, dset='', data_engine_json='', silence_rate=0.05, chunk_dur=10, sr=None, resample_rate=None): |
|
|
assert dset in ['train', 'val'], \ |
|
|
"`dset` must be one of ['train', 'val']" |
|
|
self.dset = dset |
|
|
self.silence_rate = silence_rate |
|
|
self.chunk_dur = chunk_dur |
|
|
self.data_meta = dict() |
|
|
with open(data_list, 'r', encoding='utf-8') as d: |
|
|
reader = csv.reader(d, skipinitialspace=True) |
|
|
for row in reader: |
|
|
assert os.path.exists(row[0]), row[0] |
|
|
self.data_meta[row[0]] = row[1:] |
|
|
|
|
|
self.augmentation = torchaudio.transforms.SpeedPerturbation(48000, [0.9, 1.1]) |
|
|
|
|
|
self.data_names = list(self.data_meta.keys()) |
|
|
if dset == 'val': |
|
|
self.noise_names = [] |
|
|
for name in self.data_names: |
|
|
noise_name = name |
|
|
while set(self.data_meta[noise_name]) & set(self.data_meta[name]): |
|
|
noise_name = random.choice(self.data_names) |
|
|
self.noise_names.append(noise_name) |
|
|
|
|
|
self.data_engine_dict = {} |
|
|
if os.path.exists(data_engine_json): |
|
|
self.data_engine_dict = json.load(open(data_engine_json, 'r')) |
|
|
|
|
|
if resample_rate is not None: |
|
|
self.resampler = AT.Resample(sr, resample_rate) |
|
|
self.sr = sr |
|
|
self.resample_rate = resample_rate |
|
|
else: |
|
|
self.sr = sr |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data_names) |
|
|
|
|
|
def load_wav(self, path): |
|
|
max_length = self.sr * self.chunk_dur |
|
|
wav = librosa.core.load(path, sr=self.sr)[0] |
|
|
if len(wav) > max_length: |
|
|
wav = wav[0:max_length] |
|
|
|
|
|
|
|
|
if len(wav) < max_length: |
|
|
wav = np.pad(wav, (0, max_length - len(wav)), 'constant') |
|
|
return wav |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
tgt_name = self.data_names[idx] |
|
|
if self.dset =='train': |
|
|
noise_name = tgt_name |
|
|
while set(self.data_meta[noise_name]) & set(self.data_meta[tgt_name]): |
|
|
noise_name = random.choice(self.data_names) |
|
|
else: |
|
|
noise_name = self.noise_names[idx] |
|
|
|
|
|
snr = torch.zeros((1,)) |
|
|
|
|
|
tgt = torch.tensor(self.load_wav(tgt_name)).unsqueeze(0) |
|
|
noise = torch.tensor(self.load_wav(noise_name)).unsqueeze(0) |
|
|
|
|
|
|
|
|
mixed = torchaudio.functional.add_noise(tgt, noise, snr=snr) |
|
|
|
|
|
|
|
|
pos_sample, _ = self.augmentation(self.resampler(tgt.squeeze())) |
|
|
noise = noise.squeeze() |
|
|
|
|
|
max_value = torch.max(torch.abs(mixed)) |
|
|
if max_value > 1: |
|
|
tgt *= 0.9 / max_value |
|
|
mixed *= 0.9 / max_value |
|
|
|
|
|
tgt = tgt.squeeze() |
|
|
mixed = mixed.squeeze() |
|
|
tgt_cap = labels2caption(self.data_meta[tgt_name]) |
|
|
neg_cap = labels2caption(self.data_meta[noise_name]) |
|
|
mixed_resample = self.resampler(mixed) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video = tgt_name.split('/')[-1][:-4] |
|
|
if self.dset =='train' and video in self.data_engine_dict and random.random() > 0.5: |
|
|
mixed = tgt |
|
|
mixed_resample = self.resampler(mixed) |
|
|
items = self.data_engine_dict[video] |
|
|
tgt_idx = random.choice(range(0, len(items))) |
|
|
tgt_item = items[tgt_idx] |
|
|
items.pop(tgt_idx) |
|
|
tgt = torch.tensor(self.load_wav(tgt_item[0])) |
|
|
max_value = torch.max(torch.abs(tgt)) |
|
|
if max_value > 1: |
|
|
tgt *= 0.9 / max_value |
|
|
tgt_cap = tgt_item[1] |
|
|
if len(items) > 0: |
|
|
noises = [torch.tensor(self.load_wav(x[0])) for x in items] |
|
|
noise_caps = [x[1] for x in items] |
|
|
noise = torch.mean(torch.stack(noises, dim=0), dim=0) |
|
|
neg_cap = labels2caption(noise_caps) |
|
|
|
|
|
|
|
|
elif self.dset =='train' and random.random() < self.silence_rate: |
|
|
other_name = tgt_name |
|
|
while set(self.data_meta[other_name]) & (set(self.data_meta[tgt_name]) | set(self.data_meta[noise_name])): |
|
|
other_name = random.choice(self.data_names) |
|
|
tgt = torch.zeros_like(mixed) |
|
|
neg_cap = labels2caption(self.data_meta[tgt_name] + self.data_meta[noise_name]) |
|
|
tgt_cap = labels2caption(self.data_meta[other_name]) |
|
|
pos_sample, _ = self.augmentation(self.resampler(torch.tensor(self.load_wav(other_name)))) |
|
|
noise = mixed |
|
|
|
|
|
neg_sample, _ = self.augmentation(self.resampler(noise)) |
|
|
|
|
|
return mixed, mixed_resample, tgt_cap, neg_cap, tgt, self.pad_or_trim(pos_sample), self.pad_or_trim(neg_sample) |
|
|
|
|
|
def pad_or_trim(self, wav_in): |
|
|
target_len = 48000 * self.chunk_dur |
|
|
if wav_in.size(0) < target_len: |
|
|
wav_in = torch.nn.functional.pad(wav_in, (0, target_len - wav_in.size(0))) |
|
|
elif wav_in.size(0) > target_len: |
|
|
wav_in = wav_in[:target_len] |
|
|
max_value = torch.max(torch.abs(wav_in)) |
|
|
if max_value > 1: |
|
|
wav_in *= 0.9 / max_value |
|
|
return wav_in |
|
|
|
|
|
|