File size: 5,935 Bytes
08175df
3d23449
08175df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8064545
08175df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bf8b74
08175df
 
 
 
 
 
 
 
 
9bf8b74
 
 
08175df
 
9bf8b74
08175df
 
 
 
 
 
 
 
 
 
 
845cf99
08175df
 
9bf8b74
 
08175df
845cf99
08175df
845cf99
08175df
3d23449
08175df
 
d7789b4
08175df
 
 
 
 
 
9bf8b74
08175df
 
9bf8b74
08175df
 
 
9bf8b74
08175df
 
 
9bf8b74
08175df
 
 
 
 
845cf99
08175df
 
 
 
9bf8b74
08175df
 
9bf8b74
a747133
 
08175df
a747133
9bf8b74
08175df
 
 
9bf8b74
 
 
 
08175df
 
9bf8b74
 
 
 
 
 
08175df
 
9bf8b74
08175df
9bf8b74
7f4460d
9bf8b74
c8adc9b
9bf8b74
c8adc9b
9bf8b74
 
 
 
 
 
 
 
3d23449
9bf8b74
08175df
 
9bf8b74
 
 
 
 
 
 
 
 
 
2b1c1bc
 
38503c5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import os
import json
import pickle
import numpy as np
import torch
import torch.nn as nn
import librosa
import scipy.signal as sps
import gradio as gr
from sklearn.preprocessing import LabelEncoder

# ----------------------------
# 1) Global parameters & paths
# ----------------------------
SR          = 22050
DURATION    = 4.0
HOP         = 512
FMIN, FMAX  = 150, 4500
MODEL_PATH  = "CNN_final.pth"
DATA_PKL    = "label_encoder_and_thresholds.pkl"
CAL_PATH    = "calibrators.pkl"
DEVICE      = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ----------------------------
# 2) Model definition
# ----------------------------
class SEBlock(nn.Module):
    def __init__(self, channels, red=16):
        super().__init__()
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels//red, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels//red, channels, 1),
            nn.Sigmoid()
        )
    def forward(self, x): return x * self.fc(x)

class EfficientNetSE(nn.Module):
    def __init__(self, bbone, num_classes, drop=0.3):
        super().__init__()
        self.backbone   = bbone
        self.se         = SEBlock(1280)
        self.pool       = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Dropout(drop),
            nn.Linear(1280, num_classes)
        )
    def forward(self, x):
        x = self.backbone.features(x)
        x = self.se(x)
        x = self.pool(x).flatten(1)
        return self.classifier(x)

# ----------------------------
# 3) Audio preprocessing
# ----------------------------
def load_and_normalize(path, sr=SR, target_dBFS=-20.0):
    y, _   = librosa.load(path, sr=sr)
    y      = y - np.mean(y)
    rms    = np.sqrt(np.mean(y**2)) + 1e-9
    scalar = (10**(target_dBFS/20)) / rms
    return y * scalar

def bandpass(y, sr=SR, low=FMIN, high=FMAX, order=6):
    nyq = 0.5*sr
    b,a = sps.butter(order, [low/nyq, high/nyq], btype='band')
    return sps.filtfilt(b,a,y)

def segment(y, sr=SR, win=DURATION, hop=1.0):
    w = int(win*sr); h = int(hop*sr)
    if len(y) < w:
        y = np.pad(y, (0, w - len(y)))
        return [y]
    return [y[i:i+w] for i in range(0, len(y)-w+1, h)]

def extract_log_mel(y, sr=SR, n_mels=128, hop_length=HOP, fmin=FMIN, fmax=FMAX):
    mel = librosa.feature.melspectrogram(
        y=y, sr=sr, n_mels=n_mels,
        hop_length=hop_length, fmin=fmin, fmax=fmax, power=1.0
    )
    return librosa.pcen(mel * (2**31))

def predict_segments(fp):
    y    = load_and_normalize(fp)
    y    = bandpass(y)
    segs = segment(y)
    all_p = []
    with torch.no_grad():
        for seg in segs:
            mel = extract_log_mel(seg)
            inp = torch.tensor(mel[None,None], dtype=torch.float32).to(DEVICE)
            out = model(inp)
            all_p.append(torch.sigmoid(out).cpu().numpy()[0])
    return np.vstack(all_p)

# ----------------------------
# 4) Load artifacts
# ----------------------------
with open(DATA_PKL, "rb") as f:
    data = pickle.load(f)
classes        = data["classes"]
orig_thresholds = np.array(data["thresholds"])
adj_thresholds = np.array(data["adj_thresholds"])

# Rebuild encoder
le = LabelEncoder()
le.classes_ = np.array(classes, dtype=object)

# Calibrators
with open(CAL_PATH, "rb") as f:
    calibrators = pickle.load(f)

# Load backbone & model
backbone = torch.hub.load('pytorch/vision:v0.14.0','efficientnet_b0',pretrained=True)
backbone.features[0][0] = nn.Conv2d(1,32,3,2,1,bias=False)
model = EfficientNetSE(backbone, num_classes=len(le.classes_)).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# ----------------------------
# 5) Inference logic
# ----------------------------
def infer(audio_path, sensitivity):
    # segments β†’ probabilities
    seg_probs = predict_segments(audio_path)
    agg       = np.percentile(seg_probs, 90, axis=0)
    # calibrate
    calibrated = np.array([
        calibrators[i].transform([agg[i]])[0]
        for i in range(len(le.classes_))
    ])
    # adjust thresholds
    thresholds = adj_thresholds * sensitivity
    preds = calibrated > thresholds

    # build results
    results = [(le.classes_[i].replace("_"," "), round(float(calibrated[i]),3))
               for i, flag in enumerate(preds) if flag]
    if not results:
        return "πŸ” **No species confidently detected.**\nTry reducing the strictness."

    # sort and format Markdown with italics species names
    results.sort(key=lambda x: -x[1])
    md = "### βœ… Detected species:\n"
    for sp, p in results:
        md += f"- *{sp}* β€” probability: {p}\n"
    return md

# ----------------------------
# 6) Gradio Blocks interface
# ----------------------------
with gr.Blocks() as demo:
    gr.Markdown("# 🐸 RibbID – Amphibian species acoustic identifier\n")
    # Intro sentence about native species
    gr.Markdown(
        "This CNN model detects the native frog and toad species of Catalonia (Northern Spain) through ther calls."
    )
    gr.Markdown(
        "To start, upload an audio file or record a new one. Next, select the detection strictness in the slider, and click submit. Results might take time.\n"
        "\n"
        "**Detection strictness** controls how conservative the model is:\n"
        "- Lower values (0.5) = more sensitive (may include false positives).\n"
        "- Higher values (1.0) = only very confident detections."
    )

    with gr.Row():
        audio = gr.Audio(type="filepath", label="Upload audio file (.wav/.mp3) or record live")
        slider = gr.Slider(0.5, 1.0, value=1.0, step=0.05,
                           label="Detection strictness")

    output = gr.Markdown()

    btn = gr.Button("Submit")
    btn.click(
        fn=infer,
        inputs=[audio, slider],
        outputs=[output],
        show_progress=True
    )

if __name__ == "__main__":
    demo.launch(share=False)