RibbID / app.py
Calotriton's picture
Update app.py
7f4460d verified
import os
import json
import pickle
import numpy as np
import torch
import torch.nn as nn
import librosa
import scipy.signal as sps
import gradio as gr
from sklearn.preprocessing import LabelEncoder
# ----------------------------
# 1) Global parameters & paths
# ----------------------------
SR = 22050
DURATION = 4.0
HOP = 512
FMIN, FMAX = 150, 4500
MODEL_PATH = "CNN_final.pth"
DATA_PKL = "label_encoder_and_thresholds.pkl"
CAL_PATH = "calibrators.pkl"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ----------------------------
# 2) Model definition
# ----------------------------
class SEBlock(nn.Module):
def __init__(self, channels, red=16):
super().__init__()
self.fc = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, channels//red, 1),
nn.ReLU(inplace=True),
nn.Conv2d(channels//red, channels, 1),
nn.Sigmoid()
)
def forward(self, x): return x * self.fc(x)
class EfficientNetSE(nn.Module):
def __init__(self, bbone, num_classes, drop=0.3):
super().__init__()
self.backbone = bbone
self.se = SEBlock(1280)
self.pool = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Sequential(
nn.Dropout(drop),
nn.Linear(1280, num_classes)
)
def forward(self, x):
x = self.backbone.features(x)
x = self.se(x)
x = self.pool(x).flatten(1)
return self.classifier(x)
# ----------------------------
# 3) Audio preprocessing
# ----------------------------
def load_and_normalize(path, sr=SR, target_dBFS=-20.0):
y, _ = librosa.load(path, sr=sr)
y = y - np.mean(y)
rms = np.sqrt(np.mean(y**2)) + 1e-9
scalar = (10**(target_dBFS/20)) / rms
return y * scalar
def bandpass(y, sr=SR, low=FMIN, high=FMAX, order=6):
nyq = 0.5*sr
b,a = sps.butter(order, [low/nyq, high/nyq], btype='band')
return sps.filtfilt(b,a,y)
def segment(y, sr=SR, win=DURATION, hop=1.0):
w = int(win*sr); h = int(hop*sr)
if len(y) < w:
y = np.pad(y, (0, w - len(y)))
return [y]
return [y[i:i+w] for i in range(0, len(y)-w+1, h)]
def extract_log_mel(y, sr=SR, n_mels=128, hop_length=HOP, fmin=FMIN, fmax=FMAX):
mel = librosa.feature.melspectrogram(
y=y, sr=sr, n_mels=n_mels,
hop_length=hop_length, fmin=fmin, fmax=fmax, power=1.0
)
return librosa.pcen(mel * (2**31))
def predict_segments(fp):
y = load_and_normalize(fp)
y = bandpass(y)
segs = segment(y)
all_p = []
with torch.no_grad():
for seg in segs:
mel = extract_log_mel(seg)
inp = torch.tensor(mel[None,None], dtype=torch.float32).to(DEVICE)
out = model(inp)
all_p.append(torch.sigmoid(out).cpu().numpy()[0])
return np.vstack(all_p)
# ----------------------------
# 4) Load artifacts
# ----------------------------
with open(DATA_PKL, "rb") as f:
data = pickle.load(f)
classes = data["classes"]
orig_thresholds = np.array(data["thresholds"])
adj_thresholds = np.array(data["adj_thresholds"])
# Rebuild encoder
le = LabelEncoder()
le.classes_ = np.array(classes, dtype=object)
# Calibrators
with open(CAL_PATH, "rb") as f:
calibrators = pickle.load(f)
# Load backbone & model
backbone = torch.hub.load('pytorch/vision:v0.14.0','efficientnet_b0',pretrained=True)
backbone.features[0][0] = nn.Conv2d(1,32,3,2,1,bias=False)
model = EfficientNetSE(backbone, num_classes=len(le.classes_)).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()
# ----------------------------
# 5) Inference logic
# ----------------------------
def infer(audio_path, sensitivity):
# segments β†’ probabilities
seg_probs = predict_segments(audio_path)
agg = np.percentile(seg_probs, 90, axis=0)
# calibrate
calibrated = np.array([
calibrators[i].transform([agg[i]])[0]
for i in range(len(le.classes_))
])
# adjust thresholds
thresholds = adj_thresholds * sensitivity
preds = calibrated > thresholds
# build results
results = [(le.classes_[i].replace("_"," "), round(float(calibrated[i]),3))
for i, flag in enumerate(preds) if flag]
if not results:
return "πŸ” **No species confidently detected.**\nTry reducing the strictness."
# sort and format Markdown with italics species names
results.sort(key=lambda x: -x[1])
md = "### βœ… Detected species:\n"
for sp, p in results:
md += f"- *{sp}* β€” probability: {p}\n"
return md
# ----------------------------
# 6) Gradio Blocks interface
# ----------------------------
with gr.Blocks() as demo:
gr.Markdown("# 🐸 RibbID – Amphibian species acoustic identifier\n")
# Intro sentence about native species
gr.Markdown(
"This CNN model detects the native frog and toad species of Catalonia (Northern Spain) through ther calls."
)
gr.Markdown(
"To start, upload an audio file or record a new one. Next, select the detection strictness in the slider, and click submit. Results might take time.\n"
"\n"
"**Detection strictness** controls how conservative the model is:\n"
"- Lower values (0.5) = more sensitive (may include false positives).\n"
"- Higher values (1.0) = only very confident detections."
)
with gr.Row():
audio = gr.Audio(type="filepath", label="Upload audio file (.wav/.mp3) or record live")
slider = gr.Slider(0.5, 1.0, value=1.0, step=0.05,
label="Detection strictness")
output = gr.Markdown()
btn = gr.Button("Submit")
btn.click(
fn=infer,
inputs=[audio, slider],
outputs=[output],
show_progress=True
)
if __name__ == "__main__":
demo.launch(share=False)