Calotriton commited on
Commit
38503c5
Β·
verified Β·
1 Parent(s): 7143a96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -49
app.py CHANGED
@@ -53,7 +53,7 @@ class EfficientNetSE(nn.Module):
53
  return self.classifier(x)
54
 
55
  # ----------------------------
56
- # 3) Audio preprocessing
57
  # ----------------------------
58
  def load_and_normalize(path, sr=SR, target_dBFS=-20.0):
59
  y, _ = librosa.load(path, sr=sr)
@@ -63,12 +63,11 @@ def load_and_normalize(path, sr=SR, target_dBFS=-20.0):
63
  return y * scalar
64
 
65
  def bandpass(y, sr=SR, low=FMIN, high=FMAX, order=6):
66
- nyq = 0.5*sr
67
- b,a = sps.butter(order, [low/nyq, high/nyq], btype='band')
68
- return sps.filtfilt(b,a,y)
69
 
70
  def segment(y, sr=SR, win=DURATION, hop=1.0):
71
- w = int(win*sr); h = int(hop*sr)
72
  if len(y) < w:
73
  y = np.pad(y, (0, w - len(y)))
74
  return [y]
@@ -83,8 +82,8 @@ def extract_log_mel(y, sr=SR, n_mels=128, hop_length=HOP, fmin=FMIN, fmax=FMAX):
83
 
84
  def predict_segments(fp):
85
  y = load_and_normalize(fp)
86
- y = bandpass(y)
87
- segs = segment(y)
88
  all_p = []
89
  with torch.no_grad():
90
  for seg in segs:
@@ -100,18 +99,14 @@ def predict_segments(fp):
100
  with open(DATA_PKL, "rb") as f:
101
  data = pickle.load(f)
102
  classes = data["classes"]
103
- orig_thresholds = np.array(data["thresholds"])
104
  adj_thresholds = np.array(data["adj_thresholds"])
105
 
106
- # Rebuild encoder
107
  le = LabelEncoder()
108
  le.classes_ = np.array(classes, dtype=object)
109
 
110
- # Calibrators
111
  with open(CAL_PATH, "rb") as f:
112
  calibrators = pickle.load(f)
113
 
114
- # Load backbone & model
115
  backbone = torch.hub.load('pytorch/vision:v0.14.0','efficientnet_b0',pretrained=True)
116
  backbone.features[0][0] = nn.Conv2d(1,32,3,2,1,bias=False)
117
  model = EfficientNetSE(backbone, num_classes=len(le.classes_)).to(DEVICE)
@@ -122,62 +117,48 @@ model.eval()
122
  # 5) Inference logic
123
  # ----------------------------
124
  def infer(audio_path, sensitivity):
125
- # segments β†’ probabilities
126
  seg_probs = predict_segments(audio_path)
127
  agg = np.percentile(seg_probs, 90, axis=0)
128
- # calibrate
129
  calibrated = np.array([
130
  calibrators[i].transform([agg[i]])[0]
131
  for i in range(len(le.classes_))
132
  ])
133
- # adjust thresholds
134
  thresholds = adj_thresholds * sensitivity
135
  preds = calibrated > thresholds
136
 
137
- # build results
138
- results = [(le.classes_[i].replace("_"," "), round(float(calibrated[i]),3))
139
- for i, flag in enumerate(preds) if flag]
140
- if not results:
141
  return "πŸ” **No species confidently detected.**\nTry reducing the strictness."
142
 
143
- # sort and format Markdown with italics species names
144
- results.sort(key=lambda x: -x[1])
145
- md = "### βœ… Detected species:\n"
146
- for sp, p in results:
147
- md += f"- *{sp}* β€” probability: {p}\n"
148
- return md
149
 
150
  # ----------------------------
151
- # 6) Gradio Blocks interface
152
  # ----------------------------
153
- with gr.Blocks() as demo:
154
- gr.Markdown("# 🐸 RibbID – Amphibian species identifier\n")
155
- # Intro sentence about native species
 
 
 
 
 
156
  gr.Markdown(
157
- "This CNN model detects the native frog and toad species of **Catalonia** (Nort-East Spain) through ther calls."
158
  )
159
- gr.Markdown(
160
- "To start, **upload** an audio file or record a new one. Next, **select** the detection strictness in the slider, and click **submit**. Results might take time.\n"
161
- "\n"
162
- "**Detection strictness** controls how conservative the model is:\n"
163
- "- **Lower values (0.5)** = more sensitive (may include false positives).\n"
164
- "- **Higher values (1.0)** = only very confident detections (may ignore true positives)."
165
- )
166
-
167
  with gr.Row():
168
- audio = gr.Audio(type="filepath", label="Upload audio file (.wav/.mp3) or record live")
169
  slider = gr.Slider(0.5, 1.0, value=1.0, step=0.05,
170
  label="Detection strictness")
