fevot commited on
Commit
5cc1efc
·
verified ·
1 Parent(s): 1e646fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -53
app.py CHANGED
@@ -1,83 +1,100 @@
1
- import gradio as gr
2
  import torch
3
- from torch import nn
4
- import cv2
 
5
  import numpy as np
 
6
  import json
7
- from torchvision import models
8
- import librosa
9
 
 
 
 
10
  class BirdCallRNN(nn.Module):
11
  def __init__(self, resnet, num_classes):
12
- super(BirdCallRNN, self).__init__()
13
  self.resnet = resnet
14
- num_features = self.resnet.fc.in_features
15
- self.resnet.fc = nn.Identity()
16
- self.rnn = nn.LSTM(input_size=num_features, hidden_size=256, num_layers=2, batch_first=True, bidirectional=True)
17
- self.fc = nn.Linear(512, num_classes)
18
 
19
  def forward(self, x):
 
20
  batch, seq_len, C, H, W = x.size()
21
- x = x.view(batch * seq_len, C, H, W)
22
- features = self.resnet(x)
23
- features = features.view(batch, seq_len, -1)
24
- rnn_out, _ = self.rnn(features)
25
- output = self.fc(rnn_out[:, -1, :])
26
  return output
27
 
28
- def mp3_to_mel_spectrogram(mp3_file, target_shape=(128, 500), resize_shape=(224, 224)):
29
- y, sr = librosa.load(mp3_file, sr=None)
30
- S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
31
- log_S = librosa.power_to_db(S, ref=np.max)
32
- current_time_steps = log_S.shape[1]
33
- target_time_steps = target_shape[1]
34
- if current_time_steps < target_time_steps:
35
- pad_width = target_time_steps - current_time_steps
36
- log_S_resized = np.pad(log_S, ((0, 0), (0, pad_width)), mode='constant')
37
- elif current_time_steps > target_time_steps:
38
- log_S_resized = log_S[:, :target_time_steps]
39
- else:
40
- log_S_resized = log_S
41
- log_S_resized = cv2.resize(log_S_resized, resize_shape, interpolation=cv2.INTER_CUBIC)
42
- return log_S_resized
43
 
44
- def infer_birdcall(model, mp3_file, segment_length=500, device="cuda"):
45
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  y, sr = librosa.load(mp3_file, sr=None)
47
  S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
48
  log_S = librosa.power_to_db(S, ref=np.max)
 
 
 
49
  num_segments = log_S.shape[1] // segment_length
50
  if num_segments == 0:
51
  segments = [log_S]
52
  else:
53
  segments = [log_S[:, i * segment_length:(i + 1) * segment_length] for i in range(num_segments)]
 
54
  segment_tensors = []
55
  for seg in segments:
 
56
  seg_resized = cv2.resize(seg, (224, 224), interpolation=cv2.INTER_CUBIC)
57
  seg_rgb = np.repeat(seg_resized[:, :, np.newaxis], 3, axis=-1)
58
- seg_tensor = torch.Tensor(seg_rgb).permute(2, 0, 1).float()
59
  segment_tensors.append(seg_tensor)
 
 
60
  sequence = torch.stack(segment_tensors, dim=0).unsqueeze(0).to(device)
61
- output = model(sequence)
62
- pred = torch.max(output, dim=1)[1].cpu().numpy()[0]
63
- with open('class_names.json', 'r') as f:
64
- class_names = json.load(f)
65
- predicted_bird = class_names[pred]
66
- return predicted_bird
67
 
68
- # Load model and set up
69
- resnet = models.resnet50(weights='IMAGENET1K_V2')
70
- with open('class_names.json', 'r') as f:
71
- class_names = json.load(f)
72
- num_classes = len(class_names)
73
- model = BirdCallRNN(resnet, num_classes)
74
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
75
- model.to(device)
76
- model.load_state_dict(torch.load('birdcall_model.pth', map_location=device))
77
- model.eval()
78
 
79
- def predict_bird(file_path):
80
- return infer_birdcall(model, file_path, segment_length=500, device=str(device))
 
 
 
 
 
 
 
 
81
 
82
- interface = gr.Interface(fn=predict_bird, inputs=gr.File(label="Upload MP3 file", file_types=['.mp3']), outputs=gr.Textbox(label="Predicted Bird Species"))
83
- interface.launch()
 
 
1
  import torch
2
+ import torch.nn as nn
3
+ import torchvision.models as models
4
+ import librosa
5
  import numpy as np
6
+ import cv2
7
  import json
8
+ import gradio as gr
 
9
 
10
+ # --------------------------
11
+ # Define the Model Architecture
12
+ # --------------------------
13
  class BirdCallRNN(nn.Module):
14
  def __init__(self, resnet, num_classes):
15
+ super(BirdCallRNN, self).__init__()
16
  self.resnet = resnet
17
+ # RNN expects input of shape (batch, seq_len, feature_dim)
18
+ self.rnn = nn.LSTM(input_size=resnet.fc.in_features, hidden_size=256, num_layers=2, batch_first=True, bidirectional=True)
19
+ self.fc = nn.Linear(512, num_classes) # 512 = 2 * hidden_size (bidirectional)
 
20
 
21
  def forward(self, x):
22
+ # x shape: (batch, seq_len, 3, 224, 224)
23
  batch, seq_len, C, H, W = x.size()
24
+ x = x.view(batch * seq_len, C, H, W) # (batch * seq_len, 3, 224, 224)
25
+ features = self.resnet(x) # (batch * seq_len, feature_dim)
26
+ features = features.view(batch, seq_len, -1) # (batch, seq_len, feature_dim)
27
+ rnn_out, _ = self.rnn(features) # (batch, seq_len, 512)
28
+ output = self.fc(rnn_out[:, -1, :]) # Use last time step for classification
29
  return output
30
 
31
+ # --------------------------
32
+ # Load Model Weights and Class Mapping
33
+ # --------------------------
34
+ # Load class mapping from JSON file (index -> class name)
35
+ with open("class_mapping.json", "r") as f:
36
+ class_mapping = json.load(f)
37
+ num_classes = len(class_mapping)
 
 
 
 
 
 
 
 
38
 
39
+ # Load pre-trained ResNet50 and remove its classification head
40
+ resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
41
+ resnet.fc = nn.Identity()
42
+
43
+ # Initialize the BirdCallRNN model and load trained weights
44
+ model = BirdCallRNN(resnet, num_classes)
45
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
+ model.to(device)
47
+ model.load_state_dict(torch.load("model_weights.pth", map_location=device))
48
+ model.eval()
49
+
50
+ # --------------------------
51
+ # Inference Function
52
+ # --------------------------
53
+ def predict_bird(mp3_file):
54
+ """
55
+ Given an uploaded MP3 file, process it and predict the bird species.
56
+ """
57
+ # Load the audio file (Gradio provides a temporary file path)
58
  y, sr = librosa.load(mp3_file, sr=None)
59
  S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
60
  log_S = librosa.power_to_db(S, ref=np.max)
61
+
62
+ # Define segment length and segment the spectrogram
63
+ segment_length = 500
64
  num_segments = log_S.shape[1] // segment_length
65
  if num_segments == 0:
66
  segments = [log_S]
67
  else:
68
  segments = [log_S[:, i * segment_length:(i + 1) * segment_length] for i in range(num_segments)]
69
+
70
  segment_tensors = []
71
  for seg in segments:
72
+ # Resize each segment to 224x224 and replicate the single channel to 3 channels
73
  seg_resized = cv2.resize(seg, (224, 224), interpolation=cv2.INTER_CUBIC)
74
  seg_rgb = np.repeat(seg_resized[:, :, np.newaxis], 3, axis=-1)
75
+ seg_tensor = torch.tensor(seg_rgb, dtype=torch.float32).permute(2, 0, 1) # (3, 224, 224)
76
  segment_tensors.append(seg_tensor)
77
+
78
+ # Stack segments to form a sequence: (1, seq_len, 3, 224, 224)
79
  sequence = torch.stack(segment_tensors, dim=0).unsqueeze(0).to(device)
80
+ with torch.no_grad():
81
+ output = model(sequence)
82
+ pred = torch.argmax(output, dim=1).cpu().numpy()[0]
 
 
 
83
 
84
+ # Look up the predicted class name
85
+ predicted_bird = class_mapping.get(str(pred), "Unknown")
86
+ return predicted_bird
 
 
 
 
 
 
 
87
 
88
+ # --------------------------
89
+ # Create Gradio Interface
90
+ # --------------------------
91
+ iface = gr.Interface(
92
+ fn=predict_bird,
93
+ inputs=gr.Audio(source="upload", type="filepath"),
94
+ outputs="text",
95
+ title="BirdCall Classification",
96
+ description="Upload an MP3 file of a bird call to classify the bird species."
97
+ )
98
 
99
+ if __name__ == "__main__":
100
+ iface.launch()