import os import numpy as np import torch import json import sys import matplotlib.pyplot as plt import argparse from sklearn.model_selection import train_test_split from src.data.download import ESC50Downloader from src.data.augment import AudioAugment from src.models.cnn import CNN from src.models.predict import AudioPredictor from src.models.traincnn import CNNTrainer from src.config.config import ProcessingConfig, DatasetConfig, DownloadConfig, TrainConfig def _load_or_preprocess(args) -> tuple[np.ndarray, np.ndarray]: X_path = args.X_path or "data/preprocessed/X.npy" y_path = args.y_path or "data/preprocessed/y.npy" if os.path.exists(X_path) and os.path.exists(y_path): print("Loading existing processed data...") return np.load(X_path, allow_pickle=True), np.load(y_path) print("Processing audio data...") augmenter = AudioAugment() augmenter.run(augment=True, preprocess=True) return np.load(X_path, allow_pickle=True), np.load(y_path) def main() -> None: parser = argparse.ArgumentParser( description="ESC50 Audio Classification", formatter_class=argparse.RawDescriptionHelpFormatter ) subparsers = parser.add_subparsers(dest="command", help="Command to run") subparsers.required = True download_parser = subparsers.add_parser('download', help='Download ESC50 dataset') download_parser.set_defaults(func=cmd_download) augment_parser = subparsers.add_parser('augment', help='Create augmented datasets') augment_parser.add_argument('--input-dir', type=str, default="data/audio/0", help="Input audio directory") augment_parser.add_argument('--output-dir', type=str, default="data/audio", help='Output directory for augmented datasets') augment_parser.set_defaults(func=cmd_augment) preprocess_parser = subparsers.add_parser('preprocess', help='Preprocess audio dataset') preprocess_parser.add_argument('--input-dir', type=str, default="data/audio", help='Input audio directory') preprocess_parser.add_argument('--output-dir', type=str, default="data/preprocessed", help='Output directory for preprocessed data') preprocess_parser.set_defaults(func=cmd_preprocess) train_parser = subparsers.add_parser('train', help='Train model') train_parser.add_argument('--audio-dir', type=str, help='Path to training audio directory') train_parser.add_argument('--output-dir', type=str, help='Path to save preprocessed data') train_parser.add_argument('--X-path', type=str, help='Path to preprocessed X.npy') train_parser.add_argument('--y-path', type=str, help='Path to preprocessed y.npy') train_parser.add_argument('--epochs', type=int, help='Number of epochs (default: 100)') train_parser.add_argument('--batch-size', type=int, help='Batch size (default: 100)') train_parser.add_argument('--lr', type=float, help='Learning rate (default: 0.01)') train_parser.add_argument('--sample-fraction', type=float, help='Fraction of samples per epoch (default: 1/8)') train_parser.add_argument('--checkpoint-dir', type=str, help='Checkpoint directory') train_parser.add_argument('--save-every', type=int, help='Save checkpoint every N epochs') train_parser.set_defaults(func=cmd_train_cnn) resume_parser = subparsers.add_parser('resume', help='Resume training from checkpoint') resume_parser.add_argument('--resume-from', type=str, required=True, help='Path to checkpoint file') resume_parser.add_argument('--X-path', type=str, default="data/preprocessed/X.npy") resume_parser.add_argument('--y-path', type=str, default="data/preprocessed/y.npy") resume_parser.add_argument('--epochs', type=int, help='Number of epochs (default: 100)') resume_parser.add_argument('--batch-size', type=int, help='Batch size (default: 100)') resume_parser.add_argument('--lr', type=float, help='Learning rate (default: 0.01)') resume_parser.add_argument('--sample-fraction', type=float, help='Fraction of samples per epoch') resume_parser.add_argument('--checkpoint-dir', type=str, help='Checkpoint directory') resume_parser.add_argument('--save-every', type=int, help='Save checkpoint every N epochs') resume_parser.set_defaults(func=cmd_resume) predict_parser = subparsers.add_parser('predict', help='Predict audio file class') predict_parser.add_argument('audio_file', type=str, help='Path to .wav file to classify') predict_parser.add_argument('--model', type=str, default='final_model.pt', help='Path to model checkpoint (default: best_model.pt)') predict_parser.add_argument('--top-k', type=int, default=5, help='Number of top predictions (default: 5)') predict_parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device (default: auto)') predict_parser.set_defaults(func=cmd_predict) args = parser.parse_args() args.func(args) def cmd_download(args) -> None: print("Downloading ESC50 audio data...") downloader = ESC50Downloader() downloader.download_clean() print("Downloaded and cleaned data.") def cmd_augment(args) -> None: print("Augmenting audio data...") augmentater = AudioAugment() augmentater.run(augment=True, preprocess=False) print(f"Augmented data and saved to {args.output_dir}") def cmd_preprocess(args) -> None: print("Processing audio data...") augmentater = AudioAugment() augmentater.run(augment=False, preprocess=True) print(f"Preprocessed data and saved to {args.output_dir}") def cmd_train_cnn(args) -> None: X, y = _load_or_preprocess(args) trainer = CNNTrainer(TrainConfig( epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, samples_per_epoch_fraction=args.sample_fraction, checkpoint_dir=args.checkpoint_dir, save_every_n_epoch=args.save_every, )) X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y) best_val_acc = trainer.train_cnn(CNN(n_classes=len(np.unique(y))), X_train, y_train, X_val, y_val, fold_num=0) print(f"\nTraining complete! Best validation accuracy: {best_val_acc:.4f}") def cmd_predict(args) -> None: if not os.path.exists(args.audio_file): print(f"Error: Audio file not found: {args.audio_file}"); sys.exit(1) if not os.path.exists(args.model): print(f"Error: Model file not found: {args.model}"); sys.exit(1) try: predictor = AudioPredictor(model_path=args.model, device=args.device) predicted_class, top_probs, top_indices = predictor.predict_file(args.audio_file, top_k=args.top_k) labels = DatasetConfig().esc50_labels print("\n" + "=" * 60) print(f"Top {args.top_k} Predictions:") print("=" * 60) for i, (prob, idx) in enumerate(zip(top_probs, top_indices)): marker = "★" if idx == predicted_class else " " print(f"{marker} {i+1}. {labels[idx]:20s} - {prob*100:6.2f}%") except Exception as e: import traceback print(f"\nError during prediction: {e}") traceback.print_exc() sys.exit(1) main()