Update app.py
Browse files
app.py
CHANGED
|
@@ -6,9 +6,7 @@ import numpy as np
|
|
| 6 |
import json
|
| 7 |
from torchvision import models
|
| 8 |
import librosa
|
| 9 |
-
import
|
| 10 |
-
import io
|
| 11 |
-
from PIL import Image
|
| 12 |
|
| 13 |
# Define the BirdCallRNN model
|
| 14 |
class BirdCallRNN(nn.Module):
|
|
@@ -27,7 +25,7 @@ class BirdCallRNN(nn.Module):
|
|
| 27 |
output = self.fc(rnn_out[:, -1, :]) # Note: We'll use this for single-segment sequences
|
| 28 |
return output
|
| 29 |
|
| 30 |
-
# Function to convert MP3 to mel spectrogram
|
| 31 |
def mp3_to_mel_spectrogram(mp3_file, target_shape=(128, 500), resize_shape=(224, 224)):
|
| 32 |
y, sr = librosa.load(mp3_file, sr=None)
|
| 33 |
S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
|
|
@@ -44,26 +42,11 @@ def mp3_to_mel_spectrogram(mp3_file, target_shape=(128, 500), resize_shape=(224,
|
|
| 44 |
log_S_resized = cv2.resize(log_S_resized, resize_shape, interpolation=cv2.INTER_CUBIC)
|
| 45 |
return log_S_resized
|
| 46 |
|
| 47 |
-
# Generate mel spectrogram image for display
|
| 48 |
-
def generate_mel_spectrogram_plot(log_S):
|
| 49 |
-
plt.figure(figsize=(10, 4))
|
| 50 |
-
plt.imshow(log_S, aspect='auto', origin='lower', cmap='viridis')
|
| 51 |
-
plt.colorbar(format='%+2.0f dB')
|
| 52 |
-
plt.title('Mel Spectrogram')
|
| 53 |
-
plt.tight_layout()
|
| 54 |
-
|
| 55 |
-
# Save plot to a bytes buffer
|
| 56 |
-
buf = io.BytesIO()
|
| 57 |
-
plt.savefig(buf, format='png')
|
| 58 |
-
plt.close()
|
| 59 |
-
buf.seek(0)
|
| 60 |
-
return Image.open(buf)
|
| 61 |
-
|
| 62 |
# Load class mapping globally
|
| 63 |
with open('class_mapping.json', 'r') as f:
|
| 64 |
class_names = json.load(f)
|
| 65 |
|
| 66 |
-
# Revised inference function to predict per segment
|
| 67 |
def infer_birdcall(model, mp3_file, segment_length=500, device="cuda"):
|
| 68 |
model.eval()
|
| 69 |
# Load audio and compute mel spectrogram
|
|
@@ -71,39 +54,24 @@ def infer_birdcall(model, mp3_file, segment_length=500, device="cuda"):
|
|
| 71 |
S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
|
| 72 |
log_S = librosa.power_to_db(S, ref=np.max)
|
| 73 |
# Segment the spectrogram
|
| 74 |
-
num_segments =
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
predictions = []
|
| 78 |
-
confidence_scores = []
|
| 79 |
-
spectrogram_images = []
|
| 80 |
-
|
| 81 |
# Process each segment individually
|
| 82 |
for seg in segments:
|
| 83 |
-
# Generate spectrogram image first
|
| 84 |
-
spec_img = generate_mel_spectrogram_plot(seg)
|
| 85 |
-
spectrogram_images.append(spec_img)
|
| 86 |
-
|
| 87 |
-
# Prepare for model input
|
| 88 |
seg_resized = cv2.resize(seg, (224, 224), interpolation=cv2.INTER_CUBIC)
|
| 89 |
seg_rgb = np.repeat(seg_resized[:, :, np.newaxis], 3, axis=-1)
|
| 90 |
# Create a tensor with batch size 1 and sequence length 1
|
| 91 |
seg_tensor = torch.from_numpy(seg_rgb).permute(2, 0, 1).float().unsqueeze(0).unsqueeze(0).to(device) # Shape: (1, 1, 3, 224, 224)
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
output = model(seg_tensor)
|
| 95 |
-
|
| 96 |
-
# Get prediction
|
| 97 |
-
probabilities = torch.nn.functional.softmax(output, dim=1)
|
| 98 |
-
confidence, pred_idx = torch.max(probabilities, dim=1)
|
| 99 |
-
pred = pred_idx.cpu().numpy()[0]
|
| 100 |
-
conf = confidence.cpu().numpy()[0]
|
| 101 |
-
|
| 102 |
predicted_bird = class_names[str(pred)] # Convert pred to string to match JSON keys
|
| 103 |
predictions.append(predicted_bird)
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
return predictions, confidence_scores, spectrogram_images
|
| 107 |
|
| 108 |
# Initialize the model
|
| 109 |
resnet = models.resnet50(weights='IMAGENET1K_V2')
|
|
@@ -117,75 +85,54 @@ model.load_state_dict(torch.load('model_weights.pth', map_location=device))
|
|
| 117 |
model.eval()
|
| 118 |
|
| 119 |
# Prediction function for Gradio
|
| 120 |
-
def predict_bird(
|
| 121 |
-
if
|
| 122 |
-
return "Please upload an
|
| 123 |
-
|
| 124 |
-
predictions, confidence_scores, spectrograms = infer_birdcall(model, file_path, segment_length=500, device=str(device))
|
| 125 |
|
| 126 |
-
|
| 127 |
-
formatted_predictions = [f"{i+1}. {bird} (Confidence: {conf:.2%})" for i, (bird, conf) in enumerate(zip(predictions, confidence_scores))]
|
| 128 |
-
prediction_text = "\n".join(formatted_predictions)
|
| 129 |
|
| 130 |
-
#
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
bird_origins_img = "3.jpeg"
|
| 134 |
|
| 135 |
-
|
|
|
|
| 136 |
|
| 137 |
-
# Create Gradio
|
| 138 |
-
with gr.Blocks() as
|
| 139 |
-
gr.Markdown("# Bird Call
|
| 140 |
|
| 141 |
with gr.Row():
|
| 142 |
with gr.Column():
|
| 143 |
-
|
| 144 |
-
input_audio = gr.Audio(
|
| 145 |
-
type="filepath",
|
| 146 |
-
label="Upload MP3 file"
|
| 147 |
-
)
|
| 148 |
-
|
| 149 |
-
# Submit button
|
| 150 |
-
submit_btn = gr.Button("Identify Bird Species")
|
| 151 |
|
| 152 |
-
# Results section
|
| 153 |
with gr.Row():
|
| 154 |
-
|
| 155 |
|
| 156 |
-
# Spectrograms gallery - removed style method
|
| 157 |
with gr.Row():
|
| 158 |
-
|
| 159 |
-
label="Mel Spectrograms by Segment",
|
| 160 |
-
show_label=True,
|
| 161 |
-
# Removed style() method that was causing errors
|
| 162 |
-
# Instead using direct parameters if available
|
| 163 |
-
grid=[2, 2],
|
| 164 |
-
height=400
|
| 165 |
-
)
|
| 166 |
|
| 167 |
-
# Bird
|
| 168 |
with gr.Row():
|
| 169 |
-
|
|
|
|
| 170 |
|
|
|
|
| 171 |
with gr.Row():
|
| 172 |
-
|
|
|
|
| 173 |
|
|
|
|
| 174 |
with gr.Row():
|
| 175 |
-
|
|
|
|
| 176 |
|
| 177 |
-
# Set up the
|
| 178 |
submit_btn.click(
|
| 179 |
fn=predict_bird,
|
| 180 |
-
inputs=
|
| 181 |
-
outputs=
|
| 182 |
-
prediction_output,
|
| 183 |
-
spectrogram_gallery,
|
| 184 |
-
bird_species_image,
|
| 185 |
-
bird_description_image,
|
| 186 |
-
bird_origins_image
|
| 187 |
-
]
|
| 188 |
)
|
| 189 |
|
| 190 |
-
# Launch the
|
| 191 |
-
|
|
|
|
| 6 |
import json
|
| 7 |
from torchvision import models
|
| 8 |
import librosa
|
| 9 |
+
import os
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# Define the BirdCallRNN model
|
| 12 |
class BirdCallRNN(nn.Module):
|
|
|
|
| 25 |
output = self.fc(rnn_out[:, -1, :]) # Note: We'll use this for single-segment sequences
|
| 26 |
return output
|
| 27 |
|
| 28 |
+
# Function to convert MP3 to mel spectrogram (unchanged)
|
| 29 |
def mp3_to_mel_spectrogram(mp3_file, target_shape=(128, 500), resize_shape=(224, 224)):
|
| 30 |
y, sr = librosa.load(mp3_file, sr=None)
|
| 31 |
S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
|
|
|
|
| 42 |
log_S_resized = cv2.resize(log_S_resized, resize_shape, interpolation=cv2.INTER_CUBIC)
|
| 43 |
return log_S_resized
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
# Load class mapping globally
|
| 46 |
with open('class_mapping.json', 'r') as f:
|
| 47 |
class_names = json.load(f)
|
| 48 |
|
| 49 |
+
# Revised inference function to predict per segment
|
| 50 |
def infer_birdcall(model, mp3_file, segment_length=500, device="cuda"):
|
| 51 |
model.eval()
|
| 52 |
# Load audio and compute mel spectrogram
|
|
|
|
| 54 |
S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
|
| 55 |
log_S = librosa.power_to_db(S, ref=np.max)
|
| 56 |
# Segment the spectrogram
|
| 57 |
+
num_segments = log_S.shape[1] // segment_length
|
| 58 |
+
if num_segments == 0:
|
| 59 |
+
segments = [log_S]
|
| 60 |
+
else:
|
| 61 |
+
segments = [log_S[:, i * segment_length:(i + 1) * segment_length] for i in range(num_segments)]
|
| 62 |
|
| 63 |
predictions = []
|
|
|
|
|
|
|
|
|
|
| 64 |
# Process each segment individually
|
| 65 |
for seg in segments:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
seg_resized = cv2.resize(seg, (224, 224), interpolation=cv2.INTER_CUBIC)
|
| 67 |
seg_rgb = np.repeat(seg_resized[:, :, np.newaxis], 3, axis=-1)
|
| 68 |
# Create a tensor with batch size 1 and sequence length 1
|
| 69 |
seg_tensor = torch.from_numpy(seg_rgb).permute(2, 0, 1).float().unsqueeze(0).unsqueeze(0).to(device) # Shape: (1, 1, 3, 224, 224)
|
| 70 |
+
output = model(seg_tensor)
|
| 71 |
+
pred = torch.max(output, dim=1)[1].cpu().numpy()[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
predicted_bird = class_names[str(pred)] # Convert pred to string to match JSON keys
|
| 73 |
predictions.append(predicted_bird)
|
| 74 |
+
return predictions
|
|
|
|
|
|
|
| 75 |
|
| 76 |
# Initialize the model
|
| 77 |
resnet = models.resnet50(weights='IMAGENET1K_V2')
|
|
|
|
| 85 |
model.eval()
|
| 86 |
|
| 87 |
# Prediction function for Gradio
|
| 88 |
+
def predict_bird(audio_file):
|
| 89 |
+
if audio_file is None:
|
| 90 |
+
return "Please upload an MP3 file."
|
|
|
|
|
|
|
| 91 |
|
| 92 |
+
predictions = infer_birdcall(model, audio_file, segment_length=500, device=str(device))
|
|
|
|
|
|
|
| 93 |
|
| 94 |
+
# Format the predictions with numbering
|
| 95 |
+
if not predictions:
|
| 96 |
+
return "No birds identified."
|
|
|
|
| 97 |
|
| 98 |
+
numbered_predictions = [f"{i+1}. {bird}" for i, bird in enumerate(predictions)]
|
| 99 |
+
return "\n".join(numbered_predictions)
|
| 100 |
|
| 101 |
+
# Create Gradio Blocks for more complex layout
|
| 102 |
+
with gr.Blocks() as demo:
|
| 103 |
+
gr.Markdown("# Bird Call Identification")
|
| 104 |
|
| 105 |
with gr.Row():
|
| 106 |
with gr.Column():
|
| 107 |
+
audio_input = gr.Audio(type="filepath", label="Upload Bird Call Audio")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
|
|
|
| 109 |
with gr.Row():
|
| 110 |
+
submit_btn = gr.Button("Identify Birds")
|
| 111 |
|
|
|
|
| 112 |
with gr.Row():
|
| 113 |
+
output_text = gr.Textbox(label="Predicted Bird Species")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
+
# Bird Species Image
|
| 116 |
with gr.Row():
|
| 117 |
+
gr.Markdown("## Bird Species")
|
| 118 |
+
species_image = gr.Image("1.jpeg", label="")
|
| 119 |
|
| 120 |
+
# Bird Description Image
|
| 121 |
with gr.Row():
|
| 122 |
+
gr.Markdown("## Bird Description")
|
| 123 |
+
description_image = gr.Image("2.jpeg", label="")
|
| 124 |
|
| 125 |
+
# Bird Origins Image
|
| 126 |
with gr.Row():
|
| 127 |
+
gr.Markdown("## Bird Origins")
|
| 128 |
+
origins_image = gr.Image("3.jpeg", label="")
|
| 129 |
|
| 130 |
+
# Set up the prediction event
|
| 131 |
submit_btn.click(
|
| 132 |
fn=predict_bird,
|
| 133 |
+
inputs=audio_input,
|
| 134 |
+
outputs=output_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
)
|
| 136 |
|
| 137 |
+
# Launch the app
|
| 138 |
+
demo.launch()
|