AI-Belha-Classifier / inference.py
NOS-Innovation's picture
Upload folder using huggingface_hub
bcb3d72 verified
"""
Audio Classification using YAMNet and Custom Models
A streamlined tool for classifying audio using pre-trained and custom models.
"""
import os
import argparse
import logging
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Union, Any
from dataclasses import dataclass, field
import numpy as np
import pandas as pd
import librosa
import resampy
import soundfile as sf
import tensorflow as tf
from tensorflow.keras.models import load_model
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Suppress TensorFlow warnings
tf.get_logger().setLevel(logging.ERROR)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
@dataclass(frozen=True)
class YAMNetParams:
"""Parameters for YAMNet model."""
sample_rate: float = 16000.0
stft_window_seconds: float = 0.025
stft_hop_seconds: float = 0.010
mel_bands: int = 64
mel_min_hz: float = 125.0
mel_max_hz: float = 7500.0
log_offset: float = 0.001
patch_window_seconds: float = 0.96
patch_hop_seconds: float = 0.48
num_classes: int = 521
conv_padding: str = 'same'
batchnorm_center: bool = True
batchnorm_scale: bool = False
batchnorm_epsilon: float = 1e-4
classifier_activation: str = 'sigmoid'
tflite_compatible: bool = True
@property
def patch_frames(self) -> int:
"""Calculate number of frames per patch."""
return int(round(self.patch_window_seconds / self.stft_hop_seconds))
@property
def patch_bands(self) -> int:
"""Get number of mel bands."""
return self.mel_bands
@dataclass
class Config:
"""Configuration for models and processing parameters."""
yamnet_model_path: str
yamnet_classes_path: str
model_path: Optional[str] = None
custom_classes_path: Optional[str] = None
output_dir: str = "results"
output_file: str = "classification.txt"
# Processing parameters
window_length: int = 10 # seconds
hop_length: int = 1 # seconds
custom_weight_factor: float = 5.0
top_k: int = 10 # Number of top predictions to keep
# Exclude certain classes
excluded_classes: List[str] = field(default_factory=lambda: ["Vehicle"])
def __post_init__(self):
"""Convert paths to absolute paths and ensure output directory exists."""
self.yamnet_model_path = os.path.abspath(self.yamnet_model_path)
self.yamnet_classes_path = os.path.abspath(self.yamnet_classes_path)
if self.model_path:
self.model_path = os.path.abspath(self.model_path)
if self.custom_classes_path:
self.custom_classes_path = os.path.abspath(self.custom_classes_path)
# Create output directory
os.makedirs(Path(self.output_dir), exist_ok=True)
@property
def output_path(self) -> str:
"""Get full path to output file."""
return os.path.join(self.output_dir, self.output_file)
@classmethod
def from_args(cls, args: argparse.Namespace) -> 'Config':
"""Create config from command line arguments."""
output_dir = os.path.dirname(args.output) or "results"
output_file = os.path.basename(args.output) or "classification.txt"
return cls(
yamnet_model_path=args.yamnet_model,
yamnet_classes_path=args.yamnet_classes,
model_path=args.model if os.path.exists(args.model) else None,
custom_classes_path=args.custom_classes if os.path.exists(args.custom_classes) else None,
output_dir=output_dir,
output_file=output_file,
window_length=args.window,
hop_length=args.hop,
custom_weight_factor=args.weight
)
class AudioClassifier:
"""Audio classification using YAMNet and custom models."""
def __init__(self, config: Config):
"""Initialize classifier with configuration."""
self.config = config
self.params = YAMNetParams()
# Initialize models
self.yamnet_model = None
self.model = None
self.yamnet_classes = []
self.custom_classes = []
# Load models
self._load_models()
def _load_models(self) -> None:
"""Load YAMNet and custom models."""
# Load YAMNet model
try:
from yamnet import yamnet_frames_model, class_names
logger.info(f"Loading YAMNet model from {self.config.yamnet_model_path}")
self.yamnet_model = yamnet_frames_model(self.params)
self.yamnet_model.load_weights(self.config.yamnet_model_path)
logger.info(f"Loading YAMNet classes from {self.config.yamnet_classes_path}")
self.yamnet_classes = class_names(self.config.yamnet_classes_path)
except ImportError:
logger.error("YAMNet module not found. Please install it or provide correct path.")
raise
except Exception as e:
logger.error(f"Failed to load YAMNet model: {e}")
raise
# Load custom model if available
if self.config.model_path:
try:
logger.info(f"Loading custom model from {self.config.model_path}")
self.model = load_model(self.config.model_path)
if self.config.custom_classes_path:
logger.info(f"Loading custom classes from {self.config.custom_classes_path}")
self.custom_classes = np.load(self.config.custom_classes_path, allow_pickle=True)
except Exception as e:
logger.warning(f"Failed to load custom model: {e}")
logger.warning("Continuing with YAMNet model only.")
self.model = None
self.custom_classes = []
def classify_file(self, audio_path: str) -> Dict[str, Any]:
"""Classify audio file and return results."""
logger.info(f"Processing audio file: {audio_path}")
# Load audio
waveform, sr = self._load_audio(audio_path)
# Process audio segments
logger.info("Processing audio segments...")
segments_results = self._process_audio_segments(waveform, sr)
# Aggregate results
logger.info("Aggregating results...")
final_results = self._aggregate_results(segments_results)
# Save results
if self.config.output_path:
self._save_results(final_results)
return final_results
def _load_audio(self, file_path: str) -> Tuple[np.ndarray, int]:
"""Load and preprocess audio file."""
if not os.path.exists(file_path):
raise FileNotFoundError(f"Audio file not found: {file_path}")
# Load audio data
logger.info(f"Loading audio from {file_path}")
wav_data, sr = sf.read(file_path, dtype=np.int16)
# Convert to float32 in range [-1.0, 1.0]
waveform = wav_data / 32768.0
waveform = waveform.astype('float32')
# Convert stereo to mono if needed
if len(waveform.shape) > 1:
logger.info("Converting stereo audio to mono")
waveform = np.mean(waveform, axis=1)
# Resample if needed
if sr != self.params.sample_rate:
logger.info(f"Resampling audio from {sr}Hz to {self.params.sample_rate}Hz")
waveform = resampy.resample(waveform, sr, self.params.sample_rate)
sr = int(self.params.sample_rate)
return waveform, sr
def _process_audio_segments(self, waveform: np.ndarray, sr: int) -> List[Dict[str, Any]]:
"""Process audio in segments."""
segment_length_samples = int(sr * self.config.window_length)
hop_length_samples = int(sr * self.config.hop_length)
if segment_length_samples <= 0:
raise ValueError(f"Invalid segment length: {self.config.window_length} seconds")
segments_results = []
# Process each segment
total_segments = max(1, (len(waveform) - segment_length_samples + hop_length_samples) // hop_length_samples)
for i in range(0, len(waveform) - segment_length_samples + 1, hop_length_samples):
segment_idx = i // hop_length_samples + 1
logger.debug(f"Processing segment {segment_idx}/{total_segments}")
end_idx = min(i + segment_length_samples, len(waveform))
window = waveform[i:end_idx]
# Get YAMNet predictions
yamnet_predictions = self._get_yamnet_predictions(window)
# Get custom model predictions if available
custom_predictions = None
if self.model is not None:
custom_predictions = self._get_custom_predictions(window)
# Combine predictions
combined_results = self._combine_predictions(yamnet_predictions, custom_predictions)
# Store results
segment_result = {
'yamnet_predictions': yamnet_predictions,
'custom_predictions': custom_predictions,
'combined_predictions': combined_results
}
segments_results.append(segment_result)
return segments_results
def _get_yamnet_predictions(self, audio_segment: np.ndarray) -> Dict[str, float]:
"""Get YAMNet predictions for an audio segment."""
try:
scores, embeddings, spectrogram = self.yamnet_model(audio_segment)
prediction = np.mean(scores, axis=0)
# Get top predictions
top_indices = np.argsort(prediction)[::-1][:self.config.top_k]
top_labels = [self.yamnet_classes[i] for i in top_indices]
top_scores = prediction[top_indices]
return {label: float(score) for label, score in zip(top_labels, top_scores)}
except Exception as e:
logger.error(f"Error in YAMNet prediction: {e}")
return {}
def _get_custom_predictions(self, audio_segment: np.ndarray) -> Dict[str, float]:
"""Get custom model predictions for an audio segment."""
try:
# Get YAMNet embeddings first
embeddings = self.yamnet_model(audio_segment)[1]
# Reshape embeddings for custom model
embeddings_reshaped = np.reshape(embeddings, (embeddings.shape[0], -1))
# Get predictions from custom model
predictions = self.model.predict(embeddings_reshaped, verbose=0)
# Calculate mean prediction over time
mean_predictions = np.mean(predictions, axis=0)
# Get top predictions
top_indices = np.argsort(mean_predictions)[::-1][:self.config.top_k]
# Check if custom classes are available
if len(self.custom_classes) > 0:
top_labels = [self.custom_classes[i] for i in top_indices]
else:
# Use numeric indices as labels if no class names are available
top_labels = [f"Class_{i}" for i in top_indices]
top_scores = mean_predictions[top_indices]
# Normalize scores
total_score = np.sum(top_scores)
if total_score > 0:
top_scores = top_scores / total_score
return {label: float(score) for label, score in zip(top_labels, top_scores)}
except Exception as e:
logger.error(f"Error in custom model prediction: {e}")
return {}
def _combine_predictions(
self,
yamnet_predictions: Dict[str, float],
custom_predictions: Optional[Dict[str, float]]
) -> Dict[str, float]:
"""Combine predictions from different models."""
combined = {}
# Add custom predictions with weighting
if custom_predictions:
for label, score in custom_predictions.items():
combined[label] = score * self.config.custom_weight_factor
# Add YAMNet predictions if not already present or if higher score
for label, score in yamnet_predictions.items():
if label not in self.config.excluded_classes:
if label not in combined or score > combined[label]:
combined[label] = score
return combined
def _aggregate_results(self, segments_results: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Aggregate results across all segments."""
# Initialize aggregated predictions
aggregated_predictions = {}
# Collect all combined predictions, keeping maximum score per label
for segment in segments_results:
for label, score in segment['combined_predictions'].items():
if label in aggregated_predictions:
aggregated_predictions[label] = max(aggregated_predictions[label], score)
else:
aggregated_predictions[label] = score
# Process results
if aggregated_predictions:
# Get top predictions
sorted_predictions = sorted(
aggregated_predictions.items(),
key=lambda x: x[1],
reverse=True
)[:self.config.top_k]
# Create a new dictionary with only the top predictions
top_predictions = {label: score for label, score in sorted_predictions}
# Normalize scores
total_score = sum(top_predictions.values())
if total_score > 0:
normalized_predictions = {
label: score / total_score
for label, score in top_predictions.items()
}
else:
# Default to equal probabilities if all scores are 0
normalized_predictions = {
label: 1.0 / len(top_predictions) if len(top_predictions) > 0 else 0.0
for label in top_predictions
}
# Find dominant label
dominant_label, dominant_score = max(normalized_predictions.items(), key=lambda x: x[1])
dominant_score_percentage = round(dominant_score * 100)
# Replace original predictions with normalized ones
aggregated_predictions = normalized_predictions
else:
dominant_label = "Unknown"
dominant_score = 0
dominant_score_percentage = 0
return {
'aggregated_predictions': aggregated_predictions,
'dominant_label': dominant_label,
'dominant_score': dominant_score,
'dominant_score_percentage': dominant_score_percentage
}
def _save_results(self, results: Dict[str, Any]) -> None:
"""Save classification results to file."""
try:
with open(self.config.output_path, 'w') as file:
file.write("Audio Classification Results\n")
file.write("=========================\n\n")
file.write(f"Primary Classification: {results['dominant_label']} ({results['dominant_score_percentage']}%)\n\n")
# Add detailed breakdown
file.write("Classification Details:\n")
file.write("-----------------\n")
sorted_predictions = sorted(
results['aggregated_predictions'].items(),
key=lambda x: x[1],
reverse=True
)
for label, score in sorted_predictions:
percentage = round(score * 100)
file.write(f"{label}: {percentage}%\n")
logger.info(f"Results saved to {self.config.output_path}")
except Exception as e:
logger.error(f"Error saving results: {e}")
raise
def main():
"""Main function to run audio classification."""
parser = argparse.ArgumentParser(
description='Audio classification using YAMNet and custom models',
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
# Required arguments
parser.add_argument('audio_file', type=str, help='Path to audio file for classification')
# Model paths
parser.add_argument('--yamnet_model', type=str, default='yamnet/yamnet.h5',
help='Path to YAMNet model weights')
parser.add_argument('--yamnet_classes', type=str, default='yamnet/yamnet_class_map.csv',
help='Path to YAMNet class names')
parser.add_argument('--model', type=str, default='model/model.h5',
help='Path to custom model (optional)')
parser.add_argument('--custom_classes', type=str, default='model/model.npy',
help='Path to custom class names (optional)')
# Processing parameters
parser.add_argument('--window', type=int, default=10,
help='Window length in seconds')
parser.add_argument('--hop', type=int, default=1,
help='Hop length in seconds')
parser.add_argument('--weight', type=float, default=5.0,
help='Weighting factor for custom model predictions')
# Output options
parser.add_argument('--output', type=str, default='results/classification.txt',
help='Path to output file')
# Logging options
parser.add_argument('--verbose', action='store_true',
help='Enable verbose output')
parser.add_argument('--debug', action='store_true',
help='Enable debug logging')
args = parser.parse_args()
# Configure logging
if args.debug:
logger.setLevel(logging.DEBUG)
elif args.verbose:
logger.setLevel(logging.INFO)
else:
logger.setLevel(logging.WARNING)
try:
# Create configuration
config = Config.from_args(args)
# Create classifier
classifier = AudioClassifier(config)
# Process audio file
results = classifier.classify_file(args.audio_file)
# Print results
print("\nAudio Classification Results")
print("=========================")
print(f"\nPrimary Classification: {results['dominant_label']} ({results['dominant_score_percentage']}%)")
if args.verbose:
print("\nTop 10 Predictions:")
print("-----------------")
sorted_predictions = sorted(
results['aggregated_predictions'].items(),
key=lambda x: x[1],
reverse=True
)
for label, score in sorted_predictions:
percentage = round(score * 100)
print(f"{label}: {percentage}%")
print(f"\nFull results saved to: {config.output_path}")
except Exception as e:
logger.error(f"Error: {e}")
if args.debug:
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == '__main__':
import sys
sys.exit(main())