fevot commited on
Commit
dd28e0e
·
verified ·
1 Parent(s): 1d4ed0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -44
app.py CHANGED
@@ -21,61 +21,63 @@ class BirdCallRNN(nn.Module):
21
  features = self.resnet(x)
22
  features = features.view(batch, seq_len, -1)
23
  rnn_out, _ = self.rnn(features)
24
- output = self.fc(rnn_out[:, -1, :])
25
  return output
26
 
27
- # Function to convert MP3 to mel spectrogram
28
  def mp3_to_mel_spectrogram(mp3_file, target_shape=(128, 500), resize_shape=(224, 224)):
29
  y, sr = librosa.load(mp3_file, sr=None)
30
  S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
31
  log_S = librosa.power_to_db(S, ref=np.max)
32
-
33
- # Ensure the correct time step size
34
  current_time_steps = log_S.shape[1]
35
  target_time_steps = target_shape[1]
36
  if current_time_steps < target_time_steps:
37
  pad_width = target_time_steps - current_time_steps
38
  log_S_resized = np.pad(log_S, ((0, 0), (0, pad_width)), mode='constant')
39
- else:
40
  log_S_resized = log_S[:, :target_time_steps]
41
-
 
42
  log_S_resized = cv2.resize(log_S_resized, resize_shape, interpolation=cv2.INTER_CUBIC)
43
  return log_S_resized
44
 
45
- # Load class mapping
46
  with open('class_mapping.json', 'r') as f:
47
  class_names = json.load(f)
48
 
49
- # Inference function for bird call identification
50
  def infer_birdcall(model, mp3_file, segment_length=500, device="cuda"):
51
  model.eval()
 
52
  y, sr = librosa.load(mp3_file, sr=None)
53
  S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
54
  log_S = librosa.power_to_db(S, ref=np.max)
55
-
56
  # Segment the spectrogram
57
  num_segments = log_S.shape[1] // segment_length
58
- segments = [log_S[:, i * segment_length:(i + 1) * segment_length] for i in range(num_segments)] if num_segments > 0 else [log_S]
59
-
 
 
 
60
  predictions = []
 
61
  for seg in segments:
62
  seg_resized = cv2.resize(seg, (224, 224), interpolation=cv2.INTER_CUBIC)
63
  seg_rgb = np.repeat(seg_resized[:, :, np.newaxis], 3, axis=-1)
64
- seg_tensor = torch.from_numpy(seg_rgb).permute(2, 0, 1).float().unsqueeze(0).unsqueeze(0).to(device)
 
65
  output = model(seg_tensor)
66
  pred = torch.max(output, dim=1)[1].cpu().numpy()[0]
67
- predicted_bird = class_names[str(pred)]
68
  predictions.append(predicted_bird)
 
69
 
70
- return "\n".join([f"{i+1}. {pred}" for i, pred in enumerate(predictions)])
71
-
72
- # Initialize model
73
  resnet = models.resnet50(weights='IMAGENET1K_V2')
74
  num_features = resnet.fc.in_features
75
  resnet.fc = nn.Identity()
76
- num_classes = len(class_names)
77
  model = BirdCallRNN(resnet, num_features, num_classes)
78
-
79
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
80
  model.to(device)
81
  model.load_state_dict(torch.load('model_weights.pth', map_location=device))
@@ -83,29 +85,37 @@ model.eval()
83
 
84
  # Prediction function for Gradio
85
  def predict_bird(file_path):
86
- return infer_birdcall(model, file_path, segment_length=500, device=str(device))
87
-
88
- # Gradio interface
89
- with gr.Blocks() as interface:
90
- gr.Markdown("### Bird Call Identification")
91
-
92
- # File upload and audio preview
93
- with gr.Row():
94
- file_input = gr.File(label="Upload MP3 file", file_types=['.mp3'])
95
- audio_player = gr.Audio(label="Uploaded MP3 File")
96
-
97
- # Update audio player after upload
98
- file_input.change(lambda file: file.name if file else None, inputs=file_input, outputs=audio_player)
99
-
100
- # Prediction UI
101
- with gr.Row():
102
- submit_button = gr.Button("Submit")
103
-
104
- prediction_output = gr.Textbox(label="Predicted Bird Species")
105
- bird_species_image = gr.Image(label="Bird Species")
106
- bird_description_image = gr.Image(label="Bird Description")
107
- bird_origins_image = gr.Image(label="Bird Origins")
108
-
109
- submit_button.click(predict_bird, inputs=file_input, outputs=prediction_output)
110
-
111
- interface.launch()
 
 
 
 
 
 
 
 
 
21
  features = self.resnet(x)
22
  features = features.view(batch, seq_len, -1)
23
  rnn_out, _ = self.rnn(features)
24
+ output = self.fc(rnn_out[:, -1, :]) # Note: We’ll use this for single-segment sequences
25
  return output
