mateo496 commited on
Commit
d339e38
·
1 Parent(s): 8a89899

Config file logic and prediction logic.

Browse files
Files changed (3) hide show
  1. src/config/config.py +16 -1
  2. src/data/augment.py +1 -10
  3. src/models/predict.py +169 -1
src/config/config.py CHANGED
@@ -3,8 +3,23 @@ parameters = {
3
  "n_mels" : 128,
4
  "frame_size" : 1024,
5
  "hop_size": 1024,
6
- "sample_rate": sample_rate,
7
  "fft_size": 8192,
8
  }
9
 
 
 
10
  sample_rate = 44100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  "n_mels" : 128,
4
  "frame_size" : 1024,
5
  "hop_size": 1024,
6
+ "sample_rate": 44100,
7
  "fft_size": 8192,
8
  }
9
 
10
+ cnn_input_length = 128
11
+
12
  sample_rate = 44100
13
+
14
+ esc50_labels = [
15
+ 'dog', 'rooster', 'pig', 'cow', 'frog',
16
+ 'cat', 'hen', 'insects', 'sheep', 'crow',
17
+ 'rain', 'sea_waves', 'crackling_fire', 'crickets', 'chirping_birds',
18
+ 'water_drops', 'wind', 'pouring_water', 'toilet_flush', 'thunderstorm',
19
+ 'crying_baby', 'sneezing', 'clapping', 'breathing', 'coughing',
20
+ 'footsteps', 'laughing', 'brushing_teeth', 'snoring', 'drinking_sipping',
21
+ 'door_wood_knock', 'mouse_click', 'keyboard_typing', 'door_wood_creaks', 'can_opening',
22
+ 'washing_machine', 'vacuum_cleaner', 'clock_alarm', 'clock_tick', 'glass_breaking',
23
+ 'helicopter', 'chainsaw', 'siren', 'car_horn', 'engine',
24
+ 'train', 'church_bells', 'airplane', 'fireworks', 'hand_saw'
25
+ ]
src/data/augment.py CHANGED
@@ -5,16 +5,7 @@ import numpy as np
5
  import os
6
  import soundfile as sf
7
 
8
- sample_rate = 44100
9
-
10
- parameters = {
11
- "n_bands" : 128,
12
- "n_mels" : 128,
13
- "frame_size" : 1024,
14
- "hop_size": 1024,
15
- "sample_rate": sample_rate,
16
- "fft_size": 8192,
17
- }
18
 
