fevot commited on
Commit
a9eca6f
·
verified ·
1 Parent(s): 41701b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -73
app.py CHANGED
@@ -1,102 +1,82 @@
 
1
  import torch
2
- import torch.nn as nn
3
- import torchvision.models as models
4
- import librosa
5
- import numpy as np
6
  import cv2
 
7
  import json
8
- import gradio as gr
 
9
 
10
- # --------------------------
11
- # Define the Model Architecture
12
- # --------------------------
13
  class BirdCallRNN(nn.Module):
14
- def __init__(self, resnet, num_classes, num_features):
15
- super(BirdCallRNN, self).__init__()
16
  self.resnet = resnet
17
- # RNN expects input of shape (batch, seq_len, feature_dim)
18
- self.rnn = nn.LSTM(input_size=num_features, hidden_size=256, num_layers=2,
19
- batch_first=True, bidirectional=True)
20
- self.fc = nn.Linear(512, num_classes) # 512 = 2 * hidden_size (bidirectional)
21
 
22
  def forward(self, x):
23
- # x shape: (batch, seq_len, 3, 224, 224)
24
  batch, seq_len, C, H, W = x.size()
25
- x = x.view(batch * seq_len, C, H, W) # (batch * seq_len, 3, 224, 224)
26
- features = self.resnet(x) # (batch * seq_len, feature_dim)
27
- features = features.view(batch, seq_len, -1) # (batch, seq_len, feature_dim)
28
- rnn_out, _ = self.rnn(features) # (batch, seq_len, 512)
29
- output = self.fc(rnn_out[:, -1, :]) # Use last time step for classification
30
  return output
31
 
32
- # --------------------------
33
- # Load Model Weights and Class Mapping
34
- # --------------------------
35
- # Load class mapping from JSON file (index -> class name)
36
- with open("class_mapping.json", "r") as f:
37
- class_mapping = json.load(f)
38
- num_classes = len(class_mapping)
39
-
40
- # Load pre-trained ResNet50 and capture the in_features attribute before modification
41
- resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
42
- num_features = resnet.fc.in_features # Capture in_features before replacing fc
43
- resnet.fc = nn.Identity() # Remove the classification head
44
-
45
- # Initialize the BirdCallRNN model and load trained weights
46
- model = BirdCallRNN(resnet, num_classes, num_features)
47
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
- model.to(device)
49
- model.load_state_dict(torch.load("model_weights.pth", map_location=device))
50
- model.eval()
51
-
52
- # --------------------------
53
- # Inference Function
54
- # --------------------------
55
- def predict_bird(mp3_file):
56
- """
57
- Given an uploaded MP3 file, process it and predict the bird species.
58
- """
59
- # Load the audio file (Gradio provides a temporary file path)
60
  y, sr = librosa.load(mp3_file, sr=None)
61
  S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
62
  log_S = librosa.power_to_db(S, ref=np.max)
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- # Define segment length and segment the spectrogram
65
- segment_length = 500
 
 
 
66
  num_segments = log_S.shape[1] // segment_length
67
  if num_segments == 0:
68
  segments = [log_S]
69
  else:
70
  segments = [log_S[:, i * segment_length:(i + 1) * segment_length] for i in range(num_segments)]
71
-
72
  segment_tensors = []
73
  for seg in segments:
74
- # Resize each segment to 224x224 and replicate the single channel to 3 channels
75
  seg_resized = cv2.resize(seg, (224, 224), interpolation=cv2.INTER_CUBIC)
76
  seg_rgb = np.repeat(seg_resized[:, :, np.newaxis], 3, axis=-1)
77
- seg_tensor = torch.tensor(seg_rgb, dtype=torch.float32).permute(2, 0, 1) # (3, 224, 224)
78
  segment_tensors.append(seg_tensor)
79
-
80
- # Stack segments to form a sequence: (1, seq_len, 3, 224, 224)
81
  sequence = torch.stack(segment_tensors, dim=0).unsqueeze(0).to(device)
82
- with torch.no_grad():
83
- output = model(sequence)
84
- pred = torch.argmax(output, dim=1).cpu().numpy()[0]
85
-
86
- # Look up the predicted class name
87
- predicted_bird = class_mapping.get(str(pred), "Unknown")
88
  return predicted_bird
89
 
90
- # --------------------------
91
- # Create Gradio Interface
92
- # --------------------------
93
- iface = gr.Interface(
94
- fn=predict_bird,
95
- inputs=gr.Audio(source="upload", type="filepath"),
96
- outputs="text",
97
- title="BirdCall Classification",
98
- description="Upload an MP3 file of a bird call to classify the bird species."
99
- )
 
 
 
 
100
 
101
- if __name__ == "__main__":
102
- iface.launch()
 
1
+ import gradio as gr
2
  import torch
3
+ from torch import nn
 
 
 
4
  import cv2
5
+ import numpy as np
6
  import json
7
+ from torchvision import models
8
+ import librosa
9
 
 
 
 
10
  class BirdCallRNN(nn.Module):
11
+ def __init__(self, resnet, num_features, num_classes):
12
+ super(BirdCallRNN, self).__init__()
13
  self.resnet = resnet
14
+ self.rnn = nn.LSTM(input_size=num_features, hidden_size=256, num_layers=2, batch_first=True, bidirectional=True)
15
+ self.fc = nn.Linear(512, num_classes)
 
 
16
 
17
  def forward(self, x):
 
18
  batch, seq_len, C, H, W = x.size()
19
+ x = x.view(batch * seq_len, C, H, W)
20
+ features = self.resnet(x)
21
+ features = features.view(batch, seq_len, -1)
22
+ rnn_out, _ = self.rnn(features)
23
+ output = self.fc(rnn_out[:, -1, :])
24
  return output
25
 
26
+ def mp3_to_mel_spectrogram(mp3_file, target_shape=(128, 500), resize_shape=(224, 224)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  y, sr = librosa.load(mp3_file, sr=None)
28
  S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
29
  log_S = librosa.power_to_db(S, ref=np.max)
30
+ current_time_steps = log_S.shape[1]
31
+ target_time_steps = target_shape[1]
32
+ if current_time_steps < target_time_steps:
33
+ pad_width = target_time_steps - current_time_steps
34
+ log_S_resized = np.pad(log_S, ((0, 0), (0, pad_width)), mode='constant')
35
+ elif current_time_steps > target_time_steps:
36
+ log_S_resized = log_S[:, :target_time_steps]
37
+ else:
38
+ log_S_resized = log_S
39
+ log_S_resized = cv2.resize(log_S_resized, resize_shape, interpolation=cv2.INTER_CUBIC)
40
+ return log_S_resized
41
 
42
+ def infer_birdcall(model, mp3_file, segment_length=500, device="cuda"):
43
+ model.eval()
44
+ y, sr = librosa.load(mp3_file, sr=None)
45
+ S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
46
+ log_S = librosa.power_to_db(S, ref=np.max)
47
  num_segments = log_S.shape[1] // segment_length
48
  if num_segments == 0:
49
  segments = [log_S]
50
  else:
51
  segments = [log_S[:, i * segment_length:(i + 1) * segment_length] for i in range(num_segments)]
 
52
  segment_tensors = []
53
  for seg in segments:
 
54
  seg_resized = cv2.resize(seg, (224, 224), interpolation=cv2.INTER_CUBIC)
55
  seg_rgb = np.repeat(seg_resized[:, :, np.newaxis], 3, axis=-1)
56
+ seg_tensor = torch.from_numpy(seg_rgb).permute(2, 0, 1).float()
57
  segment_tensors.append(seg_tensor)
 
 
58
  sequence = torch.stack(segment_tensors, dim=0).unsqueeze(0).to(device)
59
+ output = model(sequence)
60
+ pred = torch.max(output, dim=1)[1].cpu().numpy()[0]
61
+ with open('class_names.json', 'r') as f:
62
+ class_names = json.load(f)
63
+ predicted_bird = class_names[pred]
 
64
  return predicted_bird
65
 
66
+ resnet = models.resnet50(weights='IMAGENET1K_V2')
67
+ num_features = resnet.fc.in_features
68
+ resnet.fc = nn.Identity()
69
+ with open('class_names.json', 'r') as f:
70
+ class_names = json.load(f)
71
+ num_classes = len(class_names)
72
+ model = BirdCallRNN(resnet, num_features, num_classes)
73
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
+ model.to(device)
75
+ model.load_state_dict(torch.load('birdcall_model.pth', map_location=device))
76
+ model.eval()
77
+
78
+ def predict_bird(file_path):
79
+ return infer_birdcall(model, file_path, segment_length=500, device=str(device))
80
 
81
+ interface = gr.Interface(fn=predict_bird, inputs=gr.File(label="Upload MP3 file", file_types=['.mp3']), outputs=gr.Textbox(label="Predicted Bird Species"))
82
+ interface.launch()