Spaces:
Sleeping
Sleeping
Added READ.ME and src/data files
Browse files- README.md +51 -0
- src/data/__pycache__/augment.cpython-311.pyc +0 -0
- src/data/__pycache__/datasets.cpython-311.pyc +0 -0
- src/data/augment.py +187 -0
- src/data/datasets.py +96 -0
- src/data/download.py +75 -0
README.md
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Environmental Sound Classification (ESC50) with Deep CNNs
|
| 2 |
+
|
| 3 |
+
A PyTorch reimplementation of the deep convolutional neural network approach from [Salamon & Bello (2017)](https://arxiv.org/pdf/1608.04363) for environmental sound classification, extended to handle 50 classes instead of the original 10.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
This project implements a deep CNN architecture for environmental sound classification using log-mel spectrograms as input features. The implementation follows the methodology described in the paper "Deep Convolutional Neural Networks and Data Augmentation for Environmental Sound Classification" but is scaled to work with a more challenging 50-class classification task.
|
| 8 |
+
|
| 9 |
+
### Key Features
|
| 10 |
+
|
| 11 |
+
- Deep CNN with 3 convolutional layers + 2 fully connected layers
|
| 12 |
+
- Log-mel spectrogram feature extraction using Essentia
|
| 13 |
+
- Data augmentation (time stretching, pitch shifting, dynamic range compression)
|
| 14 |
+
- Overlapping patch prediction for validation (1-frame hop)
|
| 15 |
+
|
| 16 |
+
## Results
|
| 17 |
+
|
| 18 |
+
| Dataset | Classes | Accuracy (with augmentation) |
|
| 19 |
+
|---------|---------|--------------------------------|
|
| 20 |
+
| UrbanSound8K (paper) | 10 | 79% |
|
| 21 |
+
| **This project** | **50** | **74%** |
|
| 22 |
+
|
| 23 |
+
## Architecture
|
| 24 |
+
|
| 25 |
+
### Model Structure
|
| 26 |
+
|
| 27 |
+
```
|
| 28 |
+
Input: Log-mel Spectrogram (128 × 128)
|
| 29 |
+
↓
|
| 30 |
+
Conv2D(1→24, 5×5) + ReLU + MaxPool(4×2)
|
| 31 |
+
↓
|
| 32 |
+
Conv2D(24→48, 5×5) + ReLU + MaxPool(4×2)
|
| 33 |
+
↓
|
| 34 |
+
Conv2D(48→48, 5×5) + ReLU
|
| 35 |
+
↓
|
| 36 |
+
Flatten → Dense(2400→64) + ReLU + Dropout(0.5)
|
| 37 |
+
↓
|
| 38 |
+
Dense(64→50) + Softmax
|
| 39 |
+
↓
|
| 40 |
+
Output: 50 classes
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
### Training Configuration
|
| 44 |
+
|
| 45 |
+
- **Optimizer**: SGD with momentum (0.9)
|
| 46 |
+
- **Learning Rate**: 0.01
|
| 47 |
+
- **Batch Size**: 100 TF-patches
|
| 48 |
+
- **L2 Regularization**: 0.001 (on classifier layers only)
|
| 49 |
+
- **Dropout**: 0.5 (on classifier layers)
|
| 50 |
+
- **Gradient Clipping**: max_norm=1.0
|
| 51 |
+
- **Epochs**: 100
|
src/data/__pycache__/augment.cpython-311.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
src/data/__pycache__/datasets.cpython-311.pyc
ADDED
|
Binary file (5.86 kB). View file
|
|
|
src/data/augment.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tqdm
|
| 2 |
+
import essentia.standard as es
|
| 3 |
+
import librosa
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
import soundfile as sf
|
| 7 |
+
|
| 8 |
+
sample_rate = 44100
|
| 9 |
+
|
| 10 |
+
parameters = {
|
| 11 |
+
"n_bands" : 128,
|
| 12 |
+
"n_mels" : 128,
|
| 13 |
+
"frame_size" : 1024,
|
| 14 |
+
"hop_size": 1024,
|
| 15 |
+
"sample_rate": sample_rate,
|
| 16 |
+
"fft_size": 8192,
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
def data_treatment(
|
| 20 |
+
audio_path,
|
| 21 |
+
n_bands, n_mels, frame_size, hop_size, sample_rate, fft_size
|
| 22 |
+
):
|
| 23 |
+
labels = []
|
| 24 |
+
log_mel_spectrograms = []
|
| 25 |
+
filenames = os.listdir(audio_path)
|
| 26 |
+
|
| 27 |
+
for filename in tqdm.tqdm(filenames, desc="Processing audio files"):
|
| 28 |
+
|
| 29 |
+
filename_splitted = filename.split("-")
|
| 30 |
+
label = filename_splitted[-1].split(".")[0]
|
| 31 |
+
label = label.split("_")[0]
|
| 32 |
+
labels.append(int(label))
|
| 33 |
+
|
| 34 |
+
file_path = os.path.join(audio_path, filename)
|
| 35 |
+
|
| 36 |
+
window = es.Windowing(type="hann")
|
| 37 |
+
spectrum = es.Spectrum(size=fft_size)
|
| 38 |
+
mel = es.MelBands(
|
| 39 |
+
numberBands=n_bands,
|
| 40 |
+
inputSize=fft_size//2 + 1,
|
| 41 |
+
sampleRate=sample_rate,
|
| 42 |
+
lowFrequencyBound=0,
|
| 43 |
+
highFrequencyBound=sample_rate / 2
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
loader = es.MonoLoader(filename=file_path)
|
| 47 |
+
audio = loader()
|
| 48 |
+
|
| 49 |
+
frames = es.FrameGenerator(audio, frameSize=frame_size, hopSize=hop_size)
|
| 50 |
+
log_mel_spectrogram = []
|
| 51 |
+
for frame in frames:
|
| 52 |
+
frame_padded = np.pad(frame, (0, fft_size - len(frame)), mode='constant')
|
| 53 |
+
windowed_frame = window(frame_padded)
|
| 54 |
+
spec = spectrum(windowed_frame)
|
| 55 |
+
mel_bands = mel(spec)
|
| 56 |
+
log_mel_spectrogram.append(mel_bands)
|
| 57 |
+
|
| 58 |
+
log_mel_spectrogram = np.array(log_mel_spectrogram)
|
| 59 |
+
|
| 60 |
+
mel_spectrogram_db = 10 * np.log10(log_mel_spectrogram + 1e-10)
|
| 61 |
+
max_db = mel_spectrogram_db.max()
|
| 62 |
+
mel_spectrogram_db = mel_spectrogram_db - max_db
|
| 63 |
+
|
| 64 |
+
log_mel_spectrograms.append(mel_spectrogram_db)
|
| 65 |
+
return log_mel_spectrograms, np.array(labels)
|
| 66 |
+
|
| 67 |
+
def pad(audio, target_seconds, sample_rate):
|
| 68 |
+
target_len = int(sample_rate * target_seconds)
|
| 69 |
+
n = len(audio)
|
| 70 |
+
|
| 71 |
+
if n < target_len:
|
| 72 |
+
audio = np.pad(audio, (0, target_len - n), mode="constant")
|
| 73 |
+
return audio
|
| 74 |
+
|
| 75 |
+
def time_stretch_augmentation(file_path, sample_rate, rate):
|
| 76 |
+
audio, _ = librosa.load(file_path, sr=sample_rate)
|
| 77 |
+
audio_timestretch = librosa.effects.time_stretch(audio.astype(np.float32), rate=rate)
|
| 78 |
+
return pad(audio_timestretch, 5, sample_rate)
|
| 79 |
+
|
| 80 |
+
def pitch_shift_augmentation(file_path, sample_rate, semitones):
|
| 81 |
+
audio, _ = librosa.load(file_path, sr=sample_rate)
|
| 82 |
+
return librosa.effects.pitch_shift(audio.astype(np.float32), sr=sample_rate, n_steps=semitones)
|
| 83 |
+
|
| 84 |
+
def drc_augmentation(file_path, sample_rate, compression):
|
| 85 |
+
if compression == "music_standard": threshold_db=-20; ratio=2.0; attack_ms=5; release_ms=50
|
| 86 |
+
elif compression == "film_standard": threshold_db=-25; ratio=4.0; attack_ms=10; release_ms= 100
|
| 87 |
+
elif compression == "speech": threshold_db=-18; ratio=3.0; attack_ms=2; release_ms= 40
|
| 88 |
+
elif compression == "radio": threshold_db=-15; ratio=3.5; attack_ms=1; release_ms= 200
|
| 89 |
+
|
| 90 |
+
audio, _ = librosa.load(file_path, sr=sample_rate)
|
| 91 |
+
threshold = 10**(threshold_db / 20)
|
| 92 |
+
|
| 93 |
+
attack_coeff = np.exp(-1.0 / (0.001 * attack_ms * sample_rate))
|
| 94 |
+
release_coeff = np.exp(-1.0 / (0.001 * release_ms * sample_rate))
|
| 95 |
+
|
| 96 |
+
audio_filtered = np.zeros_like(audio)
|
| 97 |
+
gain = 1.0
|
| 98 |
+
|
| 99 |
+
for n in range(len(audio)):
|
| 100 |
+
abs_audio = abs(audio[n])
|
| 101 |
+
if abs_audio > threshold:
|
| 102 |
+
desired_gain = (threshold / abs_audio) ** (ratio - 1)
|
| 103 |
+
else:
|
| 104 |
+
desired_gain = 1.0
|
| 105 |
+
|
| 106 |
+
if desired_gain < gain:
|
| 107 |
+
gain = attack_coeff * (gain - desired_gain) + desired_gain
|
| 108 |
+
else:
|
| 109 |
+
gain = release_coeff * (gain - desired_gain) + desired_gain
|
| 110 |
+
|
| 111 |
+
audio_filtered[n] = audio[n] * gain
|
| 112 |
+
|
| 113 |
+
return audio_filtered
|
| 114 |
+
|
| 115 |
+
def augment_dataset(audio_path, output_path, probability_list):
|
| 116 |
+
filenames = os.listdir(audio_path)
|
| 117 |
+
|
| 118 |
+
p1, p2, p3 = probability_list
|
| 119 |
+
os.makedirs(output_path, exist_ok=True)
|
| 120 |
+
|
| 121 |
+
for filename in tqdm.tqdm(filenames, desc="Processing audio files"):
|
| 122 |
+
|
| 123 |
+
augmentations = []
|
| 124 |
+
audio, _ = librosa.load(os.path.join(audio_path, filename), sr=sample_rate)
|
| 125 |
+
# TS
|
| 126 |
+
if np.random.rand() > p1:
|
| 127 |
+
stretch_rates = [0.81, 0.93, 1.07, 1.23]
|
| 128 |
+
stretch_rate = np.random.choice(stretch_rates)
|
| 129 |
+
audio = time_stretch_augmentation(os.path.join(audio_path, filename), sample_rate, stretch_rate)
|
| 130 |
+
augmentations.append(f"TS{stretch_rate}")
|
| 131 |
+
# PS
|
| 132 |
+
if np.random.rand() > p2:
|
| 133 |
+
semitones = [-3.5, -2.5, -2, -1, 1, 2.5, 3, 3.5]
|
| 134 |
+
semitone = np.random.choice(semitones)
|
| 135 |
+
audio = pitch_shift_augmentation(os.path.join(audio_path, filename), sample_rate, semitone)
|
| 136 |
+
augmentations.append(f"PS{semitone}")
|
| 137 |
+
|
| 138 |
+
# DRC
|
| 139 |
+
if np.random.rand() > p3:
|
| 140 |
+
compressions = ["radio", "film_standard", "music_standard", "speech"]
|
| 141 |
+
compression = np.random.choice(compressions)
|
| 142 |
+
audio = drc_augmentation(os.path.join(audio_path, filename), sample_rate, compression)
|
| 143 |
+
augmentations.append(f"DRC{compression}")
|
| 144 |
+
|
| 145 |
+
for aug in augmentations:
|
| 146 |
+
filename_splitted = filename.split(".")
|
| 147 |
+
filename = filename_splitted[0] + f"_{aug}." + filename_splitted[-1]
|
| 148 |
+
sf.write(os.path.join(output_path, filename), audio, 44100)
|
| 149 |
+
|
| 150 |
+
def create_augmented_datasets(input_path, output_path):
|
| 151 |
+
probability_lists = [
|
| 152 |
+
[0.0 , 1.0, 1.0],
|
| 153 |
+
[1.0 , 1.0, 0.0],
|
| 154 |
+
[1.0 , 0.0, 1.0],
|
| 155 |
+
[0.0 , 0.0, 0.0],
|
| 156 |
+
[0.5 , 0.5, 0.5]]
|
| 157 |
+
for i, probability_list in enumerate(probability_lists):
|
| 158 |
+
augmented_path = os.path.join(output_path, f"{i+1}")
|
| 159 |
+
os.makedirs(augmented_path, exist_ok=True)
|
| 160 |
+
augment_dataset(input_path, augmented_path, probability_list)
|
| 161 |
+
|
| 162 |
+
def create_log_mel(input_path, output_path):
|
| 163 |
+
directories = os.listdir(input_path)
|
| 164 |
+
X, y = [], []
|
| 165 |
+
|
| 166 |
+
for directory in directories:
|
| 167 |
+
log_mels, labels = data_treatment(os.path.join(input_path, directory), **parameters)
|
| 168 |
+
X.extend(log_mels)
|
| 169 |
+
y.extend(labels)
|
| 170 |
+
|
| 171 |
+
X_array = np.empty(len(X), dtype=object)
|
| 172 |
+
for i, spec in enumerate(X):
|
| 173 |
+
X_array[i] = spec
|
| 174 |
+
|
| 175 |
+
y = np.array(y)
|
| 176 |
+
os.makedirs(output_path, exist_ok=True)
|
| 177 |
+
|
| 178 |
+
np.save(os.path.join(output_path, "X.npy"), X_array, allow_pickle=True)
|
| 179 |
+
np.save(os.path.join(output_path, 'y.npy'), y)
|
| 180 |
+
return X, y
|
| 181 |
+
|
| 182 |
+
if __name__ == "__main__":
|
| 183 |
+
input_path = "data/audio/0"
|
| 184 |
+
output_base_path = "data/audio"
|
| 185 |
+
|
| 186 |
+
create_augmented_datasets(input_path, output_base_path)
|
| 187 |
+
|
src/data/datasets.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from torch.utils.data import Dataset
|
| 4 |
+
|
| 5 |
+
cnn_input_length = 128
|
| 6 |
+
|
| 7 |
+
class SpectrogramDataset(Dataset):
|
| 8 |
+
def __init__(self, spectrograms, labels, patch_length=cnn_input_length, mode='train'):
|
| 9 |
+
self.spectrograms = spectrograms
|
| 10 |
+
self.labels = labels
|
| 11 |
+
self.patch_length = patch_length
|
| 12 |
+
self.mode = mode
|
| 13 |
+
|
| 14 |
+
def __len__(self):
|
| 15 |
+
return len(self.labels)
|
| 16 |
+
|
| 17 |
+
def __getitem__(self, idx):
|
| 18 |
+
spec = self.spectrograms[idx]
|
| 19 |
+
label = self.labels[idx]
|
| 20 |
+
|
| 21 |
+
if self.mode == 'train':
|
| 22 |
+
n_frames = spec.shape[0]
|
| 23 |
+
|
| 24 |
+
if n_frames >= self.patch_length:
|
| 25 |
+
start = np.random.randint(0, n_frames - self.patch_length + 1)
|
| 26 |
+
patch = spec[start:start + self.patch_length]
|
| 27 |
+
else:
|
| 28 |
+
pad = self.patch_length - n_frames
|
| 29 |
+
patch = np.pad(spec, ((0, pad), (0, 0)), mode='constant')
|
| 30 |
+
|
| 31 |
+
patch = patch[np.newaxis, :, :]
|
| 32 |
+
return torch.tensor(patch, dtype=torch.float32), torch.tensor(label, dtype=torch.long)
|
| 33 |
+
|
| 34 |
+
else:
|
| 35 |
+
return spec, label
|
| 36 |
+
|
| 37 |
+
class FullTFPatchesDataset(Dataset):
|
| 38 |
+
def __init__(self, spectrograms, labels, patch_length=128):
|
| 39 |
+
self.patch_length = patch_length
|
| 40 |
+
self.patch_indices = []
|
| 41 |
+
|
| 42 |
+
for spec_idx, spec in enumerate(spectrograms):
|
| 43 |
+
n_frames = spec.shape[0]
|
| 44 |
+
label = labels[spec_idx]
|
| 45 |
+
|
| 46 |
+
if n_frames >= patch_length:
|
| 47 |
+
for start_frame in range(n_frames - patch_length + 1):
|
| 48 |
+
self.patch_indices.append((spec_idx, start_frame, label))
|
| 49 |
+
else:
|
| 50 |
+
self.patch_indices.append((spec_idx, 0, label))
|
| 51 |
+
|
| 52 |
+
self.spectrograms = spectrograms
|
| 53 |
+
|
| 54 |
+
def __len__(self):
|
| 55 |
+
return len(self.patch_indices)
|
| 56 |
+
|
| 57 |
+
def __getitem__(self, idx):
|
| 58 |
+
spec_idx, start_frame, label = self.patch_indices[idx]
|
| 59 |
+
spec = self.spectrograms[spec_idx]
|
| 60 |
+
|
| 61 |
+
n_frames = spec.shape[0]
|
| 62 |
+
|
| 63 |
+
if n_frames >= self.patch_length:
|
| 64 |
+
patch = spec[start_frame:start_frame + self.patch_length]
|
| 65 |
+
else:
|
| 66 |
+
pad = self.patch_length - n_frames
|
| 67 |
+
patch = np.pad(spec, ((0, pad), (0, 0)), mode='constant')
|
| 68 |
+
|
| 69 |
+
patch = patch[np.newaxis, :, :]
|
| 70 |
+
|
| 71 |
+
return torch.tensor(patch, dtype=torch.float32), torch.tensor(label, dtype=torch.long)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class RandomPatchDataset(Dataset):
|
| 75 |
+
def __init__(self, spectrograms, labels, patch_length=128):
|
| 76 |
+
self.spectrograms = spectrograms
|
| 77 |
+
self.labels = labels
|
| 78 |
+
self.patch_length = patch_length
|
| 79 |
+
|
| 80 |
+
def __len__(self):
|
| 81 |
+
return len(self.labels)
|
| 82 |
+
|
| 83 |
+
def __getitem__(self, idx):
|
| 84 |
+
spec = self.spectrograms[idx]
|
| 85 |
+
label = self.labels[idx]
|
| 86 |
+
n_frames = spec.shape[0]
|
| 87 |
+
|
| 88 |
+
if n_frames >= self.patch_length:
|
| 89 |
+
start = np.random.randint(0, n_frames - self.patch_length + 1)
|
| 90 |
+
patch = spec[start:start + self.patch_length]
|
| 91 |
+
else:
|
| 92 |
+
pad = self.patch_length - n_frames
|
| 93 |
+
patch = np.pad(spec, ((0, pad), (0, 0)), mode='constant')
|
| 94 |
+
|
| 95 |
+
patch = patch[np.newaxis, :, :]
|
| 96 |
+
return torch.tensor(patch, dtype=torch.float32), torch.tensor(label, dtype=torch.long)
|
src/data/download.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import zipfile
|
| 3 |
+
import io
|
| 4 |
+
import os
|
| 5 |
+
import shutil
|
| 6 |
+
|
| 7 |
+
repo_url = "https://github.com/karolpiczak/ESC-50/archive/refs/heads/master.zip"
|
| 8 |
+
repo_dst_dir = "data"
|
| 9 |
+
audio_dst_dir = os.path.join(repo_dst_dir, "audio", "0")
|
| 10 |
+
|
| 11 |
+
paths_to_delete = [
|
| 12 |
+
".gitignore",
|
| 13 |
+
"esc50.gif",
|
| 14 |
+
"LICENSE",
|
| 15 |
+
"pytest.ini",
|
| 16 |
+
"README.md",
|
| 17 |
+
"requirements.txt",
|
| 18 |
+
"tests",
|
| 19 |
+
"meta",
|
| 20 |
+
".github",
|
| 21 |
+
".circleci"
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
def download_and_extract(url, dst_dir):
|
| 25 |
+
os.makedirs(dst_dir, exist_ok=True)
|
| 26 |
+
print(f"Downloading from {url}")
|
| 27 |
+
response = requests.get(url)
|
| 28 |
+
response.raise_for_status()
|
| 29 |
+
|
| 30 |
+
with zipfile.ZipFile(io.BytesIO(response.content)) as z:
|
| 31 |
+
print(f"Extracting to {dst_dir}")
|
| 32 |
+
z.extractall(dst_dir)
|
| 33 |
+
print("Done extracting.")
|
| 34 |
+
|
| 35 |
+
def clean_files(repo_dir, paths_to_delete):
|
| 36 |
+
for f in paths_to_delete:
|
| 37 |
+
path = os.path.join(repo_dir, f)
|
| 38 |
+
if os.path.isfile(path):
|
| 39 |
+
os.remove(path)
|
| 40 |
+
print(f"Deleted file: {path}")
|
| 41 |
+
elif os.path.isdir(path):
|
| 42 |
+
shutil.rmtree(path)
|
| 43 |
+
print(f"Deleted directory: {path}")
|
| 44 |
+
|
| 45 |
+
def move_audio_files(src_dir, dst_dir):
|
| 46 |
+
os.makedirs(dst_dir, exist_ok=True)
|
| 47 |
+
print(f"Moving audio files from {src_dir} to {dst_dir}")
|
| 48 |
+
|
| 49 |
+
for filename in os.listdir(src_dir):
|
| 50 |
+
src_file = os.path.join(src_dir, filename)
|
| 51 |
+
dst_file = os.path.join(dst_dir, filename)
|
| 52 |
+
if os.path.isfile(src_file):
|
| 53 |
+
shutil.move(src_file, dst_file)
|
| 54 |
+
print(f"Moved all audio files to {dst_dir}")
|
| 55 |
+
|
| 56 |
+
def download_clean():
|
| 57 |
+
# Download and extract
|
| 58 |
+
download_and_extract(repo_url, repo_dst_dir)
|
| 59 |
+
|
| 60 |
+
# The extracted path will be data/ESC-50-master/
|
| 61 |
+
extracted_dir = os.path.join(repo_dst_dir, "ESC-50-master")
|
| 62 |
+
audio_src_dir = os.path.join(extracted_dir, "audio")
|
| 63 |
+
|
| 64 |
+
# Clean unwanted files
|
| 65 |
+
clean_files(extracted_dir, paths_to_delete)
|
| 66 |
+
|
| 67 |
+
# Move audio files to data/audio/0
|
| 68 |
+
move_audio_files(audio_src_dir, audio_dst_dir)
|
| 69 |
+
|
| 70 |
+
# Clean up the extracted directory
|
| 71 |
+
shutil.rmtree(extracted_dir)
|
| 72 |
+
print(f"Cleanup complete. Audio files are in {audio_dst_dir}")
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
download_clean()
|