|
|
""" |
|
|
Audio Classification using YAMNet and Custom Models |
|
|
A streamlined tool for classifying audio using pre-trained and custom models. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import argparse |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Tuple, Optional, Union, Any |
|
|
from dataclasses import dataclass, field |
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import librosa |
|
|
import resampy |
|
|
import soundfile as sf |
|
|
import tensorflow as tf |
|
|
from tensorflow.keras.models import load_model |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
tf.get_logger().setLevel(logging.ERROR) |
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class YAMNetParams: |
|
|
"""Parameters for YAMNet model.""" |
|
|
sample_rate: float = 16000.0 |
|
|
stft_window_seconds: float = 0.025 |
|
|
stft_hop_seconds: float = 0.010 |
|
|
mel_bands: int = 64 |
|
|
mel_min_hz: float = 125.0 |
|
|
mel_max_hz: float = 7500.0 |
|
|
log_offset: float = 0.001 |
|
|
patch_window_seconds: float = 0.96 |
|
|
patch_hop_seconds: float = 0.48 |
|
|
num_classes: int = 521 |
|
|
conv_padding: str = 'same' |
|
|
batchnorm_center: bool = True |
|
|
batchnorm_scale: bool = False |
|
|
batchnorm_epsilon: float = 1e-4 |
|
|
classifier_activation: str = 'sigmoid' |
|
|
tflite_compatible: bool = True |
|
|
|
|
|
@property |
|
|
def patch_frames(self) -> int: |
|
|
"""Calculate number of frames per patch.""" |
|
|
return int(round(self.patch_window_seconds / self.stft_hop_seconds)) |
|
|
|
|
|
@property |
|
|
def patch_bands(self) -> int: |
|
|
"""Get number of mel bands.""" |
|
|
return self.mel_bands |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Config: |
|
|
"""Configuration for models and processing parameters.""" |
|
|
|
|
|
yamnet_model_path: str |
|
|
yamnet_classes_path: str |
|
|
model_path: Optional[str] = None |
|
|
custom_classes_path: Optional[str] = None |
|
|
output_dir: str = "results" |
|
|
output_file: str = "classification.txt" |
|
|
|
|
|
|
|
|
window_length: int = 10 |
|
|
hop_length: int = 1 |
|
|
custom_weight_factor: float = 5.0 |
|
|
top_k: int = 10 |
|
|
|
|
|
|
|
|
excluded_classes: List[str] = field(default_factory=lambda: ["Vehicle"]) |
|
|
|
|
|
def __post_init__(self): |
|
|
"""Convert paths to absolute paths and ensure output directory exists.""" |
|
|
self.yamnet_model_path = os.path.abspath(self.yamnet_model_path) |
|
|
self.yamnet_classes_path = os.path.abspath(self.yamnet_classes_path) |
|
|
|
|
|
if self.model_path: |
|
|
self.model_path = os.path.abspath(self.model_path) |
|
|
|
|
|
if self.custom_classes_path: |
|
|
self.custom_classes_path = os.path.abspath(self.custom_classes_path) |
|
|
|
|
|
|
|
|
os.makedirs(Path(self.output_dir), exist_ok=True) |
|
|
|
|
|
@property |
|
|
def output_path(self) -> str: |
|
|
"""Get full path to output file.""" |
|
|
return os.path.join(self.output_dir, self.output_file) |
|
|
|
|
|
@classmethod |
|
|
def from_args(cls, args: argparse.Namespace) -> 'Config': |
|
|
"""Create config from command line arguments.""" |
|
|
output_dir = os.path.dirname(args.output) or "results" |
|
|
output_file = os.path.basename(args.output) or "classification.txt" |
|
|
|
|
|
return cls( |
|
|
yamnet_model_path=args.yamnet_model, |
|
|
yamnet_classes_path=args.yamnet_classes, |
|
|
model_path=args.model if os.path.exists(args.model) else None, |
|
|
custom_classes_path=args.custom_classes if os.path.exists(args.custom_classes) else None, |
|
|
output_dir=output_dir, |
|
|
output_file=output_file, |
|
|
window_length=args.window, |
|
|
hop_length=args.hop, |
|
|
custom_weight_factor=args.weight |
|
|
) |
|
|
|
|
|
|
|
|
class AudioClassifier: |
|
|
"""Audio classification using YAMNet and custom models.""" |
|
|
|
|
|
def __init__(self, config: Config): |
|
|
"""Initialize classifier with configuration.""" |
|
|
self.config = config |
|
|
self.params = YAMNetParams() |
|
|
|
|
|
|
|
|
self.yamnet_model = None |
|
|
self.model = None |
|
|
self.yamnet_classes = [] |
|
|
self.custom_classes = [] |
|
|
|
|
|
|
|
|
self._load_models() |
|
|
|
|
|
def _load_models(self) -> None: |
|
|
"""Load YAMNet and custom models.""" |
|
|
|
|
|
try: |
|
|
from yamnet import yamnet_frames_model, class_names |
|
|
|
|
|
logger.info(f"Loading YAMNet model from {self.config.yamnet_model_path}") |
|
|
self.yamnet_model = yamnet_frames_model(self.params) |
|
|
self.yamnet_model.load_weights(self.config.yamnet_model_path) |
|
|
|
|
|
logger.info(f"Loading YAMNet classes from {self.config.yamnet_classes_path}") |
|
|
self.yamnet_classes = class_names(self.config.yamnet_classes_path) |
|
|
|
|
|
except ImportError: |
|
|
logger.error("YAMNet module not found. Please install it or provide correct path.") |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load YAMNet model: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
if self.config.model_path: |
|
|
try: |
|
|
logger.info(f"Loading custom model from {self.config.model_path}") |
|
|
self.model = load_model(self.config.model_path) |
|
|
|
|
|
if self.config.custom_classes_path: |
|
|
logger.info(f"Loading custom classes from {self.config.custom_classes_path}") |
|
|
self.custom_classes = np.load(self.config.custom_classes_path, allow_pickle=True) |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to load custom model: {e}") |
|
|
logger.warning("Continuing with YAMNet model only.") |
|
|
self.model = None |
|
|
self.custom_classes = [] |
|
|
|
|
|
def classify_file(self, audio_path: str) -> Dict[str, Any]: |
|
|
"""Classify audio file and return results.""" |
|
|
logger.info(f"Processing audio file: {audio_path}") |
|
|
|
|
|
|
|
|
waveform, sr = self._load_audio(audio_path) |
|
|
|
|
|
|
|
|
logger.info("Processing audio segments...") |
|
|
segments_results = self._process_audio_segments(waveform, sr) |
|
|
|
|
|
|
|
|
logger.info("Aggregating results...") |
|
|
final_results = self._aggregate_results(segments_results) |
|
|
|
|
|
|
|
|
if self.config.output_path: |
|
|
self._save_results(final_results) |
|
|
|
|
|
return final_results |
|
|
|
|
|
def _load_audio(self, file_path: str) -> Tuple[np.ndarray, int]: |
|
|
"""Load and preprocess audio file.""" |
|
|
|
|
|
if not os.path.exists(file_path): |
|
|
raise FileNotFoundError(f"Audio file not found: {file_path}") |
|
|
|
|
|
|
|
|
logger.info(f"Loading audio from {file_path}") |
|
|
wav_data, sr = sf.read(file_path, dtype=np.int16) |
|
|
|
|
|
|
|
|
waveform = wav_data / 32768.0 |
|
|
waveform = waveform.astype('float32') |
|
|
|
|
|
|
|
|
if len(waveform.shape) > 1: |
|
|
logger.info("Converting stereo audio to mono") |
|
|
waveform = np.mean(waveform, axis=1) |
|
|
|
|
|
|
|
|
if sr != self.params.sample_rate: |
|
|
logger.info(f"Resampling audio from {sr}Hz to {self.params.sample_rate}Hz") |
|
|
waveform = resampy.resample(waveform, sr, self.params.sample_rate) |
|
|
sr = int(self.params.sample_rate) |
|
|
|
|
|
return waveform, sr |
|
|
|
|
|
def _process_audio_segments(self, waveform: np.ndarray, sr: int) -> List[Dict[str, Any]]: |
|
|
"""Process audio in segments.""" |
|
|
segment_length_samples = int(sr * self.config.window_length) |
|
|
hop_length_samples = int(sr * self.config.hop_length) |
|
|
|
|
|
if segment_length_samples <= 0: |
|
|
raise ValueError(f"Invalid segment length: {self.config.window_length} seconds") |
|
|
|
|
|
segments_results = [] |
|
|
|
|
|
|
|
|
total_segments = max(1, (len(waveform) - segment_length_samples + hop_length_samples) // hop_length_samples) |
|
|
for i in range(0, len(waveform) - segment_length_samples + 1, hop_length_samples): |
|
|
segment_idx = i // hop_length_samples + 1 |
|
|
logger.debug(f"Processing segment {segment_idx}/{total_segments}") |
|
|
|
|
|
end_idx = min(i + segment_length_samples, len(waveform)) |
|
|
window = waveform[i:end_idx] |
|
|
|
|
|
|
|
|
yamnet_predictions = self._get_yamnet_predictions(window) |
|
|
|
|
|
|
|
|
custom_predictions = None |
|
|
if self.model is not None: |
|
|
custom_predictions = self._get_custom_predictions(window) |
|
|
|
|
|
|
|
|
combined_results = self._combine_predictions(yamnet_predictions, custom_predictions) |
|
|
|
|
|
|
|
|
segment_result = { |
|
|
'yamnet_predictions': yamnet_predictions, |
|
|
'custom_predictions': custom_predictions, |
|
|
'combined_predictions': combined_results |
|
|
} |
|
|
|
|
|
segments_results.append(segment_result) |
|
|
|
|
|
return segments_results |
|
|
|
|
|
def _get_yamnet_predictions(self, audio_segment: np.ndarray) -> Dict[str, float]: |
|
|
"""Get YAMNet predictions for an audio segment.""" |
|
|
try: |
|
|
scores, embeddings, spectrogram = self.yamnet_model(audio_segment) |
|
|
prediction = np.mean(scores, axis=0) |
|
|
|
|
|
|
|
|
top_indices = np.argsort(prediction)[::-1][:self.config.top_k] |
|
|
top_labels = [self.yamnet_classes[i] for i in top_indices] |
|
|
top_scores = prediction[top_indices] |
|
|
|
|
|
return {label: float(score) for label, score in zip(top_labels, top_scores)} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in YAMNet prediction: {e}") |
|
|
return {} |
|
|
|
|
|
def _get_custom_predictions(self, audio_segment: np.ndarray) -> Dict[str, float]: |
|
|
"""Get custom model predictions for an audio segment.""" |
|
|
try: |
|
|
|
|
|
embeddings = self.yamnet_model(audio_segment)[1] |
|
|
|
|
|
|
|
|
embeddings_reshaped = np.reshape(embeddings, (embeddings.shape[0], -1)) |
|
|
|
|
|
|
|
|
predictions = self.model.predict(embeddings_reshaped, verbose=0) |
|
|
|
|
|
|
|
|
mean_predictions = np.mean(predictions, axis=0) |
|
|
|
|
|
|
|
|
top_indices = np.argsort(mean_predictions)[::-1][:self.config.top_k] |
|
|
|
|
|
|
|
|
if len(self.custom_classes) > 0: |
|
|
top_labels = [self.custom_classes[i] for i in top_indices] |
|
|
else: |
|
|
|
|
|
top_labels = [f"Class_{i}" for i in top_indices] |
|
|
|
|
|
top_scores = mean_predictions[top_indices] |
|
|
|
|
|
|
|
|
total_score = np.sum(top_scores) |
|
|
if total_score > 0: |
|
|
top_scores = top_scores / total_score |
|
|
|
|
|
return {label: float(score) for label, score in zip(top_labels, top_scores)} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in custom model prediction: {e}") |
|
|
return {} |
|
|
|
|
|
def _combine_predictions( |
|
|
self, |
|
|
yamnet_predictions: Dict[str, float], |
|
|
custom_predictions: Optional[Dict[str, float]] |
|
|
) -> Dict[str, float]: |
|
|
"""Combine predictions from different models.""" |
|
|
combined = {} |
|
|
|
|
|
|
|
|
if custom_predictions: |
|
|
for label, score in custom_predictions.items(): |
|
|
combined[label] = score * self.config.custom_weight_factor |
|
|
|
|
|
|
|
|
for label, score in yamnet_predictions.items(): |
|
|
if label not in self.config.excluded_classes: |
|
|
if label not in combined or score > combined[label]: |
|
|
combined[label] = score |
|
|
|
|
|
return combined |
|
|
|
|
|
def _aggregate_results(self, segments_results: List[Dict[str, Any]]) -> Dict[str, Any]: |
|
|
"""Aggregate results across all segments.""" |
|
|
|
|
|
aggregated_predictions = {} |
|
|
|
|
|
|
|
|
for segment in segments_results: |
|
|
for label, score in segment['combined_predictions'].items(): |
|
|
if label in aggregated_predictions: |
|
|
aggregated_predictions[label] = max(aggregated_predictions[label], score) |
|
|
else: |
|
|
aggregated_predictions[label] = score |
|
|
|
|
|
|
|
|
if aggregated_predictions: |
|
|
|
|
|
sorted_predictions = sorted( |
|
|
aggregated_predictions.items(), |
|
|
key=lambda x: x[1], |
|
|
reverse=True |
|
|
)[:self.config.top_k] |
|
|
|
|
|
|
|
|
top_predictions = {label: score for label, score in sorted_predictions} |
|
|
|
|
|
|
|
|
total_score = sum(top_predictions.values()) |
|
|
if total_score > 0: |
|
|
normalized_predictions = { |
|
|
label: score / total_score |
|
|
for label, score in top_predictions.items() |
|
|
} |
|
|
else: |
|
|
|
|
|
normalized_predictions = { |
|
|
label: 1.0 / len(top_predictions) if len(top_predictions) > 0 else 0.0 |
|
|
for label in top_predictions |
|
|
} |
|
|
|
|
|
|
|
|
dominant_label, dominant_score = max(normalized_predictions.items(), key=lambda x: x[1]) |
|
|
dominant_score_percentage = round(dominant_score * 100) |
|
|
|
|
|
|
|
|
aggregated_predictions = normalized_predictions |
|
|
else: |
|
|
dominant_label = "Unknown" |
|
|
dominant_score = 0 |
|
|
dominant_score_percentage = 0 |
|
|
|
|
|
return { |
|
|
'aggregated_predictions': aggregated_predictions, |
|
|
'dominant_label': dominant_label, |
|
|
'dominant_score': dominant_score, |
|
|
'dominant_score_percentage': dominant_score_percentage |
|
|
} |
|
|
|
|
|
def _save_results(self, results: Dict[str, Any]) -> None: |
|
|
"""Save classification results to file.""" |
|
|
try: |
|
|
with open(self.config.output_path, 'w') as file: |
|
|
file.write("Audio Classification Results\n") |
|
|
file.write("=========================\n\n") |
|
|
file.write(f"Primary Classification: {results['dominant_label']} ({results['dominant_score_percentage']}%)\n\n") |
|
|
|
|
|
|
|
|
file.write("Classification Details:\n") |
|
|
file.write("-----------------\n") |
|
|
sorted_predictions = sorted( |
|
|
results['aggregated_predictions'].items(), |
|
|
key=lambda x: x[1], |
|
|
reverse=True |
|
|
) |
|
|
|
|
|
for label, score in sorted_predictions: |
|
|
percentage = round(score * 100) |
|
|
file.write(f"{label}: {percentage}%\n") |
|
|
|
|
|
logger.info(f"Results saved to {self.config.output_path}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error saving results: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function to run audio classification.""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description='Audio classification using YAMNet and custom models', |
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument('audio_file', type=str, help='Path to audio file for classification') |
|
|
|
|
|
|
|
|
parser.add_argument('--yamnet_model', type=str, default='yamnet/yamnet.h5', |
|
|
help='Path to YAMNet model weights') |
|
|
parser.add_argument('--yamnet_classes', type=str, default='yamnet/yamnet_class_map.csv', |
|
|
help='Path to YAMNet class names') |
|
|
parser.add_argument('--model', type=str, default='model/model.h5', |
|
|
help='Path to custom model (optional)') |
|
|
parser.add_argument('--custom_classes', type=str, default='model/model.npy', |
|
|
help='Path to custom class names (optional)') |
|
|
|
|
|
|
|
|
parser.add_argument('--window', type=int, default=10, |
|
|
help='Window length in seconds') |
|
|
parser.add_argument('--hop', type=int, default=1, |
|
|
help='Hop length in seconds') |
|
|
parser.add_argument('--weight', type=float, default=5.0, |
|
|
help='Weighting factor for custom model predictions') |
|
|
|
|
|
|
|
|
parser.add_argument('--output', type=str, default='results/classification.txt', |
|
|
help='Path to output file') |
|
|
|
|
|
|
|
|
parser.add_argument('--verbose', action='store_true', |
|
|
help='Enable verbose output') |
|
|
parser.add_argument('--debug', action='store_true', |
|
|
help='Enable debug logging') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.debug: |
|
|
logger.setLevel(logging.DEBUG) |
|
|
elif args.verbose: |
|
|
logger.setLevel(logging.INFO) |
|
|
else: |
|
|
logger.setLevel(logging.WARNING) |
|
|
|
|
|
try: |
|
|
|
|
|
config = Config.from_args(args) |
|
|
|
|
|
|
|
|
classifier = AudioClassifier(config) |
|
|
|
|
|
|
|
|
results = classifier.classify_file(args.audio_file) |
|
|
|
|
|
|
|
|
print("\nAudio Classification Results") |
|
|
print("=========================") |
|
|
print(f"\nPrimary Classification: {results['dominant_label']} ({results['dominant_score_percentage']}%)") |
|
|
|
|
|
if args.verbose: |
|
|
print("\nTop 10 Predictions:") |
|
|
print("-----------------") |
|
|
sorted_predictions = sorted( |
|
|
results['aggregated_predictions'].items(), |
|
|
key=lambda x: x[1], |
|
|
reverse=True |
|
|
) |
|
|
|
|
|
for label, score in sorted_predictions: |
|
|
percentage = round(score * 100) |
|
|
print(f"{label}: {percentage}%") |
|
|
|
|
|
print(f"\nFull results saved to: {config.output_path}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error: {e}") |
|
|
if args.debug: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return 1 |
|
|
|
|
|
return 0 |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
import sys |
|
|
sys.exit(main()) |