19
  def data_treatment(
20
  audio_path,
 
5
  import os
6
  import soundfile as sf
7
 
8
+ from src.config.config import sample_rate, parameters, cnn_input_length
 
 
 
 
 
 
 
 
 
9
 
10
  def data_treatment(
11
  audio_path,
src/models/predict.py CHANGED
@@ -1,7 +1,14 @@
1
  import numpy as np
2
  import torch
 
 
 
 
 
3
 
4
- cnn_input_length = 128
 
 
5
 
6
  def predict_with_overlapping_patches(model, spectrogram, patch_length=cnn_input_length, hop=1, batch_size=100, device="cuda"):
7
  model.eval()
@@ -35,3 +42,164 @@ def predict_with_overlapping_patches(model, spectrogram, patch_length=cnn_input_
35
  predicted_class = mean_activations.argmax().item()
36
 
37
  return predicted_class
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  import torch
3
+ import torch.nn as nn
4
+ import essentia.standard as es
5
+ import argparse
6
+ import os
7
+ import sys
8
 
9
+ from src.models.cnn import CNN
10
+ from src.data.augment import data_treatment
11
+ from src.config.config import sample_rate, parameters, cnn_input_length, esc50_labels
12
 
13
  def predict_with_overlapping_patches(model, spectrogram, patch_length=cnn_input_length, hop=1, batch_size=100, device="cuda"):
14
  model.eval()
 
42
  predicted_class = mean_activations.argmax().item()
43
 
44
  return predicted_class
45
+
46
+ def predict_top_k(model, spectrogram, patch_length=cnn_input_length, hop=1, batch_size=100, device="cpu", top_k=5):
47
+ model.eval()
48
+
49
+ n_frames, n_mels = spectrogram.shape
50
+
51
+ if n_frames < patch_length:
52
+ pad = patch_length - n_frames
53
+ spectrogram = np.pad(spectrogram, ((0, pad), (0, 0)), mode='constant')
54
+ n_frames = patch_length
55
+
56
+ patches = []
57
+ for start in range(0, n_frames - patch_length + 1, hop):
58
+ patch = spectrogram[start:start + patch_length]
59
+ patch = patch[np.newaxis, np.newaxis, :, :]
60
+ patches.append(patch)
61
+
62
+ patches = np.concatenate(patches, axis=0)
63
+ patches = torch.tensor(patches, dtype=torch.float32).to(device)
64
+
65
+ all_outputs = []
66
+ with torch.no_grad():
67
+ for i in range(0, len(patches), batch_size):
68
+ batch = patches[i:i + batch_size]
69
+ outputs = model(batch)
70
+ all_outputs.append(outputs)
71
+
72
+ all_outputs = torch.cat(all_outputs, dim=0)
73
+ mean_logits = all_outputs.mean(dim=0)
74
+ probabilities = torch.nn.functional.softmax(mean_logits, dim=0)
75
+
76
+ top_probs, top_indices = torch.topk(probabilities, min(top_k, 50))
77
+ top_probs = top_probs.cpu().numpy()
78
+ top_indices = top_indices.cpu().numpy()
79
+
80
+ return top_probs, top_indices
81
+
82
+ def predict_file(model, audio_file, device="cpu", top_k=5):
83
+ parameters = {
84
+ "n_bands" : 128,
85
+ "n_mels" : 128,
86
+ "frame_size" : 1024,
87
+ "hop_size": 1024,
88
+ "sample_rate": sample_rate,
89
+ "fft_size": 8192,
90
+ }
91
+ spectrogram, label = data_treatment(audio_file, **parameters)
92
+
93
+ spectrogram = np.array(spectrogram)
94
+ print(f" Spectrogram shape from data_treatment: {spectrogram.shape}")
95
+
96
+ spectrogram = spectrogram.squeeze()
97
+
98
+
99
+ predicted_class = predict_with_overlapping_patches(
100
+ model, spectrogram, patch_length=128, hop=1, batch_size=100, device=device
101
+ )
102
+ top_probs, top_indices = predict_top_k(
103
+ model, spectrogram, patch_length=128, hop=1, batch_size=100, device=device, top_k=top_k
104
+ )
105
+
106
+ return predicted_class, label, top_probs, top_indices
107
+
108
+ def load_model(model_path, device='cpu'):
109
+ print(f"Loading model from {model_path}...")
110
+
111
+ model = CNN(n_classes=50)
112
+ checkpoint = torch.load(model_path, map_location=device)
113
+
114
+ if isinstance(checkpoint, dict):
115
+ if 'model_state_dict' in checkpoint:
116
+ model.load_state_dict(checkpoint['model_state_dict'])
117
+ if 'best_val_acc' in checkpoint:
118
+ print(f"Model validation accuracy: {checkpoint['best_val_acc']:.4f}")
119
+ else:
120
+ model.load_state_dict(checkpoint)
121
+ else:
122
+ model.load_state_dict(checkpoint)
123
+
124
+ model.to(device)
125
+ model.eval()
126
+ print("Model loaded successfully!\n")
127
+ return model
128
+
129
+
130
+
131
+ def main():
132
+ parser = argparse.ArgumentParser(
133
+ description='Predict environmental sound class using trained ESC-50 model'
134
+ )
135
+ parser.add_argument(
136
+ 'audio_file',
137
+ type=str,
138
+ help='Path to .wav file to classify'
139
+ )
140
+ parser.add_argument(
141
+ '--model',
142
+ type=str,
143
+ default='best_model.pt',
144
+ help='Path to trained model checkpoint (default: best_model.pt)'
145
+ )
146
+ parser.add_argument(
147
+ '--top-k',
148
+ type=int,
149
+ default=5,
150
+ help='Number of top predictions to show (default: 5)'
151
+ )
152
+ parser.add_argument(
153
+ '--device',
154
+ type=str,
155
+ default='cuda' if torch.cuda.is_available() else 'cpu',
156
+ help='Device to use (default: auto-detect)'
157
+ )
158
+
159
+ args = parser.parse_args()
160
+
161
+ if not os.path.exists(args.audio_file):
162
+ print(f"Error: Audio file not found: {args.audio_file}")
163
+ sys.exit(1)
164
+
165
+ if not os.path.exists(args.model):
166
+ print(f"Error: Model file not found: {args.model}")
167
+ sys.exit(1)
168
+
169
+ # Load model
170
+ try:
171
+ model = load_model(args.model, device=args.device)
172
+ except Exception as e:
173
+ print(f"Error loading model: {e}")
174
+ import traceback
175
+ traceback.print_exc()
176
+ sys.exit(1)
177
+
178
+ # Predict
179
+ try:
180
+ predicted_class, label, top_probs, top_indices = predict_file(
181
+ model, args.audio_file, device=args.device, top_k=args.top_k
182
+ )
183
+
184
+ # Display results
185
+ print("\n" + "=" * 60)
186
+ print(f"Top {args.top_k} Predictions:")
187
+ print("=" * 60)
188
+
189
+ for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
190
+ class_name = esc50_labels[idx]
191
+ marker = "★" if idx == predicted_class else " "
192
+ print(f"{marker} {i+1}. {class_name:20s} - {prob*100:6.2f}%")
193
+
194
+ print("=" * 60)
195
+ print(f"\n✓ Predicted class: {esc50_labels[predicted_class]}")
196
+
197
+ except Exception as e:
198
+ print(f"\nError during prediction: {e}")
199
+ import traceback
200
+ traceback.print_exc()
201
+ sys.exit(1)
202
+
203
+
204
+ if __name__ == '__main__':
205
+ main()