| | import torch |
| | import numpy as np |
| | import librosa |
| | from scipy.ndimage import zoom |
| | from model_architecture import load_model |
| |
|
| | def preprocess_audio(audio_path, sr=250000, nfft=0.0032, overlap=0.0028, |
| | freq_range=(40000, 100000)): |
| | ''' |
| | Preprocess audio file into spectrogram patch |
| | |
| | Args: |
| | audio_path: Path to .wav file |
| | sr: Sample rate (250 kHz for ultrasonic) |
| | nfft: FFT window size in seconds |
| | overlap: Overlap between windows in seconds |
| | freq_range: Frequency range to extract (Hz) |
| | |
| | Returns: |
| | torch.Tensor: Preprocessed spectrogram (1, 1, 64, 64) |
| | ''' |
| | |
| | audio, _ = librosa.load(audio_path, sr=sr) |
| | |
| | |
| | nfft_samples = int(nfft * sr) |
| | hop_length = int((nfft - overlap) * sr) |
| | spec = librosa.stft(audio, n_fft=nfft_samples, hop_length=hop_length) |
| | spec_db = librosa.amplitude_to_db(np.abs(spec), ref=np.max) |
| | |
| | |
| | freqs = librosa.fft_frequencies(sr=sr, n_fft=nfft_samples) |
| | freq_mask = (freqs >= freq_range[0]) & (freqs <= freq_range[1]) |
| | spec_db = spec_db[freq_mask, :] |
| | |
| | |
| | zoom_factors = (64 / spec_db.shape[0], 64 / spec_db.shape[1]) |
| | spec_resized = zoom(spec_db, zoom_factors, order=1) |
| | |
| | |
| | spec_resized = (spec_resized - np.mean(spec_resized)) / (np.std(spec_resized) + 1e-8) |
| | |
| | |
| | return torch.FloatTensor(spec_resized).unsqueeze(0).unsqueeze(0) |
| |
|
| | def predict(audio_path, model): |
| | ''' |
| | Predict if audio contains USV |
| | |
| | Args: |
| | audio_path: Path to .wav file |
| | model: Loaded USVDetectorCNN model |
| | |
| | Returns: |
| | dict: Prediction results |
| | ''' |
| | |
| | spec_tensor = preprocess_audio(audio_path) |
| | |
| | |
| | with torch.no_grad(): |
| | output = model(spec_tensor) |
| | probabilities = torch.softmax(output, dim=1) |
| | prediction = torch.argmax(output, dim=1).item() |
| | |
| | return { |
| | 'is_usv': prediction == 1, |
| | 'confidence': probabilities[0][prediction].item(), |
| | 'usv_probability': probabilities[0][1].item(), |
| | 'noise_probability': probabilities[0][0].item() |
| | } |
| |
|
| | |
| | if __name__ == "__main__": |
| | |
| | model = load_model('final_usv_model.pth') |
| | |
| | |
| | result = predict('test_audio.wav', model) |
| | |
| | print(f"USV Detected: {result['is_usv']}") |
| | print(f"Confidence: {result['confidence']:.2%}") |
| | print(f"USV Probability: {result['usv_probability']:.2%}") |
| | print(f"Noise Probability: {result['noise_probability']:.2%}") |
| |
|