Calotriton commited on
Commit
afabda4
·
verified ·
1 Parent(s): 278b999

Upload 3 files

Browse files
Files changed (3) hide show
  1. CNN_final.pth +3 -0
  2. label_encoder_and_thresholds.pkl +3 -0
  3. predict.py +147 -0
CNN_final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a49aa884e76b9a2a6774cbae827ef4c8b6013441a550361b0e76426cb3eb954b
3
+ size 22320011
label_encoder_and_thresholds.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:483b044e7ed702fcee8bd664b240c0edb3d77a4b071e028c213d754bcd5b5228
3
+ size 486
predict.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Mon Jun 30 17:06:08 2025
4
+
5
+ @author: User
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import numpy as np
11
+ import librosa
12
+ import joblib
13
+ import pickle
14
+ from pathlib import Path
15
+ from sklearn.isotonic import IsotonicRegression
16
+ import argparse
17
+
18
+ # ==== CONFIGURACIÓN ====
19
+ SR = 22050
20
+ DURATION = 4.0
21
+ SAMPLES = int(SR * DURATION)
22
+ BANDS = 128
23
+ HOP = 512
24
+ FMIN, FMAX = 150, 4500
25
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+
27
+ # ==== MODELO ====
28
+ class SEBlock(nn.Module):
29
+ def __init__(self, channels, red=16):
30
+ super().__init__()
31
+ self.fc = nn.Sequential(
32
+ nn.AdaptiveAvgPool2d(1),
33
+ nn.Conv2d(channels, channels // red, 1),
34
+ nn.ReLU(inplace=True),
35
+ nn.Conv2d(channels // red, channels, 1),
36
+ nn.Sigmoid()
37
+ )
38
+ def forward(self, x):
39
+ return x * self.fc(x)
40
+
41
+ class EfficientNetSE(nn.Module):
42
+ def __init__(self, backbone, num_classes, drop=0.3):
43
+ super().__init__()
44
+ self.backbone = backbone
45
+ self.se = SEBlock(1280)
46
+ self.pool = nn.AdaptiveAvgPool2d(1)
47
+ self.classifier = nn.Sequential(
48
+ nn.Dropout(drop),
49
+ nn.Linear(1280, num_classes)
50
+ )
51
+ def forward(self, x):
52
+ x = self.backbone.features(x)
53
+ x = self.se(x)
54
+ x = self.pool(x).flatten(1)
55
+ return self.classifier(x)
56
+
57
+ # ==== PREPROCESADO ====
58
+ def load_and_normalize(path, sr=SR, target_dBFS=-20.0):
59
+ y, _ = librosa.load(path, sr=sr)
60
+ y = y - np.mean(y)
61
+ rms = np.sqrt(np.mean(y ** 2)) + 1e-9
62
+ scalar = (10 ** (target_dBFS / 20)) / rms
63
+ return y * scalar
64
+
65
+ def bandpass(y, sr=SR, low=FMIN, high=FMAX, order=6):
66
+ from scipy.signal import butter, filtfilt
67
+ nyq = 0.5 * sr
68
+ b, a = butter(order, [low / nyq, high / nyq], btype='band')
69
+ return filtfilt(b, a, y)
70
+
71
+ def segment(y, sr=SR, win=DURATION, hop=1.0):
72
+ w = int(win * sr)
73
+ h = int(hop * sr)
74
+ if len(y) < w:
75
+ y = np.pad(y, (0, w - len(y)))
76
+ return [y]
77
+ return [y[i:i + w] for i in range(0, len(y) - w + 1, h)]
78
+
79
+ def extract_log_mel(y, sr=SR, n_mels=BANDS, hop_length=HOP, fmin=FMIN, fmax=FMAX):
80
+ mel = librosa.feature.melspectrogram(
81
+ y=y, sr=sr, n_mels=n_mels, hop_length=hop_length, fmin=fmin, fmax=fmax, power=1.0)
82
+ pcen = librosa.pcen(mel * (2 ** 31))
83
+ return pcen
84
+
85
+ # ==== PREDICCIÓN SEGMENTADA ====
86
+ def predict_segments(file_path, model):
87
+ y = load_and_normalize(file_path)
88
+ y = bandpass(y, SR)
89
+ segments = segment(y, SR)
90
+ all_probs = []
91
+ model.eval()
92
+ with torch.no_grad():
93
+ for seg in segments:
94
+ mel = extract_log_mel(seg)
95
+ inp = torch.tensor(mel[None, None], dtype=torch.float32).to(DEVICE)
96
+ probs = torch.sigmoid(model(inp)).cpu().numpy()[0]
97
+ all_probs.append(probs)
98
+ return np.array(all_probs)
99
+
100
+ # ==== ESTRATEGIA HÍBRIDA DE PREDICCIÓN ====
101
+ def predict_file_with_hybrid_strategy(file_path, model, thresholds, label_encoder, override_max=0.9):
102
+ probs = predict_segments(file_path, model)
103
+ mean_probs = probs.mean(axis=0)
104
+ max_probs = probs.max(axis=0)
105
+ sensitive_thresh = [t - 0.15 for t in thresholds]
106
+
107
+ preds = []
108
+ for i, sp in enumerate(label_encoder.classes_):
109
+ if mean_probs[i] > sensitive_thresh[i] or max_probs[i] > override_max:
110
+ preds.append(sp)
111
+ return preds, mean_probs, max_probs, probs
112
+
113
+ # ==== MAIN ====
114
+ if __name__ == "__main__":
115
+ parser = argparse.ArgumentParser()
116
+ parser.add_argument("audio_file", type=str, help="Ruta al archivo de audio (.wav)")
117
+ parser.add_argument("--model", default="CNN_final.pth", help="Ruta al modelo CNN .pth")
118
+ parser.add_argument("--meta", default="label_encoder_and_thresholds.pkl", help="Pickle con encoder y thresholds")
119
+ args = parser.parse_args()
120
+
121
+ # Cargar metadatos (label encoder, thresholds, calibrators si los quieres aplicar también)
122
+ with open(args.meta, "rb") as f:
123
+ meta = pickle.load(f)
124
+
125
+ label_encoder = meta["label_encoder"]
126
+ thresholds = meta["thresholds"]
127
+
128
+ # Cargar modelo
129
+ from torchvision import models
130
+ backbone = models.efficientnet_b0(weights=None)
131
+ backbone.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False)
132
+ model = EfficientNetSE(backbone, num_classes=len(label_encoder.classes_))
133
+ model.load_state_dict(torch.load(args.model, map_location=DEVICE))
134
+ model.to(DEVICE)
135
+
136
+ # Ejecutar predicción
137
+ file_path = args.audio_file
138
+ preds, mean_probs, max_probs, probs_all = predict_file_with_hybrid_strategy(
139
+ file_path, model, thresholds, label_encoder
140
+ )
141
+
142
+ print(f"\n Archivo: {file_path}")
143
+ print(f"Especies detectadas: {', '.join(preds)}\n")
144
+
145
+ print("📊 Probabilidades por especie:")
146
+ for i, sp in enumerate(label_encoder.classes_):
147
+ print(f" {sp:<25} → mean: {mean_probs[i]:.2f}, max: {max_probs[i]:.2f}")