Calotriton commited on
Commit
08175df
·
verified ·
1 Parent(s): 3d23449

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -76
app.py CHANGED
@@ -1,98 +1,180 @@
1
- import gradio as gr
2
- import torch
3
- import torchaudio
4
- import numpy as np
5
- import pickle
6
  import json
7
- from sklearn.isotonic import IsotonicRegression
8
- from model import EfficientNetSE, load_and_normalize, bandpass, segment, extract_log_mel
9
-
10
- # -------------------- Load resources --------------------
11
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
-
13
- # Model
14
- model = EfficientNetSE()
15
- model.load_state_dict(torch.load("cnn_final.pth", map_location=DEVICE))
16
- model.eval()
17
- model.to(DEVICE)
18
-
19
- # Label encoder, thresholds, calibrators
20
- with open("label_encoder_and_thresholds.pkl", "rb") as f:
21
- data = pickle.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- classes = data["classes"]
24
- thresholds = data["thresholds"]
25
- adj_thresholds = data["adj_thresholds"]
26
- calibrators = data["calibrators"]
27
-
28
- # -------------------- Inference --------------------
29
- def predict(audio_path, override_max=1.0):
30
- # Load and preprocess
31
- y = load_and_normalize(audio_path)
32
- y = bandpass(y, sr=32000)
33
- segments = segment(y, sr=32000)
34
-
35
- if len(segments) == 0:
36
- return "⚠️ No usable segments found in the audio file."
37
-
38
- segment_preds = []
39
  with torch.no_grad():
40
- for seg in segments:
41
  mel = extract_log_mel(seg)
42
- inp = torch.tensor(mel[None, None], dtype=torch.float32).to(DEVICE)
43
  out = model(inp)
44
- prob = torch.sigmoid(out).cpu().numpy()[0]
45
- segment_preds.append(prob)
46
 
47
- segment_preds = np.array(segment_preds)
48
- agg = np.percentile(segment_preds, 90, axis=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
 
 
 
 
 
 
 
 
50
  calibrated = np.array([
51
  calibrators[i].transform([agg[i]])[0]
52
- for i in range(len(classes))
53
  ])
54
-
55
- final = {}
56
- for i, sp in enumerate(classes):
57
- threshold = min(adj_thresholds[i], override_max)
58
- detected = calibrated[i] > threshold
59
- if detected:
60
- final[sp] = f"{calibrated[i]:.2f}"
61
-
62
- if not final:
63
- return "🔍 No species confidently detected."
64
-
65
- result_md = "### ✅ Detected species:\n"
66
- for sp, prob in final.items():
67
- result_md += f"- **{sp}**: {prob}\n"
68
- return result_md
69
-
70
- # -------------------- Interface --------------------
 
 
 
71
  with gr.Blocks() as demo:
72
- gr.Markdown("# 🐸 RibbID – European Frog Call Identifier")
73
  gr.Markdown(
74
- "Upload a recording of frog calls and RibbID will identify the species present.\n\n"
75
- "**Detection strictness** controls how confident the model must be to report a detection:\n"
76
- "- Lower = more sensitive (can include false positives)\n"
77
- "- Higher = more conservative (only very confident predictions shown)"
78
  )
79
 
80
  with gr.Row():
81
- audio_input = gr.Audio(type="filepath", label="Upload your audio (WAV/MP3)")
82
- slider = gr.Slider(minimum=0.5, maximum=1.0, value=1.0, step=0.01, label="Detection strictness")
 
83
 
84
- status = gr.Markdown("") # Spinner text
85
  output = gr.Markdown()
86
 
87
- def wrapped_predict(audio_path, slider_value):
88
  status.update("⏳ Processing...")
89
- result = predict(audio_path, override_max=slider_value)
90
- status.update("") # clear
91
- return result
92
 
93
- submit_btn = gr.Button("Submit")
94
- submit_btn.click(fn=wrapped_predict, inputs=[audio_input, slider], outputs=[output])
95
 
96
- # -------------------- Launch --------------------
97
  if __name__ == "__main__":
98
- demo.launch()
 
1
+ import os
 
 
 
 
2
  import json
3
+ import pickle
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import librosa
8
+ import scipy.signal as sps
9
+ import gradio as gr
10
+ from sklearn.preprocessing import LabelEncoder
11
+
12
+ # ----------------------------
13
+ # 1) Global parameters & paths
14
+ # ----------------------------
15
+ SR = 22050
16
+ DURATION = 4.0
17
+ HOP = 512
18
+ FMIN, FMAX = 150, 4500
19
+ MODEL_PATH = "cnn_final.pth"
20
+ DATA_PKL = "label_encoder_and_thresholds.pkl"
21
+ CAL_PATH = "calibrators.pkl"
22
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+
24
+ # ----------------------------
25
+ # 2) Model definition
26
+ # ----------------------------
27
+ class SEBlock(nn.Module):
28
+ def __init__(self, channels, red=16):
29
+ super().__init__()
30
+ self.fc = nn.Sequential(
31
+ nn.AdaptiveAvgPool2d(1),
32
+ nn.Conv2d(channels, channels//red, 1),
33
+ nn.ReLU(inplace=True),
34
+ nn.Conv2d(channels//red, channels, 1),
35
+ nn.Sigmoid()
36
+ )
37
+ def forward(self, x): return x * self.fc(x)
38
+
39
+ class EfficientNetSE(nn.Module):
40
+ def __init__(self, bbone, num_classes, drop=0.3):
41
+ super().__init__()
42
+ self.backbone = bbone
43
+ self.se = SEBlock(1280)
44
+ self.pool = nn.AdaptiveAvgPool2d(1)
45
+ self.classifier = nn.Sequential(
46
+ nn.Dropout(drop),
47
+ nn.Linear(1280, num_classes)
48
+ )
49
+ def forward(self, x):
50
+ x = self.backbone.features(x)
51
+ x = self.se(x)
52
+ x = self.pool(x).flatten(1)
53
+ return self.classifier(x)
54
+
55
+ # ----------------------------
56
+ # 3) Audio preprocessing
57
+ # ----------------------------
58
+ def load_and_normalize(path, sr=SR, target_dBFS=-20.0):
59
+ y, _ = librosa.load(path, sr=sr)
60
+ y = y - np.mean(y)
61
+ rms = np.sqrt(np.mean(y**2)) + 1e-9
62
+ scalar = (10**(target_dBFS/20)) / rms
63
+ return y * scalar
64
+
65
+ def bandpass(y, sr=SR, low=FMIN, high=FMAX, order=6):
66
+ nyq = 0.5*sr
67
+ b,a = sps.butter(order, [low/nyq, high/nyq], btype='band')
68
+ return sps.filtfilt(b,a,y)
69
+
70
+ def segment(y, sr=SR, win=DURATION, hop=1.0):
71
+ w = int(win*sr); h = int(hop*sr)
72
+ if len(y) < w:
73
+ y = np.pad(y, (0, w - len(y)))
74
+ return [y]
75
+ return [y[i:i+w] for i in range(0, len(y)-w+1, h)]
76
+
77
+ def extract_log_mel(y, sr=SR, n_mels=128, hop_length=HOP, fmin=FMIN, fmax=FMAX):
78
+ mel = librosa.feature.melspectrogram(
79
+ y=y, sr=sr, n_mels=n_mels,
80
+ hop_length=hop_length, fmin=fmin, fmax=fmax, power=1.0
81
+ )
82
+ return librosa.pcen(mel * (2**31))
83
 
84
+ def predict_segments(fp):
85
+ y = load_and_normalize(fp)
86
+ y = bandpass(y)
87
+ segs = segment(y)
88
+ all_p = []
 
 
 
 
 
 
 
 
 
 
 
89
  with torch.no_grad():
90
+ for seg in segs:
91
  mel = extract_log_mel(seg)
92
+ inp = torch.tensor(mel[None,None], dtype=torch.float32).to(DEVICE)
93
  out = model(inp)
94
+ all_p.append(torch.sigmoid(out).cpu().numpy()[0])
95
+ return np.vstack(all_p)
96
 
97
+ # ----------------------------
98
+ # 4) Load artifacts
99
+ # ----------------------------
100
+ with open(DATA_PKL, "rb") as f:
101
+ data = pickle.load(f)
102
+ classes = data["classes"]
103
+ orig_thresholds = np.array(data["thresholds"])
104
+ adj_thresholds = np.array(data["adj_thresholds"])
105
+
106
+ # Rebuild encoder
107
+ le = LabelEncoder()
108
+ le.classes_ = np.array(classes, dtype=object)
109
+
110
+ # Calibrators
111
+ with open(CAL_PATH, "rb") as f:
112
+ calibrators = pickle.load(f)
113
+
114
+ # Load backbone & model
115
+ backbone = torch.hub.load('pytorch/vision:v0.14.0','efficientnet_b0',pretrained=True)
116
+ backbone.features[0][0] = nn.Conv2d(1,32,3,2,1,bias=False)
117
+ model = EfficientNetSE(backbone, num_classes=len(le.classes_)).to(DEVICE)
118
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
119
+ model.eval()
120
 
121
+ # ----------------------------
122
+ # 5) Inference logic
123
+ # ----------------------------
124
+ def infer(audio_path, sensitivity):
125
+ # segments → probabilities
126
+ seg_probs = predict_segments(audio_path)
127
+ agg = np.percentile(seg_probs, 90, axis=0)
128
+ # calibrate
129
  calibrated = np.array([
130
  calibrators[i].transform([agg[i]])[0]
131
+ for i in range(len(le.classes_))
132
  ])
133
+ # adjust thresholds
134
+ thresholds = adj_thresholds * sensitivity
135
+ preds = calibrated > thresholds
136
+
137
+ # build results
138
+ results = [(le.classes_[i].replace("_"," "), round(float(calibrated[i]),3))
139
+ for i, flag in enumerate(preds) if flag]
140
+ if not results:
141
+ return "🔍 **No species confidently detected.**\nTry reducing the strictness."
142
+
143
+ # sort and format Markdown
144
+ results.sort(key=lambda x: -x[1])
145
+ md = "### Detected species:\n"
146
+ for sp, p in results:
147
+ md += f"- **{sp}** — probability: {p}\n"
148
+ return md
149
+
150
+ # ----------------------------
151
+ # 6) Gradio Blocks interface
152
+ # ----------------------------
153
  with gr.Blocks() as demo:
154
+ gr.Markdown("# 🐸 RibbID – Amphibian Call Identifier\n")
155
  gr.Markdown(
156
+ "**Detection strictness** controls how conservative the model is:\n\n"
157
+ "- Lower values (0.5) = more sensitive (may include false positives).\n"
158
+ "- Higher values (1.0) = only very confident detections."
 
159
  )
160
 
161
  with gr.Row():
162
+ audio = gr.Audio(type="filepath", label="Upload audio file (.wav/.mp3)")
163
+ slider = gr.Slider(0.5, 1.0, value=1.0, step=0.05,
164
+ label="Detection strictness")
165
 
166
+ status = gr.Markdown("") # spinner placeholder
167
  output = gr.Markdown()
168
 
169
+ def wrapped(audio_path, strictness):
170
  status.update("⏳ Processing...")
171
+ res = infer(audio_path, strictness)
172
+ status.update("") # clear spinner
173
+ return res
174
 
175
+ btn = gr.Button("Submit")
176
+ btn.click(fn=wrapped, inputs=[audio, slider], outputs=[output])
177
 
178
+ # launch without share link
179
  if __name__ == "__main__":
180
+ demo.launch(share=False)