fevot commited on
Commit
cbdd927
·
verified ·
1 Parent(s): 41339fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -67
app.py CHANGED
@@ -6,11 +6,8 @@ import numpy as np
6
  import json
7
  from torchvision import models
8
  import librosa
9
- import matplotlib.pyplot as plt
10
- from io import BytesIO
11
- import PIL.Image
12
 
13
- # Define the BirdCallRNN model class
14
  class BirdCallRNN(nn.Module):
15
  def __init__(self, resnet, num_features, num_classes):
16
  super(BirdCallRNN, self).__init__()
@@ -24,87 +21,77 @@ class BirdCallRNN(nn.Module):
24
  features = self.resnet(x)
25
  features = features.view(batch, seq_len, -1)
26
  rnn_out, _ = self.rnn(features)
27
- output = self.fc(rnn_out[:, -1, :])
28
  return output
29
 
30
- # Function to plot mel spectrogram
31
- def plot_spectrogram(log_S, sr):
32
- fig, ax = plt.subplots(figsize=(10, 4))
33
- img = librosa.display.specshow(log_S, sr=sr, x_axis='time', y_axis='mel', ax=ax)
34
- fig.colorbar(img, ax=ax, format='%+2.0f dB')
35
- ax.set_title('Mel Spectrogram')
36
- buf = BytesIO()
37
- plt.savefig(buf, format='png')
38
- buf.seek(0)
39
- img = PIL.Image.open(buf)
40
- plt.close(fig)
41
- return img
 
 
 
 
42
 
43
- # Load class mapping
44
  with open('class_mapping.json', 'r') as f:
45
  class_names = json.load(f)
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # Initialize the model
48
  resnet = models.resnet50(weights='IMAGENET1K_V2')
49
  num_features = resnet.fc.in_features
50
  resnet.fc = nn.Identity()
51
- num_classes = len(class_names)
52
  model = BirdCallRNN(resnet, num_features, num_classes)
53
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
  model.to(device)
55
  model.load_state_dict(torch.load('model_weights.pth', map_location=device))
56
  model.eval()
57
 
58
- # Prediction function
59
- def predict_bird(audio):
60
- # Load audio file
61
- y, sr = librosa.load(audio, sr=None)
62
- S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
63
- log_S = librosa.power_to_db(S, ref=np.max)
64
-
65
- # Generate spectrogram image
66
- spectrogram_img = plot_spectrogram(log_S, sr)
67
-
68
- # Segment audio and predict
69
- predictions = []
70
- segment_length = 500
71
- num_segments = log_S.shape[1] // segment_length
72
- segments = [log_S] if num_segments == 0 else [log_S[:, i * segment_length:(i + 1) * segment_length] for i in range(num_segments)]
73
- for seg in segments:
74
- seg_resized = cv2.resize(seg, (224, 224), interpolation=cv2.INTER_CUBIC)
75
- seg_rgb = np.repeat(seg_resized[:, :, np.newaxis], 3, axis=-1)
76
- seg_tensor = torch.from_numpy(seg_rgb).permute(2, 0, 1).float().unsqueeze(0).unsqueeze(0).to(device)
77
- output = model(seg_tensor)
78
- probs = torch.softmax(output, dim=1)
79
- confidence, pred = torch.max(probs, dim=1)
80
- pred = pred.cpu().numpy()[0]
81
- confidence = confidence.cpu().numpy()[0]
82
- predicted_bird = class_names[str(pred)]
83
- predictions.append((predicted_bird, confidence))
84
-
85
- # Format predictions as HTML
86
- predictions_html = "<ol>"
87
- for i, (bird, conf) in enumerate(predictions, 1):
88
- predictions_html += f"<li>{bird} (Confidence: {conf*100:.1f}%)</li>"
89
- predictions_html += "</ol>"
90
-
91
- return spectrogram_img, predictions_html
92
 
93
- # Gradio interface
94
  interface = gr.Interface(
95
  fn=predict_bird,
96
- inputs=gr.Audio(label="Upload MP3 file", type="filepath"),
97
- outputs=[
98
- gr.Image(label="Mel Spectrogram"),
99
- gr.HTML(label="Predicted Bird Species")
100
- ],
101
- description="""
102
- <h3>Bird Species</h3>
103
- <img src='1.jpeg' width='300'>
104
- <h3>Bird Description</h3>
105
- <img src='2.jpeg' width='300'>
106
- <h3>Bird Origins</h3>
107
- <img src='3.jpeg' width='300'>
108
- """
109
  )
