fevot commited on
Commit
fcf045e
·
verified ·
1 Parent(s): f604fd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -93
app.py CHANGED
@@ -6,9 +6,7 @@ import numpy as np
6
  import json
7
  from torchvision import models
8
  import librosa
9
- import matplotlib.pyplot as plt
10
- import io
11
- from PIL import Image
12
 
13
  # Define the BirdCallRNN model
14
  class BirdCallRNN(nn.Module):
@@ -27,7 +25,7 @@ class BirdCallRNN(nn.Module):
27
  output = self.fc(rnn_out[:, -1, :]) # Note: We'll use this for single-segment sequences
28
  return output
29
 
30
- # Function to convert MP3 to mel spectrogram
31
  def mp3_to_mel_spectrogram(mp3_file, target_shape=(128, 500), resize_shape=(224, 224)):
32
  y, sr = librosa.load(mp3_file, sr=None)
33
  S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
@@ -44,26 +42,11 @@ def mp3_to_mel_spectrogram(mp3_file, target_shape=(128, 500), resize_shape=(224,
44
  log_S_resized = cv2.resize(log_S_resized, resize_shape, interpolation=cv2.INTER_CUBIC)
45
  return log_S_resized
46
 
47
- # Generate mel spectrogram image for display
48
- def generate_mel_spectrogram_plot(log_S):
49
- plt.figure(figsize=(10, 4))
50
- plt.imshow(log_S, aspect='auto', origin='lower', cmap='viridis')
51
- plt.colorbar(format='%+2.0f dB')
52
- plt.title('Mel Spectrogram')
53
- plt.tight_layout()
54
-
55
- # Save plot to a bytes buffer
56
- buf = io.BytesIO()
57
- plt.savefig(buf, format='png')
58
- plt.close()
59
- buf.seek(0)
60
- return Image.open(buf)
61
-
62
  # Load class mapping globally
63
  with open('class_mapping.json', 'r') as f:
64
  class_names = json.load(f)
65
 
66
- # Revised inference function to predict per segment with confidence scores
67
  def infer_birdcall(model, mp3_file, segment_length=500, device="cuda"):
68
  model.eval()
69
  # Load audio and compute mel spectrogram
@@ -71,39 +54,24 @@ def infer_birdcall(model, mp3_file, segment_length=500, device="cuda"):
71
  S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
72
  log_S = librosa.power_to_db(S, ref=np.max)
73
  # Segment the spectrogram
74
- num_segments = max(1, log_S.shape[1] // segment_length)
75
- segments = [log_S[:, i * segment_length:min((i + 1) * segment_length, log_S.shape[1])] for i in range(num_segments)]
 
 
 
76
 
77
  predictions = []
78
- confidence_scores = []
79
- spectrogram_images = []
80
-
81
  # Process each segment individually
82
  for seg in segments:
83
- # Generate spectrogram image first
84
- spec_img = generate_mel_spectrogram_plot(seg)
85
- spectrogram_images.append(spec_img)
86
-
87
- # Prepare for model input
88
  seg_resized = cv2.resize(seg, (224, 224), interpolation=cv2.INTER_CUBIC)
89
  seg_rgb = np.repeat(seg_resized[:, :, np.newaxis], 3, axis=-1)
90
  # Create a tensor with batch size 1 and sequence length 1
91
  seg_tensor = torch.from_numpy(seg_rgb).permute(2, 0, 1).float().unsqueeze(0).unsqueeze(0).to(device) # Shape: (1, 1, 3, 224, 224)
92
-
93
- with torch.no_grad():
94
- output = model(seg_tensor)
95
-
96
- # Get prediction
97
- probabilities = torch.nn.functional.softmax(output, dim=1)
98
- confidence, pred_idx = torch.max(probabilities, dim=1)
99
- pred = pred_idx.cpu().numpy()[0]
100
- conf = confidence.cpu().numpy()[0]
101
-
102
  predicted_bird = class_names[str(pred)] # Convert pred to string to match JSON keys
103
  predictions.append(predicted_bird)
104
- confidence_scores.append(conf)
105
-
106
- return predictions, confidence_scores, spectrogram_images
107
 
108
  # Initialize the model
109
  resnet = models.resnet50(weights='IMAGENET1K_V2')
@@ -117,75 +85,54 @@ model.load_state_dict(torch.load('model_weights.pth', map_location=device))
117
  model.eval()
118
 
119
  # Prediction function for Gradio
120
- def predict_bird(file_path):
121
- if file_path is None:
122
- return "Please upload an audio file", [], None, None, None
123
-
124
- predictions, confidence_scores, spectrograms = infer_birdcall(model, file_path, segment_length=500, device=str(device))
125
 
126
- # Format the predictions with numbering and confidence
127
- formatted_predictions = [f"{i+1}. {bird} (Confidence: {conf:.2%})" for i, (bird, conf) in enumerate(zip(predictions, confidence_scores))]
128
- prediction_text = "\n".join(formatted_predictions)
129
 
130
- # Load the three static images
131
- bird_species_img = "1.jpeg"
132
- bird_description_img = "2.jpeg"
133
- bird_origins_img = "3.jpeg"
134
 
135
- return prediction_text, spectrograms, bird_species_img, bird_description_img, bird_origins_img
 
136
 
137
- # Create Gradio blocks interface
138
- with gr.Blocks() as interface:
139
- gr.Markdown("# Bird Call Recognition")
140
 
141
  with gr.Row():
142
  with gr.Column():
143
- # File upload - fixed parameter issue
144
- input_audio = gr.Audio(
145
- type="filepath",
146
- label="Upload MP3 file"
147
- )
148
-
149
- # Submit button
150
- submit_btn = gr.Button("Identify Bird Species")
151
 
152
- # Results section
153
  with gr.Row():
154
- prediction_output = gr.Textbox(label="Identified Bird Species")
155
 
156
- # Spectrograms gallery - removed style method
157
  with gr.Row():
158
- spectrogram_gallery = gr.Gallery(
159
- label="Mel Spectrograms by Segment",
160
- show_label=True,
161
- # Removed style() method that was causing errors
162
- # Instead using direct parameters if available
163
- grid=[2, 2],
164
- height=400
165
- )
166
 
167
- # Bird information images
168
  with gr.Row():
169
- bird_species_image = gr.Image(label="Bird Species")
 
170
 
 
171
  with gr.Row():
172
- bird_description_image = gr.Image(label="Bird Description")
 
173
 
 
174
  with gr.Row():
175
- bird_origins_image = gr.Image(label="Bird Origins")
 
176
 
177
- # Set up the submission event
178
  submit_btn.click(
179
  fn=predict_bird,
180
- inputs=input_audio,
181
- outputs=[
182
- prediction_output,
183
- spectrogram_gallery,
184
- bird_species_image,
185
- bird_description_image,
186
- bird_origins_image
187
- ]
188
  )
189
 
190
- # Launch the interface
191
- interface.launch()
 
6
  import json
7
  from torchvision import models
8
  import librosa
9
+ import os
 
 
10
 
11
  # Define the BirdCallRNN model
12
  class BirdCallRNN(nn.Module):
 
25
  output = self.fc(rnn_out[:, -1, :]) # Note: We'll use this for single-segment sequences
26
  return output
27
 
28
+ # Function to convert MP3 to mel spectrogram (unchanged)
29
  def mp3_to_mel_spectrogram(mp3_file, target_shape=(128, 500), resize_shape=(224, 224)):
30
  y, sr = librosa.load(mp3_file, sr=None)
31
  S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
 
42
  log_S_resized = cv2.resize(log_S_resized, resize_shape, interpolation=cv2.INTER_CUBIC)
43
  return log_S_resized
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  # Load class mapping globally
46
  with open('class_mapping.json', 'r') as f:
47
  class_names = json.load(f)
48
 
49
+ # Revised inference function to predict per segment
50
  def infer_birdcall(model, mp3_file, segment_length=500, device="cuda"):
51
  model.eval()
52
  # Load audio and compute mel spectrogram
 
54
  S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
55
  log_S = librosa.power_to_db(S, ref=np.max)
56
  # Segment the spectrogram
57
+ num_segments = log_S.shape[1] // segment_length
58
+ if num_segments == 0:
59
+ segments = [log_S]
60
+ else:
61
+ segments = [log_S[:, i * segment_length:(i + 1) * segment_length] for i in range(num_segments)]
62
 
63
  predictions = []
 
 
 
64
  # Process each segment individually
65
  for seg in segments:
 
 
 
 
 
66
  seg_resized = cv2.resize(seg, (224, 224), interpolation=cv2.INTER_CUBIC)
67
  seg_rgb = np.repeat(seg_resized[:, :, np.newaxis], 3, axis=-1)
68
  # Create a tensor with batch size 1 and sequence length 1
69
  seg_tensor = torch.from_numpy(seg_rgb).permute(2, 0, 1).float().unsqueeze(0).unsqueeze(0).to(device) # Shape: (1, 1, 3, 224, 224)
70
+ output = model(seg_tensor)
71
+ pred = torch.max(output, dim=1)[1].cpu().numpy()[0]
 
 
 
 
 
 
 
 
72
  predicted_bird = class_names[str(pred)] # Convert pred to string to match JSON keys
73
  predictions.append(predicted_bird)
74
+ return predictions
 
 
75
 
76
  # Initialize the model
77
  resnet = models.resnet50(weights='IMAGENET1K_V2')
 
85
  model.eval()
86
 
87
  # Prediction function for Gradio
88
+ def predict_bird(audio_file):
89
+ if audio_file is None:
90
+ return "Please upload an MP3 file."
 
 
91
 
92
+ predictions = infer_birdcall(model, audio_file, segment_length=500, device=str(device))
 
 
93
 
94
+ # Format the predictions with numbering
95
+ if not predictions:
96
+ return "No birds identified."
 
97
 
98
+ numbered_predictions = [f"{i+1}. {bird}" for i, bird in enumerate(predictions)]
99
+ return "\n".join(numbered_predictions)
100
 
101
+ # Create Gradio Blocks for more complex layout
102
+ with gr.Blocks() as demo:
103
+ gr.Markdown("# Bird Call Identification")
104
 
105
  with gr.Row():
106
  with gr.Column():
107
+ audio_input = gr.Audio(type="filepath", label="Upload Bird Call Audio")
 
 
 
 
 
 
 
108
 
 
109
  with gr.Row():
110
+ submit_btn = gr.Button("Identify Birds")
111
 
 
112
  with gr.Row():
113
+ output_text = gr.Textbox(label="Predicted Bird Species")
 
 
 
 
 
 
 
114
 
115
+ # Bird Species Image
116
  with gr.Row():
117
+ gr.Markdown("## Bird Species")
118
+ species_image = gr.Image("1.jpeg", label="")
119
 
120
+ # Bird Description Image
121
  with gr.Row():
122
+ gr.Markdown("## Bird Description")
123
+ description_image = gr.Image("2.jpeg", label="")
124
 
125
+ # Bird Origins Image
126
  with gr.Row():
127
+ gr.Markdown("## Bird Origins")
128
+ origins_image = gr.Image("3.jpeg", label="")
129
 
130
+ # Set up the prediction event
131
  submit_btn.click(
132
  fn=predict_bird,
133
+ inputs=audio_input,
134
+ outputs=output_text
 
 
 
 
 
 
135
  )
136
 
137
+ # Launch the app
138
+ demo.launch()