|
|
import os |
|
|
import torchaudio |
|
|
import torch |
|
|
import numpy as np |
|
|
import soundfile |
|
|
class AudioLoader: |
|
|
def __init__(self, sample_rate=16000): |
|
|
self.sample_rate = sample_rate |
|
|
|
|
|
def load_audio(self, file_path): |
|
|
audio, sample_rate = torchaudio.load(file_path,backend='soundfile') |
|
|
if sample_rate != self.sample_rate: |
|
|
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)(audio) |
|
|
return audio.squeeze(0) |
|
|
|
|
|
class STFT: |
|
|
def __init__(self, n_fft=1024, hop_length=512, win_length=1024): |
|
|
self.n_fft = n_fft |
|
|
self.hop_length = hop_length |
|
|
self.win_length = win_length |
|
|
|
|
|
def compute_stft(self, signal): |
|
|
return torch.stft(signal, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=torch.hamming_window(self.win_length), return_complex=True) |
|
|
|
|
|
class SpectrogramSaver: |
|
|
@staticmethod |
|
|
def save_spectrogram(spectrogram, save_path): |
|
|
torch.save(spectrogram, save_path) |
|
|
|
|
|
class Preprocessing: |
|
|
def __init__(self, sample_rate=16000, n_fft=1024, hop_length=512, win_length=1024): |
|
|
self.loader = AudioLoader(sample_rate) |
|
|
self.stft = STFT(n_fft, hop_length, win_length) |
|
|
self.saver = SpectrogramSaver() |
|
|
self.fixed_length = None |
|
|
|
|
|
def preprocess(self, signal): |
|
|
spectrogram = self.stft.compute_stft(signal) |
|
|
real = spectrogram.real |
|
|
imag = spectrogram.imag |
|
|
combined = torch.stack((real, imag), dim=-1) |
|
|
return combined |
|
|
|
|
|
def determine_fixed_length(self, noisy_dir): |
|
|
lengths = [] |
|
|
noisy_files = [os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir) if f.endswith('.wav')] |
|
|
|
|
|
for noisy_file in noisy_files: |
|
|
noisy_audio = self.loader.load_audio(noisy_file) |
|
|
noisy_spectrogram = self.preprocess(noisy_audio) |
|
|
lengths.append(noisy_spectrogram.shape[1]) |
|
|
|
|
|
self.fixed_length = int(np.median(lengths)) |
|
|
print(f"Determined fixed length: {self.fixed_length}") |
|
|
|
|
|
def create_dataset(self, noisy_dir, save_dir): |
|
|
if self.fixed_length is None: |
|
|
self.determine_fixed_length(noisy_dir) |
|
|
|
|
|
noisy_save_dir = os.path.join(save_dir, 'noisy') |
|
|
|
|
|
if not os.path.exists(noisy_save_dir): |
|
|
os.makedirs(noisy_save_dir) |
|
|
|
|
|
noisy_files = [os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir) if f.endswith('.wav')] |
|
|
|
|
|
for noisy_file in noisy_files: |
|
|
noisy_audio = self.loader.load_audio(noisy_file) |
|
|
noisy_spectrogram = self.preprocess(noisy_audio) |
|
|
noisy_spectrogram = self.pad_spectrogram(noisy_spectrogram) |
|
|
noisy_save_path = os.path.join(noisy_save_dir, f"noisy_{os.path.basename(noisy_file).split('.')[0]}.pt") |
|
|
self.saver.save_spectrogram(noisy_spectrogram, noisy_save_path) |
|
|
|
|
|
def pad_spectrogram(self, spectrogram): |
|
|
pad_length = self.fixed_length - spectrogram.shape[1] |
|
|
if pad_length > 0: |
|
|
pad = torch.zeros((spectrogram.shape[0], pad_length, spectrogram.shape[2])) |
|
|
spectrogram = torch.cat((spectrogram, pad), dim=1) |
|
|
elif pad_length < 0: |
|
|
spectrogram = spectrogram[:, :self.fixed_length, :] |
|
|
return spectrogram |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|