171
-
172
- output = gr.Markdown()
173
-
174
- btn = gr.Button("Submit")
175
- btn.click(
176
- fn=infer,
177
- inputs=[audio, slider],
178
- outputs=[output],
179
- show_progress=True
180
- )
181
 
182
  if __name__ == "__main__":
183
- demo.launch(share=False)
 
53
  return self.classifier(x)
54
 
55
  # ----------------------------
56
+ # 3) Audio preprocessing functions
57
  # ----------------------------
58
  def load_and_normalize(path, sr=SR, target_dBFS=-20.0):
59
  y, _ = librosa.load(path, sr=sr)
 
63
  return y * scalar
64
 
65
  def bandpass(y, sr=SR, low=FMIN, high=FMAX, order=6):
66
+ sos = sps.butter(order, [low, high], btype='band', fs=sr, output='sos')
67
+ return sps.sosfilt(sos, y)
 
68
 
69
  def segment(y, sr=SR, win=DURATION, hop=1.0):
70
+ w, h = int(win*sr), int(hop*sr)
71
  if len(y) < w:
72
  y = np.pad(y, (0, w - len(y)))
73
  return [y]
 
82
 
83
  def predict_segments(fp):
84
  y = load_and_normalize(fp)
85
+ y = bandpass(y, sr=SR)
86
+ segs = segment(y, sr=SR)
87
  all_p = []
88
  with torch.no_grad():
89
  for seg in segs:
 
99
  with open(DATA_PKL, "rb") as f:
100
  data = pickle.load(f)
101
  classes = data["classes"]
 
102
  adj_thresholds = np.array(data["adj_thresholds"])
103
 
 
104
  le = LabelEncoder()
105
  le.classes_ = np.array(classes, dtype=object)
106
 
 
107
  with open(CAL_PATH, "rb") as f:
108
  calibrators = pickle.load(f)
109
 
 
110
  backbone = torch.hub.load('pytorch/vision:v0.14.0','efficientnet_b0',pretrained=True)
111
  backbone.features[0][0] = nn.Conv2d(1,32,3,2,1,bias=False)
112
  model = EfficientNetSE(backbone, num_classes=len(le.classes_)).to(DEVICE)
 
117
  # 5) Inference logic
118
  # ----------------------------
119
  def infer(audio_path, sensitivity):
 
120
  seg_probs = predict_segments(audio_path)
121
  agg = np.percentile(seg_probs, 90, axis=0)
 
122
  calibrated = np.array([
123
  calibrators[i].transform([agg[i]])[0]
124
  for i in range(len(le.classes_))
125
  ])
 
126
  thresholds = adj_thresholds * sensitivity
127
  preds = calibrated > thresholds
128
 
129
+ detected = [(le.classes_[i].replace("_"," "), round(float(calibrated[i]),3))
130
+ for i, flag in enumerate(preds) if flag]
131
+ if not detected:
 
132
  return "πŸ” **No species confidently detected.**\nTry reducing the strictness."
133
 
134
+ detected.sort(key=lambda x: -x[1])
135
+ md = "<h3 style='color:#2b7a78;'>βœ… Detected Species</h3><ul>"
136
+ for sp, p in detected:
137
+ md += f"<li><em>{sp}</em> β€” probability: <strong>{p}</strong></li>"
138
+ md += "</ul>"
139
+ return gr.HTML(md)
140
 
141
  # ----------------------------
142
+ # 6) Gradio Blocks UI
143
  # ----------------------------
144
+ custom_css = '''
145
+ body { background-color: #f0f8ff; }
146
+ h1, h3 { font-family: 'Helvetica Neue', sans-serif; }
147
+ .gr-button { background-color: #2b7a78 !important; color: white !important; }
148
+ '''
149
+ with gr.Blocks(css=custom_css) as demo:
150
+ gr.HTML("<h1 style='text-align:center; color:#17252a;'>🐸 RibbID</h1>")
151
+ gr.HTML("<p style='text-align:center;'>Detects native frog and toad species of Catalonia from audio calls.</p>")
152
  gr.Markdown(
153
+ "**Strictness** controls detection sensitivity. Lower=more sensitive, higher=more conservative."
154
  )
 
 
 
 
 
 
 
 
155
  with gr.Row():
156
+ audio = gr.Audio(type="filepath", label="Upload or record audio (.wav/.mp3)")
157
  slider = gr.Slider(0.5, 1.0, value=1.0, step=0.05,
158
  label="Detection strictness")
159
+ output = gr.HTML()
160
+ submit = gr.Button("πŸ” Identify Species")
161
+ submit.click(fn=infer, inputs=[audio, slider], outputs=[output], show_progress=True)
 
 
 
 
 
 
 
162
 
163
  if __name__ == "__main__":
164
+ demo.launch(share=False)