| | import os
|
| | import numpy as np
|
| | import torch
|
| | from torch.utils.data import Dataset, DataLoader, Sampler
|
| | from tqdm import tqdm
|
| | import librosa
|
| | import logging
|
| | import argparse
|
| | import json
|
| | import time
|
| | import torchaudio
|
| | from torchvision import transforms
|
| | import pickle
|
| | import random
|
| |
|
| | def configure_logging():
|
| | logging.basicConfig(level=logging.DEBUG,
|
| | format='%(asctime)s - %(levelname)s - %(message)s',
|
| | handlers=[
|
| | logging.StreamHandler()
|
| | ])
|
| | logging.info("Logging is set up.")
|
| | print("Logging is set up.")
|
| |
|
| | def parse_args():
|
| | parser = argparse.ArgumentParser(description='Spectrogram Dataset Preparation')
|
| | parser.add_argument('--config', type=str, required=True, help='Path to the config file')
|
| | return parser.parse_args()
|
| |
|
| | def load_config(config_path):
|
| | logging.info(f"Loading configuration from {config_path}")
|
| | print(f"Loading configuration from {config_path}")
|
| | try:
|
| | with open(config_path, 'r') as f:
|
| | config = json.load(f)
|
| | logging.info("Configuration loaded successfully")
|
| | print("Configuration loaded successfully")
|
| | return config
|
| | except Exception as e:
|
| | logging.error(f"Failed to load config file: {e}", exc_info=True)
|
| | print(f"Failed to load config file: {e}")
|
| | raise
|
| |
|
| | def validate_audio(y, sr, target_sr=44100, min_duration=0.1):
|
| | logging.debug(f"Validating audio with sr={sr}, target_sr={target_sr}, min_duration={min_duration}")
|
| | print(f"Validating audio with sr={sr}, target_sr={target_sr}, min_duration={min_duration}")
|
| | if sr != target_sr:
|
| | logging.warning(f"Resampling from {sr} to {target_sr}")
|
| | print(f"Resampling from {sr} to {target_sr}")
|
| | y = librosa.resample(y, orig_sr=sr, target_sr=target_sr)
|
| | if len(y) < min_duration * target_sr:
|
| | pad_length = int(min_duration * target_sr - len(y))
|
| | y = np.pad(y, (0, pad_length), mode='constant')
|
| | logging.info(f"Audio file padded with {pad_length} samples")
|
| | print(f"Audio file padded with {pad_length} samples")
|
| | return y, target_sr
|
| |
|
| | def strip_silence(y, sr, top_db=20, pad_duration=0.1):
|
| | logging.debug(f"Stripping silence with sr={sr}, top_db={top_db}, pad_duration={pad_duration}")
|
| | print(f"Stripping silence with sr={sr}, top_db={top_db}, pad_duration={pad_duration}")
|
| | y_trimmed, _ = librosa.effects.trim(y, top_db=top_db)
|
| | pad_length = int(pad_duration * sr)
|
| | y_padded = np.pad(y_trimmed, pad_length, mode='constant')
|
| | return y_padded
|
| |
|
| | def audio_to_spectrogram(file_path, n_fft=2048, hop_length=256, n_mels=128, target_sr=44100, min_duration=0.1):
|
| | try:
|
| | logging.info(f"Loading file: {file_path}")
|
| | print(f"Loading file: {file_path}")
|
| | y, sr = librosa.load(file_path, sr=None)
|
| | logging.debug(f"Loaded file: {file_path} with sr={sr}")
|
| | print(f"Loaded file: {file_path} with sr={sr}")
|
| | y, sr = validate_audio(y, sr, target_sr, min_duration)
|
| | y = strip_silence(y, sr)
|
| | except Exception as e:
|
| | logging.error(f"Error reading {file_path}: {e}", exc_info=True)
|
| | print(f"Error reading {file_path}: {e}")
|
| | return None
|
| |
|
| | y = librosa.util.normalize(y)
|
| | S = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels)
|
| | S_dB = librosa.power_to_db(S, ref=np.max)
|
| | logging.debug(f"Generated spectrogram for file: {file_path}")
|
| | print(f"Generated spectrogram for file: {file_path}")
|
| |
|
| | return S_dB
|
| |
|
| | def validate_spectrogram(spectrogram, n_mels=128):
|
| | logging.debug(f"Validating spectrogram with n_mels={n_mels}")
|
| | print(f"Validating spectrogram with n_mels={n_mels}")
|
| | if spectrogram.shape[0] != n_mels:
|
| | raise ValueError(f"Spectrogram has incorrect number of mel bands: {spectrogram.shape[0]}")
|
| | if spectrogram.shape[1] == 0:
|
| | raise ValueError("Spectrogram has zero frames")
|
| | return True
|
| |
|
| | def save_spectrogram(spectrogram, save_path):
|
| | os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| | np.save(save_path, spectrogram)
|
| | logging.debug(f"Spectrogram saved at {save_path}")
|
| | print(f"Spectrogram saved at {save_path}")
|
| |
|
| | class AddNoise(torch.nn.Module):
|
| | def __init__(self, noise_type='white', snr=10):
|
| | super(AddNoise, self).__init__()
|
| | self.noise_type = noise_type
|
| | self.snr = snr
|
| |
|
| | def forward(self, waveform):
|
| | noise = torch.randn_like(waveform)
|
| | signal_power = waveform.norm(p=2)
|
| | noise_power = noise.norm(p=2)
|
| | noise = noise * (signal_power / noise_power) / (10 ** (self.snr / 20))
|
| | return waveform + noise
|
| |
|
| | class SpectrogramDataset(Dataset):
|
| | def __init__(self, config, directory, process_new=True):
|
| | logging.info("Initializing SpectrogramDataset...")
|
| | print("Initializing SpectrogramDataset...")
|
| | self.directory = directory
|
| | self.output_directory = config['output_directory']
|
| | self.spectrograms = []
|
| | self.labels = []
|
| | self.label_to_index = {}
|
| | self.process_new = process_new
|
| | self.config = config
|
| |
|
| |
|
| | self.cache_path = os.path.join(self.output_directory, 'cache_data.npy')
|
| | self.dataset_path = os.path.join(self.output_directory, 'spectrogram_dataset.pkl')
|
| |
|
| | if os.path.exists(self.dataset_path):
|
| | self.load_dataset()
|
| | else:
|
| | if os.path.exists(self.cache_path):
|
| | os.remove(self.cache_path)
|
| | logging.info(f"Cache cleared at {self.cache_path}")
|
| | print(f"Cache cleared at {self.cache_path}")
|
| |
|
| | self.load_data()
|
| | self.save_dataset()
|
| |
|
| | self.transforms = transforms.Compose([
|
| | torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
|
| | torchaudio.transforms.TimeMasking(time_mask_param=30)
|
| | ]) if self.config['augment'] else None
|
| |
|
| | self.audio_transforms = torch.nn.Sequential(
|
| | AddNoise(snr=self.config['noise_snr']),
|
| | torchaudio.transforms.PitchShift(self.config['sample_rate'], n_steps=self.config['pitch_steps'])
|
| | ) if self.config['augment'] else None
|
| | logging.info("SpectrogramDataset initialized successfully")
|
| | print("SpectrogramDataset initialized successfully")
|
| |
|
| | def save_dataset(self):
|
| | with open(self.dataset_path, 'wb') as f:
|
| | pickle.dump(self, f)
|
| | logging.info(f"Dataset object saved at {self.dataset_path}")
|
| | print(f"Dataset object saved at {self.dataset_path}")
|
| |
|
| | def load_dataset(self):
|
| | with open(self.dataset_path, 'rb') as f:
|
| | obj = pickle.load(f)
|
| | self.__dict__.update(obj.__dict__)
|
| | logging.info(f"Dataset object loaded from {self.dataset_path}")
|
| | print(f"Dataset object loaded from {self.dataset_path}")
|
| |
|
| | def process_file(self, file_path):
|
| | logging.debug(f"Processing file: {file_path}")
|
| | print(f"Processing file: {file_path}")
|
| | try:
|
| | label = os.path.basename(os.path.dirname(file_path))
|
| | if label not in self.label_to_index:
|
| | self.label_to_index[label] = len(self.label_to_index)
|
| | relative_path = os.path.relpath(file_path, self.directory)
|
| | spectrogram_path = os.path.join(self.output_directory, os.path.splitext(relative_path)[0] + '_spectrogram.npy')
|
| | if not os.path.exists(spectrogram_path) and self.process_new:
|
| | spectrogram = audio_to_spectrogram(file_path, n_fft=self.config['n_fft'], hop_length=self.config['hop_length'], n_mels=self.config['n_mels'], target_sr=self.config['sample_rate'], min_duration=self.config['min_duration'])
|
| | if spectrogram is not None:
|
| | if spectrogram.shape[1] > self.config['max_frames']:
|
| | spectrogram = spectrogram[:, :self.config['max_frames']]
|
| | try:
|
| | validate_spectrogram(spectrogram, n_mels=self.config['n_mels'])
|
| | save_spectrogram(spectrogram, spectrogram_path)
|
| | logging.debug(f"Spectrogram saved: {spectrogram_path}")
|
| | print(f"Spectrogram saved: {spectrogram_path}")
|
| | except Exception as e:
|
| | logging.error(f"Error validating/saving spectrogram: {e}", exc_info=True)
|
| | print(f"Error validating/saving spectrogram: {e}")
|
| | if os.path.exists(spectrogram_path):
|
| | try:
|
| | spectrogram = np.load(spectrogram_path)
|
| | validate_spectrogram(spectrogram, n_mels=self.config['n_mels'])
|
| | spectrogram_tensor = torch.tensor(spectrogram, dtype=torch.float32)
|
| | self.spectrograms.append(spectrogram_tensor)
|
| | self.labels.append(self.label_to_index[label])
|
| | logging.debug(f"Spectrogram loaded and appended for file: {file_path}")
|
| | print(f"Spectrogram loaded and appended for file: {file_path}")
|
| | except Exception as e:
|
| | logging.error(f"Error loading spectrogram {spectrogram_path}: {e}", exc_info=True)
|
| | print(f"Error loading spectrogram {spectrogram_path}: {e}")
|
| | except Exception as e:
|
| | logging.error(f"Exception in process_file: {e}", exc_info=True)
|
| | print(f"Exception in process_file: {e}")
|
| |
|
| | def load_data(self):
|
| | start_time = time.time()
|
| | logging.info("Starting to load and process files...")
|
| | print("Starting to load and process files...")
|
| | files_to_process = [os.path.join(root, file) for root, _, files in os.walk(self.directory) for file in files if file.lower().endswith('.wav')]
|
| | total_files = len(files_to_process)
|
| | logging.info(f"Total files to process: {total_files}")
|
| | print(f"Total files to process: {total_files}")
|
| |
|
| | for file_path in tqdm(files_to_process, desc="Processing files"):
|
| | self.process_file(file_path)
|
| |
|
| | end_time = time.time()
|
| | logging.info(f"Data loading and processing took {end_time - start_time:.2f} seconds")
|
| | print(f"Data loading and processing took {end_time - start_time:.2f} seconds")
|
| |
|
| | self.save_cached_data(self.cache_path)
|
| |
|
| | def save_cached_data(self, cache_path):
|
| | os.makedirs(os.path.dirname(cache_path), exist_ok=True)
|
| | np.save(cache_path, {'spectrograms': self.spectrograms, 'labels': self.labels})
|
| | logging.debug(f"Cached data saved at {cache_path}")
|
| | print(f"Cached data saved at {cache_path}")
|
| |
|
| | def __len__(self):
|
| | return len(self.spectrograms)
|
| |
|
| | def __getitem__(self, idx):
|
| | spectrogram, label = self.spectrograms[idx], self.labels[idx]
|
| | if self.config['augment']:
|
| | if spectrogram.shape[1] >= 256:
|
| | spectrogram = self.audio_transforms(spectrogram.unsqueeze(0)).squeeze(0)
|
| | spectrogram = self.transforms(spectrogram.unsqueeze(0)).squeeze(0)
|
| | return spectrogram, label
|
| |
|
| | def collate_fn(batch):
|
| | spectrograms, labels = zip(*batch)
|
| | labels = torch.tensor(labels, dtype=torch.long)
|
| | max_length = max(s.size(1) for s in spectrograms)
|
| | max_freq = max(s.size(0) for s in spectrograms)
|
| | spectrograms_padded = torch.zeros(len(spectrograms), max_freq, max_length)
|
| | for i, s in enumerate(spectrograms):
|
| | if s.dim() == 3 and s.size(2) == 1:
|
| | s = s.squeeze(2)
|
| | spectrograms_padded[i, :s.size(0), :s.size(1)] = s
|
| | return spectrograms_padded, labels
|
| |
|
| | class SmartBatchingSampler(Sampler):
|
| | def __init__(self, data_source, batch_size):
|
| | self.data_source = data_source
|
| | self.batch_size = batch_size
|
| |
|
| | def __iter__(self):
|
| | sorted_indices = sorted(range(len(self.data_source)), key=lambda i: self.data_source[i][0].shape[1], reverse=True)
|
| | pooled_indices = [sorted_indices[i:i + self.batch_size] for i in range(0, len(sorted_indices), self.batch_size)]
|
| | random.shuffle(pooled_indices)
|
| | for p in pooled_indices:
|
| | yield from p
|
| | if len(sorted_indices) % self.batch_size != 0:
|
| | yield from sorted_indices[-(len(sorted_indices) % self.batch_size):]
|
| |
|
| | def __len__(self):
|
| | return len(self.data_source) // self.batch_size
|
| |
|
| | if __name__ == '__main__':
|
| | print("Starting script")
|
| | try:
|
| | args = parse_args()
|
| | print(f"Arguments parsed: {args}")
|
| | config = load_config(args.config)
|
| | print(f"Config loaded: {config}")
|
| |
|
| | configure_logging()
|
| | print("Logging configured")
|
| |
|
| | logging.info("Script started.")
|
| | dataset = SpectrogramDataset(config, config['directory'], process_new=True)
|
| | dataloader = DataLoader(dataset, batch_size=config['batch_size'], collate_fn=collate_fn, sampler=SmartBatchingSampler(dataset, config['batch_size']))
|
| | for batch in dataloader:
|
| | spectrograms, labels = batch
|
| | logging.info(f"Spectrograms batch shape: {spectrograms.shape}")
|
| | logging.info(f"Labels batch shape: {labels.shape}")
|
| | print(f"Spectrograms batch shape: {spectrograms.shape}")
|
| | print(f"Labels batch shape: {labels.shape}")
|
| | break
|
| |
|
| | logging.info(f"Total files processed: {len(dataset)}")
|
| | print(f"Total files processed: {len(dataset)}")
|
| | except Exception as e:
|
| | logging.error(f"Exception occurred: {e}", exc_info=True)
|
| | print(f"Exception occurred: {e}")
|
| | finally:
|
| | logging.info("Script ended.")
|
| | print("Script ended") |