import torch import numpy as np import librosa from scipy.ndimage import zoom from model_architecture import load_model def preprocess_audio(audio_path, sr=250000, nfft=0.0032, overlap=0.0028, freq_range=(40000, 100000)): ''' Preprocess audio file into spectrogram patch Args: audio_path: Path to .wav file sr: Sample rate (250 kHz for ultrasonic) nfft: FFT window size in seconds overlap: Overlap between windows in seconds freq_range: Frequency range to extract (Hz) Returns: torch.Tensor: Preprocessed spectrogram (1, 1, 64, 64) ''' # Load audio audio, _ = librosa.load(audio_path, sr=sr) # Generate spectrogram nfft_samples = int(nfft * sr) hop_length = int((nfft - overlap) * sr) spec = librosa.stft(audio, n_fft=nfft_samples, hop_length=hop_length) spec_db = librosa.amplitude_to_db(np.abs(spec), ref=np.max) # Filter to USV frequency range freqs = librosa.fft_frequencies(sr=sr, n_fft=nfft_samples) freq_mask = (freqs >= freq_range[0]) & (freqs <= freq_range[1]) spec_db = spec_db[freq_mask, :] # Resize to 64x64 zoom_factors = (64 / spec_db.shape[0], 64 / spec_db.shape[1]) spec_resized = zoom(spec_db, zoom_factors, order=1) # Normalize spec_resized = (spec_resized - np.mean(spec_resized)) / (np.std(spec_resized) + 1e-8) # Convert to tensor return torch.FloatTensor(spec_resized).unsqueeze(0).unsqueeze(0) def predict(audio_path, model): ''' Predict if audio contains USV Args: audio_path: Path to .wav file model: Loaded USVDetectorCNN model Returns: dict: Prediction results ''' # Preprocess spec_tensor = preprocess_audio(audio_path) # Predict with torch.no_grad(): output = model(spec_tensor) probabilities = torch.softmax(output, dim=1) prediction = torch.argmax(output, dim=1).item() return { 'is_usv': prediction == 1, 'confidence': probabilities[0][prediction].item(), 'usv_probability': probabilities[0][1].item(), 'noise_probability': probabilities[0][0].item() } # Example usage if __name__ == "__main__": # Load model model = load_model('final_usv_model.pth') # Predict result = predict('test_audio.wav', model) print(f"USV Detected: {result['is_usv']}") print(f"Confidence: {result['confidence']:.2%}") print(f"USV Probability: {result['usv_probability']:.2%}") print(f"Noise Probability: {result['noise_probability']:.2%}")