110
  interface.launch()
 
6
  import json
7
  from torchvision import models
8
  import librosa
 
 
 
9
 
10
+ # Define the BirdCallRNN model
11
  class BirdCallRNN(nn.Module):
12
  def __init__(self, resnet, num_features, num_classes):
13
  super(BirdCallRNN, self).__init__()
 
21
  features = self.resnet(x)
22
  features = features.view(batch, seq_len, -1)
23
  rnn_out, _ = self.rnn(features)
24
+ output = self.fc(rnn_out[:, -1, :]) # Note: We’ll use this for single-segment sequences
25
  return output
26
 
27
+ # Function to convert MP3 to mel spectrogram (unchanged)
28
+ def mp3_to_mel_spectrogram(mp3_file, target_shape=(128, 500), resize_shape=(224, 224)):
29
+ y, sr = librosa.load(mp3_file, sr=None)
30
+ S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
31
+ log_S = librosa.power_to_db(S, ref=np.max)
32
+ current_time_steps = log_S.shape[1]
33
+ target_time_steps = target_shape[1]
34
+ if current_time_steps < target_time_steps:
35
+ pad_width = target_time_steps - current_time_steps
36
+ log_S_resized = np.pad(log_S, ((0, 0), (0, pad_width)), mode='constant')
37
+ elif current_time_steps > target_time_steps:
38
+ log_S_resized = log_S[:, :target_time_steps]
39
+ else:
40
+ log_S_resized = log_S
41
+ log_S_resized = cv2.resize(log_S_resized, resize_shape, interpolation=cv2.INTER_CUBIC)
42
+ return log_S_resized
43
 
44
+ # Load class mapping globally
45
  with open('class_mapping.json', 'r') as f:
46
  class_names = json.load(f)
47
 
48
+ # Revised inference function to predict per segment
49
+ def infer_birdcall(model, mp3_file, segment_length=500, device="cuda"):
50
+ model.eval()
51
+ # Load audio and compute mel spectrogram
52
+ y, sr = librosa.load(mp3_file, sr=None)
53
+ S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
54
+ log_S = librosa.power_to_db(S, ref=np.max)
55
+ # Segment the spectrogram
56
+ num_segments = log_S.shape[1] // segment_length
57
+ if num_segments == 0:
58
+ segments = [log_S]
59
+ else:
60
+ segments = [log_S[:, i * segment_length:(i + 1) * segment_length] for i in range(num_segments)]
61
+
62
+ predictions = []
63
+ # Process each segment individually
64
+ for seg in segments:
65
+ seg_resized = cv2.resize(seg, (224, 224), interpolation=cv2.INTER_CUBIC)
66
+ seg_rgb = np.repeat(seg_resized[:, :, np.newaxis], 3, axis=-1)
67
+ # Create a tensor with batch size 1 and sequence length 1
68
+ seg_tensor = torch.from_numpy(seg_rgb).permute(2, 0, 1).float().unsqueeze(0).unsqueeze(0).to(device) # Shape: (1, 1, 3, 224, 224)
69
+ output = model(seg_tensor)
70
+ pred = torch.max(output, dim=1)[1].cpu().numpy()[0]
71
+ predicted_bird = class_names[str(pred)] # Convert pred to string to match JSON keys
72
+ predictions.append(predicted_bird)
73
+ return predictions
74
+
75
  # Initialize the model
76
  resnet = models.resnet50(weights='IMAGENET1K_V2')
77
  num_features = resnet.fc.in_features
78
  resnet.fc = nn.Identity()
79
+ num_classes = len(class_names) # Should be 114
80
  model = BirdCallRNN(resnet, num_features, num_classes)
81
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
82
  model.to(device)
83
  model.load_state_dict(torch.load('model_weights.pth', map_location=device))
84
  model.eval()
85
 
86
+ # Prediction function for Gradio
87
+ def predict_bird(file_path):
88
+ predictions = infer_birdcall(model, file_path, segment_length=500, device=str(device))
89
+ return ", ".join(predictions) # Join predictions into a single string
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ # Launch Gradio interface
92
  interface = gr.Interface(
93
  fn=predict_bird,
94
+ inputs=gr.File(label="Upload MP3 file", file_types=['.mp3']),
95
+ outputs=gr.Textbox(label="Predicted Bird Species")
 
 
 
 
 
 
 
 
 
 
 
96
  )
97
  interface.launch()