import os import json import pickle import numpy as np import torch import torch.nn as nn import librosa import scipy.signal as sps import gradio as gr from sklearn.preprocessing import LabelEncoder # ---------------------------- # 1) Global parameters & paths # ---------------------------- SR = 22050 DURATION = 4.0 HOP = 512 FMIN, FMAX = 150, 4500 MODEL_PATH = "CNN_final.pth" DATA_PKL = "label_encoder_and_thresholds.pkl" CAL_PATH = "calibrators.pkl" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ---------------------------- # 2) Model definition # ---------------------------- class SEBlock(nn.Module): def __init__(self, channels, red=16): super().__init__() self.fc = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(channels, channels//red, 1), nn.ReLU(inplace=True), nn.Conv2d(channels//red, channels, 1), nn.Sigmoid() ) def forward(self, x): return x * self.fc(x) class EfficientNetSE(nn.Module): def __init__(self, bbone, num_classes, drop=0.3): super().__init__() self.backbone = bbone self.se = SEBlock(1280) self.pool = nn.AdaptiveAvgPool2d(1) self.classifier = nn.Sequential( nn.Dropout(drop), nn.Linear(1280, num_classes) ) def forward(self, x): x = self.backbone.features(x) x = self.se(x) x = self.pool(x).flatten(1) return self.classifier(x) # ---------------------------- # 3) Audio preprocessing # ---------------------------- def load_and_normalize(path, sr=SR, target_dBFS=-20.0): y, _ = librosa.load(path, sr=sr) y = y - np.mean(y) rms = np.sqrt(np.mean(y**2)) + 1e-9 scalar = (10**(target_dBFS/20)) / rms return y * scalar def bandpass(y, sr=SR, low=FMIN, high=FMAX, order=6): nyq = 0.5*sr b,a = sps.butter(order, [low/nyq, high/nyq], btype='band') return sps.filtfilt(b,a,y) def segment(y, sr=SR, win=DURATION, hop=1.0): w = int(win*sr); h = int(hop*sr) if len(y) < w: y = np.pad(y, (0, w - len(y))) return [y] return [y[i:i+w] for i in range(0, len(y)-w+1, h)] def extract_log_mel(y, sr=SR, n_mels=128, hop_length=HOP, fmin=FMIN, fmax=FMAX): mel = librosa.feature.melspectrogram( y=y, sr=sr, n_mels=n_mels, hop_length=hop_length, fmin=fmin, fmax=fmax, power=1.0 ) return librosa.pcen(mel * (2**31)) def predict_segments(fp): y = load_and_normalize(fp) y = bandpass(y) segs = segment(y) all_p = [] with torch.no_grad(): for seg in segs: mel = extract_log_mel(seg) inp = torch.tensor(mel[None,None], dtype=torch.float32).to(DEVICE) out = model(inp) all_p.append(torch.sigmoid(out).cpu().numpy()[0]) return np.vstack(all_p) # ---------------------------- # 4) Load artifacts # ---------------------------- with open(DATA_PKL, "rb") as f: data = pickle.load(f) classes = data["classes"] orig_thresholds = np.array(data["thresholds"]) adj_thresholds = np.array(data["adj_thresholds"]) # Rebuild encoder le = LabelEncoder() le.classes_ = np.array(classes, dtype=object) # Calibrators with open(CAL_PATH, "rb") as f: calibrators = pickle.load(f) # Load backbone & model backbone = torch.hub.load('pytorch/vision:v0.14.0','efficientnet_b0',pretrained=True) backbone.features[0][0] = nn.Conv2d(1,32,3,2,1,bias=False) model = EfficientNetSE(backbone, num_classes=len(le.classes_)).to(DEVICE) model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) model.eval() # ---------------------------- # 5) Inference logic # ---------------------------- def infer(audio_path, sensitivity): # segments → probabilities seg_probs = predict_segments(audio_path) agg = np.percentile(seg_probs, 90, axis=0) # calibrate calibrated = np.array([ calibrators[i].transform([agg[i]])[0] for i in range(len(le.classes_)) ]) # adjust thresholds thresholds = adj_thresholds * sensitivity preds = calibrated > thresholds # build results results = [(le.classes_[i].replace("_"," "), round(float(calibrated[i]),3)) for i, flag in enumerate(preds) if flag] if not results: return "🔍 **No species confidently detected.**\nTry reducing the strictness." # sort and format Markdown with italics species names results.sort(key=lambda x: -x[1]) md = "### ✅ Detected species:\n" for sp, p in results: md += f"- *{sp}* — probability: {p}\n" return md # ---------------------------- # 6) Gradio Blocks interface # ---------------------------- with gr.Blocks() as demo: gr.Markdown("# 🐸 RibbID – Amphibian species acoustic identifier\n") # Intro sentence about native species gr.Markdown( "This CNN model detects the native frog and toad species of Catalonia (Northern Spain) through ther calls." ) gr.Markdown( "To start, upload an audio file or record a new one. Next, select the detection strictness in the slider, and click submit. Results might take time.\n" "\n" "**Detection strictness** controls how conservative the model is:\n" "- Lower values (0.5) = more sensitive (may include false positives).\n" "- Higher values (1.0) = only very confident detections." ) with gr.Row(): audio = gr.Audio(type="filepath", label="Upload audio file (.wav/.mp3) or record live") slider = gr.Slider(0.5, 1.0, value=1.0, step=0.05, label="Detection strictness") output = gr.Markdown() btn = gr.Button("Submit") btn.click( fn=infer, inputs=[audio, slider], outputs=[output], show_progress=True ) if __name__ == "__main__": demo.launch(share=False)