|
|
""" |
|
|
Audio Classification System |
|
|
|
|
|
This module trains a neural network model on audio data using YAMNet embeddings. |
|
|
It extracts features from audio files and trains a classifier to recognize audio classes. |
|
|
|
|
|
Usage: |
|
|
python main.py --data_path <path_to_data> --model_name <model_name> |
|
|
|
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import argparse |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Tuple, List, Dict, Optional, Any |
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import tensorflow as tf |
|
|
import librosa |
|
|
from tqdm import tqdm |
|
|
from sklearn.preprocessing import LabelBinarizer |
|
|
from sklearn.utils import shuffle |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
|
handlers=[ |
|
|
logging.StreamHandler() |
|
|
] |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
DEFAULT_CONFIG = { |
|
|
'yamnet_path': 'yamnet/yamnet.h5', |
|
|
'classes_path': 'yamnet/yamnet_class_map.csv', |
|
|
'sample_rate': 16000, |
|
|
'epochs': 100, |
|
|
'batch_size': 32, |
|
|
'learning_rate': 0.001, |
|
|
'num_hidden': 1024, |
|
|
'hidden_layer_size': 512, |
|
|
'num_extra_layers': 1, |
|
|
'dropout_rate': 0.3, |
|
|
'regularization': 0.01, |
|
|
'patience': 10, |
|
|
'validation_split': 0.2, |
|
|
'model_folder': 'model' |
|
|
} |
|
|
|
|
|
|
|
|
class Configuration: |
|
|
"""Handles configuration for the audio classification system.""" |
|
|
|
|
|
def __init__(self, custom_config: Optional[Dict[str, Any]] = None): |
|
|
""" |
|
|
Initialize configuration handler. |
|
|
|
|
|
Args: |
|
|
custom_config: Custom configuration to override defaults |
|
|
""" |
|
|
self.config = DEFAULT_CONFIG.copy() |
|
|
if custom_config: |
|
|
self.config.update(custom_config) |
|
|
|
|
|
def get(self, key: str, default: Any = None) -> Any: |
|
|
return self.config.get(key, default) |
|
|
|
|
|
def set(self, key: str, value: Any) -> None: |
|
|
self.config[key] = value |
|
|
|
|
|
def __getitem__(self, key: str) -> Any: |
|
|
return self.config[key] |
|
|
|
|
|
|
|
|
class ClassMap: |
|
|
"""Handles audio class mapping and persistence.""" |
|
|
|
|
|
def __init__(self, config: Configuration): |
|
|
""" |
|
|
Initialize class map. |
|
|
|
|
|
Args: |
|
|
config: Configuration handler |
|
|
""" |
|
|
self.config = config |
|
|
self.classes_path = config['classes_path'] |
|
|
self._ensure_classes_file_exists() |
|
|
|
|
|
def _ensure_classes_file_exists(self) -> None: |
|
|
"""Ensure the classes mapping file exists.""" |
|
|
if not os.path.exists(self.classes_path): |
|
|
logger.info(f"Class map file not found: {self.classes_path}. Creating a new one.") |
|
|
|
|
|
pd.DataFrame({"display_name": [], "index": [], "mid": []}).to_csv( |
|
|
self.classes_path, index=False |
|
|
) |
|
|
|
|
|
def load_yamnet_classes(self) -> np.ndarray: |
|
|
"""Load classes from YAMNet class map CSV file.""" |
|
|
try: |
|
|
df = pd.read_csv(self.classes_path) |
|
|
return df["display_name"].values |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading classes: {str(e)}") |
|
|
return np.array([]) |
|
|
|
|
|
def update_classes(self, data_path: str) -> List[str]: |
|
|
""" |
|
|
Update classes based on directory structure. |
|
|
|
|
|
Args: |
|
|
data_path: Path to data directory |
|
|
|
|
|
Returns: |
|
|
List of all class names |
|
|
""" |
|
|
try: |
|
|
|
|
|
existing_classes_df = pd.read_csv(self.classes_path) |
|
|
existing_classes_set = set(existing_classes_df['display_name']) |
|
|
|
|
|
|
|
|
new_classes = [] |
|
|
for cls in sorted(os.listdir(data_path)): |
|
|
class_path = os.path.join(data_path, cls) |
|
|
if os.path.isdir(class_path) and cls not in existing_classes_set: |
|
|
new_classes.append(cls) |
|
|
|
|
|
|
|
|
if new_classes: |
|
|
logger.info(f"Adding {len(new_classes)} new classes: {', '.join(new_classes)}") |
|
|
new_classes_df = pd.DataFrame({ |
|
|
'display_name': new_classes, |
|
|
'index': [''] * len(new_classes), |
|
|
'mid': [''] * len(new_classes) |
|
|
}) |
|
|
updated_classes_df = pd.concat([existing_classes_df, new_classes_df], ignore_index=True) |
|
|
updated_classes_df.to_csv(self.classes_path, index=False) |
|
|
|
|
|
|
|
|
return [cls for cls in sorted(os.listdir(data_path)) |
|
|
if os.path.isdir(os.path.join(data_path, cls))] |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error updating classes: {str(e)}") |
|
|
raise |
|
|
|
|
|
|
|
|
class FeatureExtractor: |
|
|
"""Extracts features from audio files using YAMNet.""" |
|
|
|
|
|
def __init__(self, config: Configuration): |
|
|
""" |
|
|
Initialize feature extractor. |
|
|
|
|
|
Args: |
|
|
config: Configuration handler |
|
|
""" |
|
|
self.config = config |
|
|
self.yamnet_model = self._load_yamnet_model() |
|
|
|
|
|
def _load_yamnet_model(self): |
|
|
"""Load YAMNet model for feature extraction.""" |
|
|
try: |
|
|
logger.info("Loading YAMNet model...") |
|
|
|
|
|
from yamnet import yamnet_frames_model |
|
|
from params import Params |
|
|
|
|
|
model = yamnet_frames_model(Params()) |
|
|
model.load_weights(self.config['yamnet_path']) |
|
|
return model |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading YAMNet model: {str(e)}") |
|
|
raise |
|
|
|
|
|
def extract_features(self, audio_path: str) -> np.ndarray: |
|
|
""" |
|
|
Extract features from an audio file using YAMNet. |
|
|
|
|
|
Args: |
|
|
audio_path: Path to audio file |
|
|
|
|
|
Returns: |
|
|
Numpy array of extracted features |
|
|
""" |
|
|
try: |
|
|
|
|
|
wav, _ = librosa.load( |
|
|
audio_path, |
|
|
sr=self.config['sample_rate'], |
|
|
mono=True |
|
|
) |
|
|
wav = wav.astype(np.float32) |
|
|
|
|
|
if len(wav) == 0: |
|
|
logger.warning(f"Warning: Empty audio file: {audio_path}") |
|
|
return np.array([]) |
|
|
|
|
|
|
|
|
_, embeddings, _ = self.yamnet_model(wav) |
|
|
return embeddings.numpy() |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error extracting features from {audio_path}: {str(e)}") |
|
|
return np.array([]) |
|
|
|
|
|
|
|
|
class DatasetLoader: |
|
|
"""Creates a dataset from audio files.""" |
|
|
|
|
|
def __init__(self, config: Configuration, feature_extractor: FeatureExtractor): |
|
|
""" |
|
|
Initialize dataset creator. |
|
|
|
|
|
Args: |
|
|
config: Configuration handler |
|
|
feature_extractor: Feature extractor |
|
|
""" |
|
|
self.config = config |
|
|
self.feature_extractor = feature_extractor |
|
|
|
|
|
def create_dataset(self, data_path: str, classes: List[str]) -> Tuple[np.ndarray, np.ndarray]: |
|
|
""" |
|
|
Create a dataset from audio files in the specified path. |
|
|
|
|
|
Args: |
|
|
data_path: Path to the directory containing audio files organized in class folders |
|
|
classes: List of class names |
|
|
|
|
|
Returns: |
|
|
samples: Numpy array of audio features |
|
|
labels: Numpy array of corresponding labels |
|
|
""" |
|
|
samples, labels = [], [] |
|
|
|
|
|
for cls in classes: |
|
|
class_path = os.path.join(data_path, cls) |
|
|
if not os.path.isdir(class_path): |
|
|
continue |
|
|
|
|
|
logger.info(f"Processing class: {cls}") |
|
|
audio_files = os.listdir(class_path) |
|
|
|
|
|
for sound in tqdm(audio_files, desc=f"Processing {cls}"): |
|
|
audio_path = os.path.join(class_path, sound) |
|
|
embeddings = self.feature_extractor.extract_features(audio_path) |
|
|
|
|
|
if len(embeddings) == 0: |
|
|
continue |
|
|
|
|
|
|
|
|
for embedding in embeddings: |
|
|
samples.append(embedding) |
|
|
labels.append(cls) |
|
|
|
|
|
|
|
|
if not samples: |
|
|
error_msg = "No valid audio samples were processed!" |
|
|
logger.error(error_msg) |
|
|
raise ValueError(error_msg) |
|
|
|
|
|
samples = np.asarray(samples) |
|
|
labels = np.asarray(labels) |
|
|
|
|
|
logger.info(f"Created dataset with {len(samples)} samples across {len(set(labels))} classes") |
|
|
return samples, labels |
|
|
|
|
|
|
|
|
class ModelBuilder: |
|
|
"""Builds and trains neural network models for audio classification.""" |
|
|
|
|
|
def __init__(self, config: Configuration): |
|
|
""" |
|
|
Initialize model builder. |
|
|
|
|
|
Args: |
|
|
config: Configuration handler |
|
|
""" |
|
|
self.config = config |
|
|
|
|
|
def build_model(self, num_classes: int) -> tf.keras.Model: |
|
|
""" |
|
|
Build a neural network model for audio classification. |
|
|
|
|
|
Args: |
|
|
num_classes: Number of output classes |
|
|
|
|
|
Returns: |
|
|
Keras Model object |
|
|
""" |
|
|
|
|
|
inputs = tf.keras.layers.Input(shape=(1024,)) |
|
|
|
|
|
|
|
|
x = tf.keras.layers.Dense( |
|
|
self.config['num_hidden'], |
|
|
activation='relu', |
|
|
kernel_regularizer=tf.keras.regularizers.l2(self.config['regularization']) |
|
|
)(inputs) |
|
|
x = tf.keras.layers.BatchNormalization()(x) |
|
|
x = tf.keras.layers.Dropout(self.config['dropout_rate'])(x) |
|
|
|
|
|
|
|
|
for i in range(self.config['num_extra_layers']): |
|
|
layer_size = self.config['hidden_layer_size'] // (i+1) |
|
|
x = tf.keras.layers.Dense( |
|
|
layer_size, |
|
|
activation='relu', |
|
|
kernel_regularizer=tf.keras.regularizers.l2(self.config['regularization']) |
|
|
)(x) |
|
|
x = tf.keras.layers.BatchNormalization()(x) |
|
|
x = tf.keras.layers.Dropout(self.config['dropout_rate'])(x) |
|
|
|
|
|
|
|
|
outputs = tf.keras.layers.Dense(num_classes, activation='softmax')(x) |
|
|
|
|
|
|
|
|
model = tf.keras.Model(inputs=inputs, outputs=outputs) |
|
|
return model |
|
|
|
|
|
def _create_callbacks(self, model_path: str) -> List[tf.keras.callbacks.Callback]: |
|
|
""" |
|
|
Create callbacks for model training. |
|
|
|
|
|
Args: |
|
|
model_path: Path to save the model |
|
|
|
|
|
Returns: |
|
|
List of callbacks |
|
|
""" |
|
|
|
|
|
log_dir = Path(f"logs/{os.path.basename(model_path)}") |
|
|
log_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
tensorboard = tf.keras.callbacks.TensorBoard( |
|
|
log_dir=log_dir, |
|
|
histogram_freq=1 |
|
|
) |
|
|
|
|
|
|
|
|
early_stopping = tf.keras.callbacks.EarlyStopping( |
|
|
monitor='val_accuracy', |
|
|
patience=self.config['patience'], |
|
|
restore_best_weights=True, |
|
|
verbose=1 |
|
|
) |
|
|
|
|
|
|
|
|
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau( |
|
|
monitor='val_loss', |
|
|
factor=0.5, |
|
|
patience=5, |
|
|
min_lr=0.00001, |
|
|
verbose=1 |
|
|
) |
|
|
|
|
|
return [early_stopping, reduce_lr, tensorboard] |
|
|
|
|
|
def train_model(self, X: np.ndarray, y: np.ndarray, model_name: str) -> Tuple[tf.keras.Model, LabelBinarizer]: |
|
|
""" |
|
|
Train a model on the provided data. |
|
|
|
|
|
Args: |
|
|
X: Input features |
|
|
y: Target labels |
|
|
model_name: Name of the model |
|
|
|
|
|
Returns: |
|
|
Tuple of (trained model, label encoder) |
|
|
""" |
|
|
|
|
|
encoder = LabelBinarizer() |
|
|
encoded_labels = encoder.fit_transform(y) |
|
|
num_classes = len(encoder.classes_) |
|
|
|
|
|
logger.info(f"Training model with {num_classes} classes: {', '.join(encoder.classes_)}") |
|
|
|
|
|
|
|
|
model = self.build_model(num_classes=num_classes) |
|
|
|
|
|
|
|
|
model.summary() |
|
|
|
|
|
|
|
|
optimizer = tf.keras.optimizers.Adam(learning_rate=self.config['learning_rate']) |
|
|
model.compile( |
|
|
optimizer=optimizer, |
|
|
loss=tf.keras.losses.CategoricalCrossentropy(), |
|
|
metrics=['accuracy'] |
|
|
) |
|
|
|
|
|
model_folder = os.path.join(self.config['model_folder']) |
|
|
os.makedirs(model_folder, exist_ok=True) |
|
|
|
|
|
model_path = os.path.join(model_folder, model_name) |
|
|
|
|
|
callbacks = self._create_callbacks(model_path) |
|
|
|
|
|
|
|
|
history = model.fit( |
|
|
X, encoded_labels, |
|
|
epochs=self.config['epochs'], |
|
|
batch_size=self.config['batch_size'], |
|
|
validation_split=self.config['validation_split'], |
|
|
callbacks=callbacks, |
|
|
verbose=1 |
|
|
) |
|
|
|
|
|
|
|
|
model.save(f"{model_path}.h5") |
|
|
np.save(f"{model_path}_classes.npy", encoder.classes_) |
|
|
|
|
|
|
|
|
hist_df = pd.DataFrame(history.history) |
|
|
hist_df.to_csv(f"{model_path}_history.csv", index=False) |
|
|
|
|
|
logger.info(f"Model saved as {model_path}.h5") |
|
|
logger.info(f"Class names saved as {model_path}_classes.npy") |
|
|
|
|
|
return model, encoder |
|
|
|
|
|
|
|
|
def parse_arguments() -> argparse.Namespace: |
|
|
"""Parse command-line arguments.""" |
|
|
parser = argparse.ArgumentParser(description="Train an audio classification model") |
|
|
parser.add_argument('--data_path', type=str, required=True, |
|
|
help='Path to the directory containing audio files') |
|
|
parser.add_argument('--model_name', type=str, required=True, |
|
|
help='Name for the saved model') |
|
|
parser.add_argument('--config', type=str, |
|
|
help='Path to config JSON file (optional)') |
|
|
parser.add_argument('--epochs', type=int, default=DEFAULT_CONFIG['epochs'], |
|
|
help='Number of training epochs') |
|
|
parser.add_argument('--batch_size', type=int, default=DEFAULT_CONFIG['batch_size'], |
|
|
help='Batch size for training') |
|
|
parser.add_argument('--learning_rate', type=float, default=DEFAULT_CONFIG['learning_rate'], |
|
|
help='Initial learning rate') |
|
|
parser.add_argument('--model_folder', type=str, default=DEFAULT_CONFIG['model_folder'], |
|
|
help='Folder to save the model') |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def load_custom_config(config_path: Optional[str]) -> Dict[str, Any]: |
|
|
"""Load custom configuration from a JSON file.""" |
|
|
if not config_path: |
|
|
return {} |
|
|
|
|
|
try: |
|
|
import json |
|
|
with open(config_path, 'r') as f: |
|
|
return json.load(f) |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading config file: {str(e)}") |
|
|
return {} |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function to run the script.""" |
|
|
try: |
|
|
|
|
|
args = parse_arguments() |
|
|
|
|
|
|
|
|
custom_config = load_custom_config(args.config) |
|
|
|
|
|
|
|
|
custom_config.update({ |
|
|
'epochs': args.epochs, |
|
|
'batch_size': args.batch_size, |
|
|
'learning_rate': args.learning_rate, |
|
|
'model_folder': args.model_folder |
|
|
}) |
|
|
|
|
|
|
|
|
config = Configuration(custom_config) |
|
|
|
|
|
logger.info(f"Data path: {args.data_path}") |
|
|
logger.info(f"Model name: {args.model_name}") |
|
|
logger.info(f"Model folder: {config['model_folder']}") |
|
|
|
|
|
|
|
|
class_map = ClassMap(config) |
|
|
feature_extractor = FeatureExtractor(config) |
|
|
dataset_creator = DatasetLoader(config, feature_extractor) |
|
|
model_builder = ModelBuilder(config) |
|
|
|
|
|
|
|
|
classes = class_map.update_classes(args.data_path) |
|
|
|
|
|
|
|
|
samples, labels = dataset_creator.create_dataset(args.data_path, classes) |
|
|
|
|
|
|
|
|
samples, labels = shuffle(samples, labels, random_state=42) |
|
|
|
|
|
|
|
|
model, encoder = model_builder.train_model(samples, labels, args.model_name) |
|
|
|
|
|
logger.info("Training completed successfully!") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error during execution: {str(e)}", exc_info=True) |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |