AKMESSI commited on
Commit
44dff6a
Β·
verified Β·
1 Parent(s): f2d115f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -5
app.py CHANGED
@@ -32,20 +32,63 @@ def load_model_and_map():
32
  # Load the checkpoint
33
  checkpoint = torch.load("multi_species_model.pth", map_location="cpu")
34
 
 
 
 
 
 
 
 
 
 
35
  # Create model directly from torchvision instead of torch.hub
36
  model = models.mobilenet_v3_small(pretrained=False)
37
- num_classes = len(checkpoint['label_map'])
 
 
 
 
 
38
  model.classifier[3] = torch.nn.Linear(model.classifier[3].in_features, num_classes)
39
- model.load_state_dict(checkpoint['model_state_dict'])
 
 
 
 
 
 
 
 
 
40
  model.eval()
41
 
42
- # Get class names (scientific names)
43
- class_names = list(checkpoint['label_map'].keys())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  return model, class_names
46
 
47
  model, class_names = load_model_and_map()
48
 
 
 
49
  # Show status of audio backend
50
  if not TORCHAUDIO_AVAILABLE:
51
  st.info("ℹ️ Using soundfile backend for audio processing (torchaudio not available)")
@@ -124,6 +167,9 @@ if audio_data:
124
  audio_bytes = audio_data.read()
125
  audio_data.seek(0) # Reset file pointer
126
 
 
 
 
127
  if TORCHAUDIO_AVAILABLE:
128
  try:
129
  waveform, original_sr = torchaudio.load(io.BytesIO(audio_bytes))
@@ -143,6 +189,10 @@ if audio_data:
143
  waveform = waveform.mean(dim=1)
144
  waveform = waveform.unsqueeze(0)
145
 
 
 
 
 
146
  # Resample to 22050 if needed
147
  if original_sr != 22050:
148
  if TORCHAUDIO_AVAILABLE:
@@ -170,6 +220,8 @@ if audio_data:
170
  else:
171
  waveform = waveform[:, :target_samples]
172
 
 
 
173
  # Compute Mel spectrogram
174
  if TORCHAUDIO_AVAILABLE:
175
  mel = full_transform(waveform) # (1, 128, time)
@@ -178,6 +230,16 @@ if audio_data:
178
  mel = full_transform(waveform) # (1, 128, time)
179
  mel = mel.squeeze(0) # (128, time)
180
 
 
 
 
 
 
 
 
 
 
 
181
  # Normalize for visualization
182
  mel_min = mel.min()
183
  mel_max = mel.max()
@@ -185,15 +247,33 @@ if audio_data:
185
 
186
  # Prepare for model: resize to 224x224, add batch & RGB channels
187
  mel_input = mel.unsqueeze(0).unsqueeze(0) # (1, 1, 128, time)
 
 
188
  mel_input = torch.nn.functional.interpolate(mel_input, size=(224, 224), mode='bilinear', align_corners=False)
 
 
189
  mel_input = mel_input.repeat(1, 3, 1, 1) # to RGB
 
 
 
 
190
 
191
  # Inference
192
  with torch.no_grad():
193
  output = model(mel_input)
 
 
 
194
  probs = torch.nn.functional.softmax(output[0], dim=0)
 
 
195
  top5_probs, top5_idx = torch.topk(probs, 5)
196
 
 
 
 
 
 
197
  # Determine confidence level
198
  top1_confidence = top5_probs[0].item()
199
  top1_species = class_names[top5_idx[0]]
@@ -258,7 +338,7 @@ if audio_data:
258
  st.markdown("---")
259
  with st.expander("πŸ“Š View Audio Spectrogram"):
260
  mel_vis = mel_norm.cpu().numpy()
261
- st.image(mel_vis, caption="Mel Spectrogram of your audio", use_column_width=True, clamp=True)
262
  st.caption("This visualization shows the frequency content of the bird call over time.")
263
 
264
  except Exception as e:
 
32
  # Load the checkpoint
33
  checkpoint = torch.load("multi_species_model.pth", map_location="cpu")
34
 
35
+ # Debug: Check what's in the checkpoint
36
+ st.write("πŸ” **Checkpoint Keys:**", list(checkpoint.keys()))
37
+
38
+ # Get label map
39
+ label_map = checkpoint['label_map']
40
+ st.write(f"πŸ“‹ **Number of classes in checkpoint:** {len(label_map)}")
41
+ st.write(f"πŸ“ **First 5 species in label_map:**", list(label_map.keys())[:5])
42
+ st.write(f"πŸ”’ **Label map type:**", type(label_map))
43
+
44
  # Create model directly from torchvision instead of torch.hub
45
  model = models.mobilenet_v3_small(pretrained=False)
46
+ num_classes = len(label_map)
47
+
48
+ st.write(f"🧠 **Model output classes:** {num_classes}")
49
+ st.write(f"πŸ”§ **Original classifier final layer:** {model.classifier[3]}")
50
+
51
+ # Replace final layer
52
  model.classifier[3] = torch.nn.Linear(model.classifier[3].in_features, num_classes)
53
+ st.write(f"βœ… **New classifier final layer:** {model.classifier[3]}")
54
+
55
+ # Load state dict
56
+ try:
57
+ model.load_state_dict(checkpoint['model_state_dict'])
58
+ st.success(f"βœ… Model weights loaded successfully!")
59
+ except Exception as e:
60
+ st.error(f"❌ Error loading model weights: {e}")
61
+ st.stop()
62
+
63
  model.eval()
64
 
65
+ # Get class names - THIS IS CRITICAL
66
+ # The label_map from your checkpoint should be {species_name: index}
67
+ # We need to create a list where list[index] = species_name
68
+
69
+ if isinstance(list(label_map.keys())[0], str):
70
+ # label_map is {species_name: index}, need to invert it
71
+ st.info("πŸ“– Label map format: {species_name: index}")
72
+ # Create inverse mapping: index -> species_name
73
+ index_to_species = {v: k for k, v in label_map.items()}
74
+ # Create ordered list by index
75
+ class_names = [index_to_species[i] for i in range(len(label_map))]
76
+ else:
77
+ # label_map is {index: species_name}
78
+ st.info("πŸ“– Label map format: {index: species_name}")
79
+ class_names = [label_map[i] for i in sorted(label_map.keys())]
80
+
81
+ st.write(f"🐦 **Total species loaded:** {len(class_names)}")
82
+ st.write(f"πŸ”€ **Class names sample (indices 0-4):**")
83
+ for i in range(min(5, len(class_names))):
84
+ st.write(f" Index {i}: {class_names[i]}")
85
 
86
  return model, class_names
87
 
88
  model, class_names = load_model_and_map()
89
 
90
+ st.markdown("---")
91
+
92
  # Show status of audio backend
93
  if not TORCHAUDIO_AVAILABLE:
94
  st.info("ℹ️ Using soundfile backend for audio processing (torchaudio not available)")
 
167
  audio_bytes = audio_data.read()
168
  audio_data.seek(0) # Reset file pointer
169
 
170
+ # Debug: Show file info
171
+ st.info(f"πŸ“ File size: {len(audio_bytes) / 1024:.1f} KB")
172
+
173
  if TORCHAUDIO_AVAILABLE:
174
  try:
175
  waveform, original_sr = torchaudio.load(io.BytesIO(audio_bytes))
 
189
  waveform = waveform.mean(dim=1)
190
  waveform = waveform.unsqueeze(0)
191
 
192
+ # Debug info
193
+ st.info(f"🎡 Original sample rate: {original_sr} Hz, Duration: {waveform.shape[1] / original_sr:.2f} seconds")
194
+ st.info(f"πŸ“Š Waveform shape: {waveform.shape}")
195
+
196
  # Resample to 22050 if needed
197
  if original_sr != 22050:
198
  if TORCHAUDIO_AVAILABLE:
 
220
  else:
221
  waveform = waveform[:, :target_samples]
222
 
223
+ st.info(f"βœ‚οΈ Processed to 5 seconds: {waveform.shape}")
224
+
225
  # Compute Mel spectrogram
226
  if TORCHAUDIO_AVAILABLE:
227
  mel = full_transform(waveform) # (1, 128, time)
 
230
  mel = full_transform(waveform) # (1, 128, time)
231
  mel = mel.squeeze(0) # (128, time)
232
 
233
+ st.info(f"🎼 Mel spectrogram shape: {mel.shape}")
234
+
235
+ # Check if mel spectrogram is valid
236
+ if torch.isnan(mel).any() or torch.isinf(mel).any():
237
+ st.error("⚠️ Invalid mel spectrogram detected (NaN or Inf values)")
238
+ st.stop()
239
+
240
+ # Show mel spectrogram statistics
241
+ st.info(f"πŸ“ˆ Mel stats - Min: {mel.min():.2f}, Max: {mel.max():.2f}, Mean: {mel.mean():.2f}")
242
+
243
  # Normalize for visualization
244
  mel_min = mel.min()
245
  mel_max = mel.max()
 
247
 
248
  # Prepare for model: resize to 224x224, add batch & RGB channels
249
  mel_input = mel.unsqueeze(0).unsqueeze(0) # (1, 1, 128, time)
250
+ st.info(f"πŸ”§ Before resize: {mel_input.shape}")
251
+
252
  mel_input = torch.nn.functional.interpolate(mel_input, size=(224, 224), mode='bilinear', align_corners=False)
253
+ st.info(f"πŸ“ After resize to 224x224: {mel_input.shape}")
254
+
255
  mel_input = mel_input.repeat(1, 3, 1, 1) # to RGB
256
+ st.info(f"🎨 After RGB conversion: {mel_input.shape}")
257
+
258
+ # Show input statistics
259
+ st.info(f"πŸ”’ Model input stats - Min: {mel_input.min():.2f}, Max: {mel_input.max():.2f}, Mean: {mel_input.mean():.2f}")
260
 
261
  # Inference
262
  with torch.no_grad():
263
  output = model(mel_input)
264
+ st.info(f"🧠 Raw model output shape: {output.shape}")
265
+ st.info(f"πŸ“Š Raw output stats - Min: {output.min():.2f}, Max: {output.max():.2f}")
266
+
267
  probs = torch.nn.functional.softmax(output[0], dim=0)
268
+ st.info(f"🎲 Probabilities sum: {probs.sum():.4f} (should be ~1.0)")
269
+
270
  top5_probs, top5_idx = torch.topk(probs, 5)
271
 
272
+ # Show raw top 5 for debugging
273
+ with st.expander("πŸ” DEBUG: Raw Top 5 Predictions"):
274
+ for i in range(5):
275
+ st.write(f"{i+1}. Index: {top5_idx[i].item()}, Prob: {top5_probs[i].item():.4f}, Species: {class_names[top5_idx[i]]}")
276
+
277
  # Determine confidence level
278
  top1_confidence = top5_probs[0].item()
279
  top1_species = class_names[top5_idx[0]]
 
338
  st.markdown("---")
339
  with st.expander("πŸ“Š View Audio Spectrogram"):
340
  mel_vis = mel_norm.cpu().numpy()
341
+ st.image(mel_vis, caption="Mel Spectrogram of your audio", use_container_width=True, clamp=True)
342
  st.caption("This visualization shows the frequency content of the bird call over time.")
343
 
344
  except Exception as e: