RibbID_CNN / predict.py
Calotriton's picture
Upload 3 files
afabda4 verified
# -*- coding: utf-8 -*-
"""
Created on Mon Jun 30 17:06:08 2025
@author: User
"""
import torch
import torch.nn as nn
import numpy as np
import librosa
import joblib
import pickle
from pathlib import Path
from sklearn.isotonic import IsotonicRegression
import argparse
# ==== CONFIGURACIÓN ====
SR = 22050
DURATION = 4.0
SAMPLES = int(SR * DURATION)
BANDS = 128
HOP = 512
FMIN, FMAX = 150, 4500
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ==== MODELO ====
class SEBlock(nn.Module):
def __init__(self, channels, red=16):
super().__init__()
self.fc = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, channels // red, 1),
nn.ReLU(inplace=True),
nn.Conv2d(channels // red, channels, 1),
nn.Sigmoid()
)
def forward(self, x):
return x * self.fc(x)
class EfficientNetSE(nn.Module):
def __init__(self, backbone, num_classes, drop=0.3):
super().__init__()
self.backbone = backbone
self.se = SEBlock(1280)
self.pool = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Sequential(
nn.Dropout(drop),
nn.Linear(1280, num_classes)
)
def forward(self, x):
x = self.backbone.features(x)
x = self.se(x)
x = self.pool(x).flatten(1)
return self.classifier(x)
# ==== PREPROCESADO ====
def load_and_normalize(path, sr=SR, target_dBFS=-20.0):
y, _ = librosa.load(path, sr=sr)
y = y - np.mean(y)
rms = np.sqrt(np.mean(y ** 2)) + 1e-9
scalar = (10 ** (target_dBFS / 20)) / rms
return y * scalar
def bandpass(y, sr=SR, low=FMIN, high=FMAX, order=6):
from scipy.signal import butter, filtfilt
nyq = 0.5 * sr
b, a = butter(order, [low / nyq, high / nyq], btype='band')
return filtfilt(b, a, y)
def segment(y, sr=SR, win=DURATION, hop=1.0):
w = int(win * sr)
h = int(hop * sr)
if len(y) < w:
y = np.pad(y, (0, w - len(y)))
return [y]
return [y[i:i + w] for i in range(0, len(y) - w + 1, h)]
def extract_log_mel(y, sr=SR, n_mels=BANDS, hop_length=HOP, fmin=FMIN, fmax=FMAX):
mel = librosa.feature.melspectrogram(
y=y, sr=sr, n_mels=n_mels, hop_length=hop_length, fmin=fmin, fmax=fmax, power=1.0)
pcen = librosa.pcen(mel * (2 ** 31))
return pcen
# ==== PREDICCIÓN SEGMENTADA ====
def predict_segments(file_path, model):
y = load_and_normalize(file_path)
y = bandpass(y, SR)
segments = segment(y, SR)
all_probs = []
model.eval()
with torch.no_grad():
for seg in segments:
mel = extract_log_mel(seg)
inp = torch.tensor(mel[None, None], dtype=torch.float32).to(DEVICE)
probs = torch.sigmoid(model(inp)).cpu().numpy()[0]
all_probs.append(probs)
return np.array(all_probs)
# ==== ESTRATEGIA HÍBRIDA DE PREDICCIÓN ====
def predict_file_with_hybrid_strategy(file_path, model, thresholds, label_encoder, override_max=0.9):
probs = predict_segments(file_path, model)
mean_probs = probs.mean(axis=0)
max_probs = probs.max(axis=0)
sensitive_thresh = [t - 0.15 for t in thresholds]
preds = []
for i, sp in enumerate(label_encoder.classes_):
if mean_probs[i] > sensitive_thresh[i] or max_probs[i] > override_max:
preds.append(sp)
return preds, mean_probs, max_probs, probs
# ==== MAIN ====
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("audio_file", type=str, help="Ruta al archivo de audio (.wav)")
parser.add_argument("--model", default="CNN_final.pth", help="Ruta al modelo CNN .pth")
parser.add_argument("--meta", default="label_encoder_and_thresholds.pkl", help="Pickle con encoder y thresholds")
args = parser.parse_args()
# Cargar metadatos (label encoder, thresholds, calibrators si los quieres aplicar también)
with open(args.meta, "rb") as f:
meta = pickle.load(f)
label_encoder = meta["label_encoder"]
thresholds = meta["thresholds"]
# Cargar modelo
from torchvision import models
backbone = models.efficientnet_b0(weights=None)
backbone.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False)
model = EfficientNetSE(backbone, num_classes=len(label_encoder.classes_))
model.load_state_dict(torch.load(args.model, map_location=DEVICE))
model.to(DEVICE)
# Ejecutar predicción
file_path = args.audio_file
preds, mean_probs, max_probs, probs_all = predict_file_with_hybrid_strategy(
file_path, model, thresholds, label_encoder
)
print(f"\n Archivo: {file_path}")
print(f"Especies detectadas: {', '.join(preds)}\n")
print("📊 Probabilidades por especie:")
for i, sp in enumerate(label_encoder.classes_):
print(f" {sp:<25} → mean: {mean_probs[i]:.2f}, max: {max_probs[i]:.2f}")