import gradio as gr import numpy as np from PIL import Image import librosa import matplotlib.pyplot as plt import io import os import random from collections import Counter, defaultdict from app.model import predict from app.preprocess import preprocess_audio # Dataset Paths (download manually from Hugging Face) AUDIO_DATASET_DIR = "General_Audio_Dataset" IMAGE_DATASET_DIR = "Mel_Spectrogram_Images_for_Audio_Classification" # Get file lists audio_files = [ os.path.join(AUDIO_DATASET_DIR, f) for f in os.listdir(AUDIO_DATASET_DIR) if f.lower().endswith((".wav", ".mp3")) ] image_files = [ os.path.join(IMAGE_DATASET_DIR, f) for f in os.listdir(IMAGE_DATASET_DIR) if f.lower().endswith(".png") ] # Helper functions def safe_load_image(img): """Ensure input is PIL RGBA image""" if img is None: return None if isinstance(img, np.ndarray): img = Image.fromarray(img) img = img.convert("RGBA") return img def process_image_input(img): img = safe_load_image(img) label, confidence, probs = predict(img) return label, round(confidence, 3), probs def process_audio_input(audio_path): imgs = preprocess_audio(audio_path) # returns list of PIL RGBA images all_preds, all_confs, all_probs = [], [], [] for img in imgs: label, conf, probs = predict(img) all_preds.append(label) all_confs.append(conf) all_probs.append(probs) # Majority vote counter = Counter(all_preds) max_count = max(counter.values()) candidates = [k for k, v in counter.items() if v == max_count] if len(candidates) == 1: final_label = candidates[0] else: conf_sums = defaultdict(float) for i, lbl in enumerate(all_preds): if lbl in candidates: conf_sums[lbl] += all_confs[i] final_label = max(conf_sums, key=conf_sums.get) final_conf = float(np.mean([all_confs[i] for i, lbl in enumerate(all_preds) if lbl == final_label])) return final_label, round(final_conf, 3), all_preds, [round(c, 3) for c in all_confs] # Main classifier def classify(audio_path, image, random_audio=False, random_image=False): # Pick random audio if random_audio and audio_files: audio_path = random.choice(audio_files) # Pick random image if random_image and image_files: img_path = random.choice(image_files) image = Image.open(img_path).convert("RGBA") # If spectrogram image if image is not None: label, conf, probs = process_image_input(image) return { "Final Label": label, "Confidence": conf, "Details": probs }, label # If raw audio if audio_path is not None: label, conf, all_preds, all_confs = process_audio_input(audio_path) return { "Final Label": label, "Confidence": conf, "All Chunk Labels": all_preds, "All Chunk Confidences": all_confs }, label return "Please upload an audio file OR a spectrogram image.", "" # Gradio Interface interface = gr.Interface( fn=classify, inputs=[ gr.Audio(type="filepath", label="Upload Audio (WAV/MP3)"), gr.Image(type="pil", label="Upload Spectrogram Image (PNG RGBA Supported)"), gr.Checkbox(label="Pick Random Audio from Dataset"), gr.Checkbox(label="Pick Random Image from Dataset"), ], outputs=[ gr.JSON(label="Prediction Results"), gr.Textbox(label="Final Label", interactive=False) ], title="General Audio Classifier (Audio + Spectrogram Support)", description=( "Upload a raw audio file OR a spectrogram image.\n" "You can also select random samples from the local datasets.\n" "The output shows a JSON with all details and a separate field for the final label." ), ) interface.launch()