# -*- coding: utf-8 -*- """ Created on Mon Jun 30 17:06:08 2025 @author: User """ import torch import torch.nn as nn import numpy as np import librosa import joblib import pickle from pathlib import Path from sklearn.isotonic import IsotonicRegression import argparse # ==== CONFIGURACIÓN ==== SR = 22050 DURATION = 4.0 SAMPLES = int(SR * DURATION) BANDS = 128 HOP = 512 FMIN, FMAX = 150, 4500 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ==== MODELO ==== class SEBlock(nn.Module): def __init__(self, channels, red=16): super().__init__() self.fc = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(channels, channels // red, 1), nn.ReLU(inplace=True), nn.Conv2d(channels // red, channels, 1), nn.Sigmoid() ) def forward(self, x): return x * self.fc(x) class EfficientNetSE(nn.Module): def __init__(self, backbone, num_classes, drop=0.3): super().__init__() self.backbone = backbone self.se = SEBlock(1280) self.pool = nn.AdaptiveAvgPool2d(1) self.classifier = nn.Sequential( nn.Dropout(drop), nn.Linear(1280, num_classes) ) def forward(self, x): x = self.backbone.features(x) x = self.se(x) x = self.pool(x).flatten(1) return self.classifier(x) # ==== PREPROCESADO ==== def load_and_normalize(path, sr=SR, target_dBFS=-20.0): y, _ = librosa.load(path, sr=sr) y = y - np.mean(y) rms = np.sqrt(np.mean(y ** 2)) + 1e-9 scalar = (10 ** (target_dBFS / 20)) / rms return y * scalar def bandpass(y, sr=SR, low=FMIN, high=FMAX, order=6): from scipy.signal import butter, filtfilt nyq = 0.5 * sr b, a = butter(order, [low / nyq, high / nyq], btype='band') return filtfilt(b, a, y) def segment(y, sr=SR, win=DURATION, hop=1.0): w = int(win * sr) h = int(hop * sr) if len(y) < w: y = np.pad(y, (0, w - len(y))) return [y] return [y[i:i + w] for i in range(0, len(y) - w + 1, h)] def extract_log_mel(y, sr=SR, n_mels=BANDS, hop_length=HOP, fmin=FMIN, fmax=FMAX): mel = librosa.feature.melspectrogram( y=y, sr=sr, n_mels=n_mels, hop_length=hop_length, fmin=fmin, fmax=fmax, power=1.0) pcen = librosa.pcen(mel * (2 ** 31)) return pcen # ==== PREDICCIÓN SEGMENTADA ==== def predict_segments(file_path, model): y = load_and_normalize(file_path) y = bandpass(y, SR) segments = segment(y, SR) all_probs = [] model.eval() with torch.no_grad(): for seg in segments: mel = extract_log_mel(seg) inp = torch.tensor(mel[None, None], dtype=torch.float32).to(DEVICE) probs = torch.sigmoid(model(inp)).cpu().numpy()[0] all_probs.append(probs) return np.array(all_probs) # ==== ESTRATEGIA HÍBRIDA DE PREDICCIÓN ==== def predict_file_with_hybrid_strategy(file_path, model, thresholds, label_encoder, override_max=0.9): probs = predict_segments(file_path, model) mean_probs = probs.mean(axis=0) max_probs = probs.max(axis=0) sensitive_thresh = [t - 0.15 for t in thresholds] preds = [] for i, sp in enumerate(label_encoder.classes_): if mean_probs[i] > sensitive_thresh[i] or max_probs[i] > override_max: preds.append(sp) return preds, mean_probs, max_probs, probs # ==== MAIN ==== if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("audio_file", type=str, help="Ruta al archivo de audio (.wav)") parser.add_argument("--model", default="CNN_final.pth", help="Ruta al modelo CNN .pth") parser.add_argument("--meta", default="label_encoder_and_thresholds.pkl", help="Pickle con encoder y thresholds") args = parser.parse_args() # Cargar metadatos (label encoder, thresholds, calibrators si los quieres aplicar también) with open(args.meta, "rb") as f: meta = pickle.load(f) label_encoder = meta["label_encoder"] thresholds = meta["thresholds"] # Cargar modelo from torchvision import models backbone = models.efficientnet_b0(weights=None) backbone.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False) model = EfficientNetSE(backbone, num_classes=len(label_encoder.classes_)) model.load_state_dict(torch.load(args.model, map_location=DEVICE)) model.to(DEVICE) # Ejecutar predicción file_path = args.audio_file preds, mean_probs, max_probs, probs_all = predict_file_with_hybrid_strategy( file_path, model, thresholds, label_encoder ) print(f"\n Archivo: {file_path}") print(f"Especies detectadas: {', '.join(preds)}\n") print("📊 Probabilidades por especie:") for i, sp in enumerate(label_encoder.classes_): print(f" {sp:<25} → mean: {mean_probs[i]:.2f}, max: {max_probs[i]:.2f}")