Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import pickle | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import librosa | |
| import scipy.signal as sps | |
| import gradio as gr | |
| from sklearn.preprocessing import LabelEncoder | |
| # ---------------------------- | |
| # 1) Global parameters & paths | |
| # ---------------------------- | |
| SR = 22050 | |
| DURATION = 4.0 | |
| HOP = 512 | |
| FMIN, FMAX = 150, 4500 | |
| MODEL_PATH = "CNN_final.pth" | |
| DATA_PKL = "label_encoder_and_thresholds.pkl" | |
| CAL_PATH = "calibrators.pkl" | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ---------------------------- | |
| # 2) Model definition | |
| # ---------------------------- | |
| class SEBlock(nn.Module): | |
| def __init__(self, channels, red=16): | |
| super().__init__() | |
| self.fc = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Conv2d(channels, channels//red, 1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(channels//red, channels, 1), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x): return x * self.fc(x) | |
| class EfficientNetSE(nn.Module): | |
| def __init__(self, bbone, num_classes, drop=0.3): | |
| super().__init__() | |
| self.backbone = bbone | |
| self.se = SEBlock(1280) | |
| self.pool = nn.AdaptiveAvgPool2d(1) | |
| self.classifier = nn.Sequential( | |
| nn.Dropout(drop), | |
| nn.Linear(1280, num_classes) | |
| ) | |
| def forward(self, x): | |
| x = self.backbone.features(x) | |
| x = self.se(x) | |
| x = self.pool(x).flatten(1) | |
| return self.classifier(x) | |
| # ---------------------------- | |
| # 3) Audio preprocessing | |
| # ---------------------------- | |
| def load_and_normalize(path, sr=SR, target_dBFS=-20.0): | |
| y, _ = librosa.load(path, sr=sr) | |
| y = y - np.mean(y) | |
| rms = np.sqrt(np.mean(y**2)) + 1e-9 | |
| scalar = (10**(target_dBFS/20)) / rms | |
| return y * scalar | |
| def bandpass(y, sr=SR, low=FMIN, high=FMAX, order=6): | |
| nyq = 0.5*sr | |
| b,a = sps.butter(order, [low/nyq, high/nyq], btype='band') | |
| return sps.filtfilt(b,a,y) | |
| def segment(y, sr=SR, win=DURATION, hop=1.0): | |
| w = int(win*sr); h = int(hop*sr) | |
| if len(y) < w: | |
| y = np.pad(y, (0, w - len(y))) | |
| return [y] | |
| return [y[i:i+w] for i in range(0, len(y)-w+1, h)] | |
| def extract_log_mel(y, sr=SR, n_mels=128, hop_length=HOP, fmin=FMIN, fmax=FMAX): | |
| mel = librosa.feature.melspectrogram( | |
| y=y, sr=sr, n_mels=n_mels, | |
| hop_length=hop_length, fmin=fmin, fmax=fmax, power=1.0 | |
| ) | |
| return librosa.pcen(mel * (2**31)) | |
| def predict_segments(fp): | |
| y = load_and_normalize(fp) | |
| y = bandpass(y) | |
| segs = segment(y) | |
| all_p = [] | |
| with torch.no_grad(): | |
| for seg in segs: | |
| mel = extract_log_mel(seg) | |
| inp = torch.tensor(mel[None,None], dtype=torch.float32).to(DEVICE) | |
| out = model(inp) | |
| all_p.append(torch.sigmoid(out).cpu().numpy()[0]) | |
| return np.vstack(all_p) | |
| # ---------------------------- | |
| # 4) Load artifacts | |
| # ---------------------------- | |
| with open(DATA_PKL, "rb") as f: | |
| data = pickle.load(f) | |
| classes = data["classes"] | |
| orig_thresholds = np.array(data["thresholds"]) | |
| adj_thresholds = np.array(data["adj_thresholds"]) | |
| # Rebuild encoder | |
| le = LabelEncoder() | |
| le.classes_ = np.array(classes, dtype=object) | |
| # Calibrators | |
| with open(CAL_PATH, "rb") as f: | |
| calibrators = pickle.load(f) | |
| # Load backbone & model | |
| backbone = torch.hub.load('pytorch/vision:v0.14.0','efficientnet_b0',pretrained=True) | |
| backbone.features[0][0] = nn.Conv2d(1,32,3,2,1,bias=False) | |
| model = EfficientNetSE(backbone, num_classes=len(le.classes_)).to(DEVICE) | |
| model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) | |
| model.eval() | |
| # ---------------------------- | |
| # 5) Inference logic | |
| # ---------------------------- | |
| def infer(audio_path, sensitivity): | |
| # segments β probabilities | |
| seg_probs = predict_segments(audio_path) | |
| agg = np.percentile(seg_probs, 90, axis=0) | |
| # calibrate | |
| calibrated = np.array([ | |
| calibrators[i].transform([agg[i]])[0] | |
| for i in range(len(le.classes_)) | |
| ]) | |
| # adjust thresholds | |
| thresholds = adj_thresholds * sensitivity | |
| preds = calibrated > thresholds | |
| # build results | |
| results = [(le.classes_[i].replace("_"," "), round(float(calibrated[i]),3)) | |
| for i, flag in enumerate(preds) if flag] | |
| if not results: | |
| return "π **No species confidently detected.**\nTry reducing the strictness." | |
| # sort and format Markdown with italics species names | |
| results.sort(key=lambda x: -x[1]) | |
| md = "### β Detected species:\n" | |
| for sp, p in results: | |
| md += f"- *{sp}* β probability: {p}\n" | |
| return md | |
| # ---------------------------- | |
| # 6) Gradio Blocks interface | |
| # ---------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# πΈ RibbID β Amphibian species acoustic identifier\n") | |
| # Intro sentence about native species | |
| gr.Markdown( | |
| "This CNN model detects the native frog and toad species of Catalonia (Northern Spain) through ther calls." | |
| ) | |
| gr.Markdown( | |
| "To start, upload an audio file or record a new one. Next, select the detection strictness in the slider, and click submit. Results might take time.\n" | |
| "\n" | |
| "**Detection strictness** controls how conservative the model is:\n" | |
| "- Lower values (0.5) = more sensitive (may include false positives).\n" | |
| "- Higher values (1.0) = only very confident detections." | |
| ) | |
| with gr.Row(): | |
| audio = gr.Audio(type="filepath", label="Upload audio file (.wav/.mp3) or record live") | |
| slider = gr.Slider(0.5, 1.0, value=1.0, step=0.05, | |
| label="Detection strictness") | |
| output = gr.Markdown() | |
| btn = gr.Button("Submit") | |
| btn.click( | |
| fn=infer, | |
| inputs=[audio, slider], | |
| outputs=[output], | |
| show_progress=True | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=False) | |