26
 
27
+ # Function to convert MP3 to mel spectrogram (unchanged)
28
  def mp3_to_mel_spectrogram(mp3_file, target_shape=(128, 500), resize_shape=(224, 224)):
29
  y, sr = librosa.load(mp3_file, sr=None)
30
  S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
31
  log_S = librosa.power_to_db(S, ref=np.max)
 
 
32
  current_time_steps = log_S.shape[1]
33
  target_time_steps = target_shape[1]
34
  if current_time_steps < target_time_steps:
35
  pad_width = target_time_steps - current_time_steps
36
  log_S_resized = np.pad(log_S, ((0, 0), (0, pad_width)), mode='constant')
37
+ elif current_time_steps > target_time_steps:
38
  log_S_resized = log_S[:, :target_time_steps]
39
+ else:
40
+ log_S_resized = log_S
41
  log_S_resized = cv2.resize(log_S_resized, resize_shape, interpolation=cv2.INTER_CUBIC)
42
  return log_S_resized
43
 
44
+ # Load class mapping globally
45
  with open('class_mapping.json', 'r') as f:
46
  class_names = json.load(f)
47
 
48
+ # Revised inference function to predict per segment
49
  def infer_birdcall(model, mp3_file, segment_length=500, device="cuda"):
50
  model.eval()
51
+ # Load audio and compute mel spectrogram
52
  y, sr = librosa.load(mp3_file, sr=None)
53
  S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
54
  log_S = librosa.power_to_db(S, ref=np.max)
 
55
  # Segment the spectrogram
56
  num_segments = log_S.shape[1] // segment_length
57
+ if num_segments == 0:
58
+ segments = [log_S]
59
+ else:
60
+ segments = [log_S[:, i * segment_length:(i + 1) * segment_length] for i in range(num_segments)]
61
+
62
  predictions = []
63
+ # Process each segment individually
64
  for seg in segments:
65
  seg_resized = cv2.resize(seg, (224, 224), interpolation=cv2.INTER_CUBIC)
66
  seg_rgb = np.repeat(seg_resized[:, :, np.newaxis], 3, axis=-1)
67
+ # Create a tensor with batch size 1 and sequence length 1
68
+ seg_tensor = torch.from_numpy(seg_rgb).permute(2, 0, 1).float().unsqueeze(0).unsqueeze(0).to(device) # Shape: (1, 1, 3, 224, 224)
69
  output = model(seg_tensor)
70
  pred = torch.max(output, dim=1)[1].cpu().numpy()[0]
71
+ predicted_bird = class_names[str(pred)] # Convert pred to string to match JSON keys
72
  predictions.append(predicted_bird)
73
+ return predictions
74
 
75
+ # Initialize the model
 
 
76
  resnet = models.resnet50(weights='IMAGENET1K_V2')
77
  num_features = resnet.fc.in_features
78
  resnet.fc = nn.Identity()
79
+ num_classes = len(class_names) # Should be 114
80
  model = BirdCallRNN(resnet, num_features, num_classes)
 
81
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
82
  model.to(device)
83
  model.load_state_dict(torch.load('model_weights.pth', map_location=device))
 
85
 
86
  # Prediction function for Gradio
87
  def predict_bird(file_path):
88
+ predictions = infer_birdcall(model, file_path, segment_length=500, device=str(device))
89
+ # Format predictions as a numbered list
90
+ formatted_predictions = "\n".join([f"{i+1}. {pred}" for i, pred in enumerate(predictions)])
91
+ return formatted_predictions # Return formatted list of predictions
92
+
93
+ # Custom Gradio interface with additional components
94
+ def gradio_interface(file_path):
95
+ # Predict bird species
96
+ prediction = predict_bird(file_path)
97
+
98
+ # Display the uploaded MP3 file with a play button
99
+ audio_player = gr.Audio(file_path, label="Uploaded MP3 File", visible=True, autoplay=True)
100
+
101
+ # Display images with titles
102
+ bird_species_image = gr.Image("1.jpg", label="Bird Species")
103
+ bird_description_image = gr.Image("2.jpg", label="Bird Description")
104
+ bird_origins_image = gr.Image("3.jpg", label="Bird Origins")
105
+
106
+ return prediction, audio_player, bird_species_image, bird_description_image, bird_origins_image
107
+
108
+ # Launch Gradio interface
109
+ interface = gr.Interface(
110
+ fn=gradio_interface,
111
+ inputs=[
112
+ gr.File(label="Upload MP3 file", file_types=['.mp3']),
113
+ gr.Audio(label="Uploaded MP3 File"),
114
+ ]
115
+ outputs=[
116
+ gr.Textbox(label="Predicted Bird Species"),
117
+ gr.Image(label="Bird Species"),
118
+ gr.Image(label="Bird Description"),
119
+ gr.Image(label="Bird Origins")
120
+ ]
121
+ )