| | import torch |
| | import torchaudio |
| | import torchaudio.transforms as T |
| | import torch.nn.functional as F |
| | import torchaudio.functional as AF |
| |
|
| | import numpy as np |
| | import pandas as pd |
| | import matplotlib.pyplot as plt |
| | from pathlib import Path |
| | import random |
| |
|
| | import noisereduce as nr |
| | import librosa |
| |
|
| |
|
| | import scipy |
| |
|
| | import pickle |
| | import os |
| | from tqdm import tqdm |
| |
|
| |
|
| | class Load: |
| | """Loads an audio signal into memory in normalized form""" |
| | def __init__(self): |
| | pass |
| |
|
| | def load(self, file_path): |
| | signal, sample_rate = torchaudio.load(file_path, channels_first=True, normalize=True) |
| | return signal, sample_rate |
| |
|
| | class StereoToMono: |
| | """Applies mapping from stereo to mono""" |
| | def __init__(self): |
| | pass |
| | |
| | def stereo_to_mono(self, stereo_signal): |
| | mono_signal = stereo_signal.mean(dim=0, keepdim=True) |
| | return mono_signal |
| | |
| | class Resample: |
| | """Applies resampling onto a signal""" |
| | def __init__(self): |
| | self.sr_in = None |
| | self.sr_out = None |
| |
|
| | def resample(self, signal, sr_in, sr_out, debug = True): |
| | self.sr_in = sr_in |
| | self.sr_out = sr_out |
| | if sr_in == sr_out: |
| | print('No remsampling needed') if debug else None |
| | return signal, sr_out |
| | print('Resampling the signal...') |
| | resampler = torchaudio.transforms.Resample(orig_freq=self.sr_in, new_freq=self.sr_out) |
| | return resampler(signal), sr_out |
| |
|
| | class NoiseRemover: |
| | def __init__(self): |
| | self._sr = None |
| | self._signal = None |
| | self._denoised_signal = None |
| | |
| | def remove_noise(self, signal, sr): |
| | self._sr = sr |
| | signal = signal.squeeze(0).numpy() |
| | self._signal = signal |
| | denoised = nr.reduce_noise(y = signal, sr = sr) |
| | self._denoised_signal = torch.tensor(denoised).unsqueeze(0) |
| | return self._denoised_signal,sr |
| | |
| |
|
| | class TruncateOrPad: |
| | """Dynamically truncates or pads depending on the signal""" |
| | def __init__(self, max_duration: int, sr_out: int = 16_000): |
| | self.max_duration = max_duration |
| | self.sr_out = sr_out |
| | self.tot_samples_expected = sr_out * max_duration |
| |
|
| | def truncate_or_pad(self, signal, debug = True): |
| | tot_samples = signal.shape[-1] |
| | if tot_samples == self.tot_samples_expected: |
| | print('Signal already at max duration') if debug else None |
| | return signal |
| | elif tot_samples > self.tot_samples_expected: |
| | print('Truncating the signal') |
| | return self._truncate(signal) |
| | else: |
| | print('Padding the signal') |
| | return self._pad(signal) |
| |
|
| | def _truncate(self, signal): |
| | return signal[..., :self.tot_samples_expected] |
| |
|
| | def _pad(self, signal): |
| | pad_amount = self.tot_samples_expected - signal.shape[-1] |
| | return F.pad(signal, (0, pad_amount)) |
| |
|
| | class FeatureExtractor: |
| | """Extracts features: linear, log spectrograms, mel spectrograms""" |
| |
|
| | def __init__(self, n_fft=1024, hop_length=256, sr=16000, n_mels=80): |
| | self.n_fft = n_fft |
| | self.hop_length = hop_length |
| | self.sr = sr |
| | self.n_mels = n_mels |
| | self._window = torch.hann_window(n_fft) |
| |
|
| | def stft_spec(self, signal): |
| | return torch.stft( |
| | signal, |
| | n_fft=self.n_fft, |
| | hop_length=self.hop_length, |
| | window=self._window.to(device=signal.device, dtype=signal.dtype), |
| | center=True, |
| | return_complex=True |
| | ) |
| |
|
| | def linear_mag(self, signal): |
| | """stft -> abs""" |
| | return self.stft_spec(signal).abs() |
| |
|
| | def linear_power(self, signal): |
| | """stft -> abs -> **2""" |
| | return self.linear_mag(signal).pow(2) |
| |
|
| | def mel_scale(self, signal): |
| | """Mel spectrogram (power)""" |
| | mel_spec = torchaudio.transforms.MelSpectrogram( |
| | sample_rate=self.sr, |
| | n_fft=self.n_fft, |
| | hop_length=self.hop_length, |
| | n_mels=self.n_mels, |
| | center=True, |
| | power=2.0 |
| | )(signal) |
| | return mel_spec |
| |
|
| | def log_mag(self, signal, eps=1e-10): |
| | return 20 * torch.log10(self.linear_mag(signal) + eps) |
| |
|
| | def log_power(self, signal, eps=1e-10): |
| | return 10 * torch.log10(self.linear_power(signal) + eps) |
| |
|
| | def log_mel_scale(self, signal): |
| | """Log-mel spectrogram for classification""" |
| | mel_spec = self.mel_scale(signal) |
| | log_mel_spec = torchaudio.transforms.AmplitudeToDB(top_db=80)(mel_spec) |
| | return log_mel_spec |
| |
|
| |
|
| | class NormalizeFeatures: |
| | @staticmethod |
| | def min_max_normalize(mel: torch.Tensor): |
| | max_val = mel.max() |
| | min_val = mel.min() |
| | mel_norm = (mel - min_val) / (max_val - min_val + 1e-8) |
| | return mel_norm, min_val, max_val |
| |
|
| |
|
| | class BirdDatasetSaver: |
| |
|
| | def __init__(self, save_dir): |
| | self.save_dir = save_dir |
| | os.makedirs(save_dir, exist_ok=True) |
| |
|
| | def save(self, bird_category: str, audio_file_name: str, log_mel: torch.Tensor, mel_norm: torch.Tensor): |
| | category_path = os.path.join(self.save_dir, bird_category) |
| | classification_path = os.path.join(category_path, "classification") |
| | generation_path = os.path.join(category_path, "generation") |
| |
|
| | os.makedirs(classification_path, exist_ok=True) |
| | os.makedirs(generation_path, exist_ok=True) |
| |
|
| | stem = Path(audio_file_name).stem |
| | torch.save(log_mel, os.path.join(classification_path, f"{stem}_logmel.pt")) |
| | torch.save(mel_norm, os.path.join(generation_path, f"{stem}_mel.pt")) |
| |
|
| |
|
| | class PreprocessingPipeline: |
| | def __init__(self, save_dir, max_duration=4, sr_out=22050, n_fft=1024, hop_length=256, n_mels=80, debug = False): |
| | |
| | self.loader = Load() |
| | self.stereo2mono = StereoToMono() |
| | self.resampler = Resample() |
| | self.truncate_pad = TruncateOrPad(max_duration=max_duration, sr_out=sr_out) |
| | self.fe = FeatureExtractor(n_fft=n_fft, hop_length=hop_length, sr=sr_out, n_mels=n_mels) |
| | self.normer = NormalizeFeatures() |
| | self.saver = BirdDatasetSaver(save_dir) |
| | self.sr_out = sr_out |
| | self.debug = debug |
| | def process_file(self, bird_category, audio_file_path): |
| | audio_file_name = Path(audio_file_path).name |
| |
|
| | |
| | signal, sr = self.loader.load(audio_file_path) |
| | |
| | signal = self.stereo2mono.stereo_to_mono(signal) |
| | |
| | signal, sr = self.resampler.resample(signal, sr, self.sr_out, self.debug) |
| | |
| | signal = self.truncate_pad.truncate_or_pad(signal, self.debug) |
| | |
| | log_mel = self.fe.log_mel_scale(signal) |
| | mel = self.fe.mel_scale(signal) |
| | mel_norm, _, _ = self.normer.min_max_normalize(mel) |
| | |
| | self.saver.save(bird_category, audio_file_name, log_mel, mel_norm) |
| |
|
| | def process_dataset(self, root_dir): |
| | for bird_category in tqdm(os.listdir(root_dir)): |
| | category_path = os.path.join(root_dir, bird_category) |
| | if not os.path.isdir(category_path): |
| | continue |
| | for audio_file in os.listdir(category_path): |
| | if not audio_file.endswith(".wav"): |
| | continue |
| | audio_file_path = os.path.join(category_path, audio_file) |
| | self.process_file(bird_category, audio_file_path) |
| |
|