Spaces:
Sleeping
Sleeping
Config file logic and prediction logic.
Browse files- src/config/config.py +16 -1
- src/data/augment.py +1 -10
- src/models/predict.py +169 -1
src/config/config.py
CHANGED
|
@@ -3,8 +3,23 @@ parameters = {
|
|
| 3 |
"n_mels" : 128,
|
| 4 |
"frame_size" : 1024,
|
| 5 |
"hop_size": 1024,
|
| 6 |
-
"sample_rate":
|
| 7 |
"fft_size": 8192,
|
| 8 |
}
|
| 9 |
|
|
|
|
|
|
|
| 10 |
sample_rate = 44100
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
"n_mels" : 128,
|
| 4 |
"frame_size" : 1024,
|
| 5 |
"hop_size": 1024,
|
| 6 |
+
"sample_rate": 44100,
|
| 7 |
"fft_size": 8192,
|
| 8 |
}
|
| 9 |
|
| 10 |
+
cnn_input_length = 128
|
| 11 |
+
|
| 12 |
sample_rate = 44100
|
| 13 |
+
|
| 14 |
+
esc50_labels = [
|
| 15 |
+
'dog', 'rooster', 'pig', 'cow', 'frog',
|
| 16 |
+
'cat', 'hen', 'insects', 'sheep', 'crow',
|
| 17 |
+
'rain', 'sea_waves', 'crackling_fire', 'crickets', 'chirping_birds',
|
| 18 |
+
'water_drops', 'wind', 'pouring_water', 'toilet_flush', 'thunderstorm',
|
| 19 |
+
'crying_baby', 'sneezing', 'clapping', 'breathing', 'coughing',
|
| 20 |
+
'footsteps', 'laughing', 'brushing_teeth', 'snoring', 'drinking_sipping',
|
| 21 |
+
'door_wood_knock', 'mouse_click', 'keyboard_typing', 'door_wood_creaks', 'can_opening',
|
| 22 |
+
'washing_machine', 'vacuum_cleaner', 'clock_alarm', 'clock_tick', 'glass_breaking',
|
| 23 |
+
'helicopter', 'chainsaw', 'siren', 'car_horn', 'engine',
|
| 24 |
+
'train', 'church_bells', 'airplane', 'fireworks', 'hand_saw'
|
| 25 |
+
]
|
src/data/augment.py
CHANGED
|
@@ -5,16 +5,7 @@ import numpy as np
|
|
| 5 |
import os
|
| 6 |
import soundfile as sf
|
| 7 |
|
| 8 |
-
sample_rate
|
| 9 |
-
|
| 10 |
-
parameters = {
|
| 11 |
-
"n_bands" : 128,
|
| 12 |
-
"n_mels" : 128,
|
| 13 |
-
"frame_size" : 1024,
|
| 14 |
-
"hop_size": 1024,
|
| 15 |
-
"sample_rate": sample_rate,
|
| 16 |
-
"fft_size": 8192,
|
| 17 |
-
}
|
| 18 |
|
| 19 |
def data_treatment(
|
| 20 |
audio_path,
|
|
|
|
| 5 |
import os
|
| 6 |
import soundfile as sf
|
| 7 |
|
| 8 |
+
from src.config.config import sample_rate, parameters, cnn_input_length
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
def data_treatment(
|
| 11 |
audio_path,
|
src/models/predict.py
CHANGED
|
@@ -1,7 +1,14 @@
|
|
| 1 |
import numpy as np
|
| 2 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
-
|
|
|
|
|
|
|
| 5 |
|
| 6 |
def predict_with_overlapping_patches(model, spectrogram, patch_length=cnn_input_length, hop=1, batch_size=100, device="cuda"):
|
| 7 |
model.eval()
|
|
@@ -35,3 +42,164 @@ def predict_with_overlapping_patches(model, spectrogram, patch_length=cnn_input_
|
|
| 35 |
predicted_class = mean_activations.argmax().item()
|
| 36 |
|
| 37 |
return predicted_class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import essentia.standard as es
|
| 5 |
+
import argparse
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
|
| 9 |
+
from src.models.cnn import CNN
|
| 10 |
+
from src.data.augment import data_treatment
|
| 11 |
+
from src.config.config import sample_rate, parameters, cnn_input_length, esc50_labels
|
| 12 |
|
| 13 |
def predict_with_overlapping_patches(model, spectrogram, patch_length=cnn_input_length, hop=1, batch_size=100, device="cuda"):
|
| 14 |
model.eval()
|
|
|
|
| 42 |
predicted_class = mean_activations.argmax().item()
|
| 43 |
|
| 44 |
return predicted_class
|
| 45 |
+
|
| 46 |
+
def predict_top_k(model, spectrogram, patch_length=cnn_input_length, hop=1, batch_size=100, device="cpu", top_k=5):
|
| 47 |
+
model.eval()
|
| 48 |
+
|
| 49 |
+
n_frames, n_mels = spectrogram.shape
|
| 50 |
+
|
| 51 |
+
if n_frames < patch_length:
|
| 52 |
+
pad = patch_length - n_frames
|
| 53 |
+
spectrogram = np.pad(spectrogram, ((0, pad), (0, 0)), mode='constant')
|
| 54 |
+
n_frames = patch_length
|
| 55 |
+
|
| 56 |
+
patches = []
|
| 57 |
+
for start in range(0, n_frames - patch_length + 1, hop):
|
| 58 |
+
patch = spectrogram[start:start + patch_length]
|
| 59 |
+
patch = patch[np.newaxis, np.newaxis, :, :]
|
| 60 |
+
patches.append(patch)
|
| 61 |
+
|
| 62 |
+
patches = np.concatenate(patches, axis=0)
|
| 63 |
+
patches = torch.tensor(patches, dtype=torch.float32).to(device)
|
| 64 |
+
|
| 65 |
+
all_outputs = []
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
for i in range(0, len(patches), batch_size):
|
| 68 |
+
batch = patches[i:i + batch_size]
|
| 69 |
+
outputs = model(batch)
|
| 70 |
+
all_outputs.append(outputs)
|
| 71 |
+
|
| 72 |
+
all_outputs = torch.cat(all_outputs, dim=0)
|
| 73 |
+
mean_logits = all_outputs.mean(dim=0)
|
| 74 |
+
probabilities = torch.nn.functional.softmax(mean_logits, dim=0)
|
| 75 |
+
|
| 76 |
+
top_probs, top_indices = torch.topk(probabilities, min(top_k, 50))
|
| 77 |
+
top_probs = top_probs.cpu().numpy()
|
| 78 |
+
top_indices = top_indices.cpu().numpy()
|
| 79 |
+
|
| 80 |
+
return top_probs, top_indices
|
| 81 |
+
|
| 82 |
+
def predict_file(model, audio_file, device="cpu", top_k=5):
|
| 83 |
+
parameters = {
|
| 84 |
+
"n_bands" : 128,
|
| 85 |
+
"n_mels" : 128,
|
| 86 |
+
"frame_size" : 1024,
|
| 87 |
+
"hop_size": 1024,
|
| 88 |
+
"sample_rate": sample_rate,
|
| 89 |
+
"fft_size": 8192,
|
| 90 |
+
}
|
| 91 |
+
spectrogram, label = data_treatment(audio_file, **parameters)
|
| 92 |
+
|
| 93 |
+
spectrogram = np.array(spectrogram)
|
| 94 |
+
print(f" Spectrogram shape from data_treatment: {spectrogram.shape}")
|
| 95 |
+
|
| 96 |
+
spectrogram = spectrogram.squeeze()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
predicted_class = predict_with_overlapping_patches(
|
| 100 |
+
model, spectrogram, patch_length=128, hop=1, batch_size=100, device=device
|
| 101 |
+
)
|
| 102 |
+
top_probs, top_indices = predict_top_k(
|
| 103 |
+
model, spectrogram, patch_length=128, hop=1, batch_size=100, device=device, top_k=top_k
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
return predicted_class, label, top_probs, top_indices
|
| 107 |
+
|
| 108 |
+
def load_model(model_path, device='cpu'):
|
| 109 |
+
print(f"Loading model from {model_path}...")
|
| 110 |
+
|
| 111 |
+
model = CNN(n_classes=50)
|
| 112 |
+
checkpoint = torch.load(model_path, map_location=device)
|
| 113 |
+
|
| 114 |
+
if isinstance(checkpoint, dict):
|
| 115 |
+
if 'model_state_dict' in checkpoint:
|
| 116 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 117 |
+
if 'best_val_acc' in checkpoint:
|
| 118 |
+
print(f"Model validation accuracy: {checkpoint['best_val_acc']:.4f}")
|
| 119 |
+
else:
|
| 120 |
+
model.load_state_dict(checkpoint)
|
| 121 |
+
else:
|
| 122 |
+
model.load_state_dict(checkpoint)
|
| 123 |
+
|
| 124 |
+
model.to(device)
|
| 125 |
+
model.eval()
|
| 126 |
+
print("Model loaded successfully!\n")
|
| 127 |
+
return model
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def main():
|
| 132 |
+
parser = argparse.ArgumentParser(
|
| 133 |
+
description='Predict environmental sound class using trained ESC-50 model'
|
| 134 |
+
)
|
| 135 |
+
parser.add_argument(
|
| 136 |
+
'audio_file',
|
| 137 |
+
type=str,
|
| 138 |
+
help='Path to .wav file to classify'
|
| 139 |
+
)
|
| 140 |
+
parser.add_argument(
|
| 141 |
+
'--model',
|
| 142 |
+
type=str,
|
| 143 |
+
default='best_model.pt',
|
| 144 |
+
help='Path to trained model checkpoint (default: best_model.pt)'
|
| 145 |
+
)
|
| 146 |
+
parser.add_argument(
|
| 147 |
+
'--top-k',
|
| 148 |
+
type=int,
|
| 149 |
+
default=5,
|
| 150 |
+
help='Number of top predictions to show (default: 5)'
|
| 151 |
+
)
|
| 152 |
+
parser.add_argument(
|
| 153 |
+
'--device',
|
| 154 |
+
type=str,
|
| 155 |
+
default='cuda' if torch.cuda.is_available() else 'cpu',
|
| 156 |
+
help='Device to use (default: auto-detect)'
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
args = parser.parse_args()
|
| 160 |
+
|
| 161 |
+
if not os.path.exists(args.audio_file):
|
| 162 |
+
print(f"Error: Audio file not found: {args.audio_file}")
|
| 163 |
+
sys.exit(1)
|
| 164 |
+
|
| 165 |
+
if not os.path.exists(args.model):
|
| 166 |
+
print(f"Error: Model file not found: {args.model}")
|
| 167 |
+
sys.exit(1)
|
| 168 |
+
|
| 169 |
+
# Load model
|
| 170 |
+
try:
|
| 171 |
+
model = load_model(args.model, device=args.device)
|
| 172 |
+
except Exception as e:
|
| 173 |
+
print(f"Error loading model: {e}")
|
| 174 |
+
import traceback
|
| 175 |
+
traceback.print_exc()
|
| 176 |
+
sys.exit(1)
|
| 177 |
+
|
| 178 |
+
# Predict
|
| 179 |
+
try:
|
| 180 |
+
predicted_class, label, top_probs, top_indices = predict_file(
|
| 181 |
+
model, args.audio_file, device=args.device, top_k=args.top_k
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Display results
|
| 185 |
+
print("\n" + "=" * 60)
|
| 186 |
+
print(f"Top {args.top_k} Predictions:")
|
| 187 |
+
print("=" * 60)
|
| 188 |
+
|
| 189 |
+
for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
|
| 190 |
+
class_name = esc50_labels[idx]
|
| 191 |
+
marker = "★" if idx == predicted_class else " "
|
| 192 |
+
print(f"{marker} {i+1}. {class_name:20s} - {prob*100:6.2f}%")
|
| 193 |
+
|
| 194 |
+
print("=" * 60)
|
| 195 |
+
print(f"\n✓ Predicted class: {esc50_labels[predicted_class]}")
|
| 196 |
+
|
| 197 |
+
except Exception as e:
|
| 198 |
+
print(f"\nError during prediction: {e}")
|
| 199 |
+
import traceback
|
| 200 |
+
traceback.print_exc()
|
| 201 |
+
sys.exit(1)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
if __name__ == '__main__':
|
| 205 |
+
main()
|