| import numpy as np
|
| import torchaudio
|
| import torchaudio.transforms as T
|
| import joblib
|
| from scipy.stats import skew, kurtosis
|
| import tensorflow_hub as hub
|
|
|
|
|
| clf = joblib.load("models/noise_classifier.pkl")
|
| label_encoder = joblib.load("models/label_encoder.pkl")
|
|
|
|
|
| yamnet_model = hub.load("https://tfhub.dev/google/yamnet/1")
|
|
|
| def get_yamnet_embedding(audio_path):
|
| """
|
| Extract YAMNet embeddings with statistical pooling from a WAV file.
|
| """
|
| try:
|
| waveform, sr = torchaudio.load(audio_path)
|
| if sr != 16000:
|
| resampler = T.Resample(orig_freq=sr, new_freq=16000)
|
| waveform = resampler(waveform)
|
| if waveform.size(0) > 1:
|
| waveform = waveform.mean(dim=0)
|
| else:
|
| waveform = waveform.squeeze(0)
|
|
|
| waveform_np = waveform.numpy()
|
| _, embeddings, _ = yamnet_model(waveform_np)
|
|
|
|
|
| mean = np.mean(embeddings, axis=0)
|
| std = np.std(embeddings, axis=0)
|
| min_val = np.min(embeddings, axis=0)
|
| max_val = np.max(embeddings, axis=0)
|
| skewness = skew(embeddings, axis=0)
|
| kurt = kurtosis(embeddings, axis=0)
|
|
|
| return np.concatenate([mean, std, min_val, max_val, skewness, kurt])
|
| except Exception as e:
|
| print(f"Failed to process {audio_path}: {e}")
|
| return None
|
|
|
| def classify_noise(audio_path, threshold=0.6):
|
| """
|
| Classify noise with rejection threshold for 'Unknown' label.
|
| """
|
| feature = get_yamnet_embedding(audio_path)
|
| if feature is None:
|
| return [("Unknown", 0.0)]
|
|
|
| feature = feature.reshape(1, -1)
|
| probs = clf.predict_proba(feature)[0]
|
|
|
| top_idx = np.argmax(probs)
|
| top_prob = probs[top_idx]
|
|
|
| if top_prob < threshold:
|
| return [("Unknown", top_prob)]
|
|
|
| top_indices = np.argsort(probs)[::-1][:5]
|
| return [(label_encoder.inverse_transform([i])[0], probs[i]) for i in top_indices]
|
|
|