sulstice2's picture
Upload 5 files
8d28a33 verified
import torch
import numpy as np
import librosa
from scipy.ndimage import zoom
from model_architecture import load_model
def preprocess_audio(audio_path, sr=250000, nfft=0.0032, overlap=0.0028,
freq_range=(40000, 100000)):
'''
Preprocess audio file into spectrogram patch
Args:
audio_path: Path to .wav file
sr: Sample rate (250 kHz for ultrasonic)
nfft: FFT window size in seconds
overlap: Overlap between windows in seconds
freq_range: Frequency range to extract (Hz)
Returns:
torch.Tensor: Preprocessed spectrogram (1, 1, 64, 64)
'''
# Load audio
audio, _ = librosa.load(audio_path, sr=sr)
# Generate spectrogram
nfft_samples = int(nfft * sr)
hop_length = int((nfft - overlap) * sr)
spec = librosa.stft(audio, n_fft=nfft_samples, hop_length=hop_length)
spec_db = librosa.amplitude_to_db(np.abs(spec), ref=np.max)
# Filter to USV frequency range
freqs = librosa.fft_frequencies(sr=sr, n_fft=nfft_samples)
freq_mask = (freqs >= freq_range[0]) & (freqs <= freq_range[1])
spec_db = spec_db[freq_mask, :]
# Resize to 64x64
zoom_factors = (64 / spec_db.shape[0], 64 / spec_db.shape[1])
spec_resized = zoom(spec_db, zoom_factors, order=1)
# Normalize
spec_resized = (spec_resized - np.mean(spec_resized)) / (np.std(spec_resized) + 1e-8)
# Convert to tensor
return torch.FloatTensor(spec_resized).unsqueeze(0).unsqueeze(0)
def predict(audio_path, model):
'''
Predict if audio contains USV
Args:
audio_path: Path to .wav file
model: Loaded USVDetectorCNN model
Returns:
dict: Prediction results
'''
# Preprocess
spec_tensor = preprocess_audio(audio_path)
# Predict
with torch.no_grad():
output = model(spec_tensor)
probabilities = torch.softmax(output, dim=1)
prediction = torch.argmax(output, dim=1).item()
return {
'is_usv': prediction == 1,
'confidence': probabilities[0][prediction].item(),
'usv_probability': probabilities[0][1].item(),
'noise_probability': probabilities[0][0].item()
}
# Example usage
if __name__ == "__main__":
# Load model
model = load_model('final_usv_model.pth')
# Predict
result = predict('test_audio.wav', model)
print(f"USV Detected: {result['is_usv']}")
print(f"Confidence: {result['confidence']:.2%}")
print(f"USV Probability: {result['usv_probability']:.2%}")
print(f"Noise Probability: {result['noise_probability']:.2%}")