import gradio as gr import numpy as np from PIL import Image import io import random import tempfile from collections import Counter, defaultdict from datasets import load_dataset from app.model import predict from app.preprocess import preprocess_audio import soundfile as sf # Load Hugging Face datasets audio_ds = load_dataset("AIOmarRehan/General_Audio_Dataset", split="train") image_ds = load_dataset("AIOmarRehan/Mel_Spectrogram_Images_for_Audio_Classification", split="train") # Helper functions def safe_load_image(img): """Ensure input is PIL RGBA image""" if img is None: return None if isinstance(img, np.ndarray): img = Image.fromarray(img) img = img.convert("RGBA") return img def process_image_input(img): img = safe_load_image(img) label, confidence, probs = predict(img) return label, round(confidence, 3), probs def process_audio_input(audio_path): imgs = preprocess_audio(audio_path) all_preds, all_confs, all_probs = [], [], [] for img in imgs: label, conf, probs = predict(img) all_preds.append(label) all_confs.append(conf) all_probs.append(probs) # Majority vote counter = Counter(all_preds) max_count = max(counter.values()) candidates = [k for k, v in counter.items() if v == max_count] if len(candidates) == 1: final_label = candidates[0] else: conf_sums = defaultdict(float) for i, lbl in enumerate(all_preds): if lbl in candidates: conf_sums[lbl] += all_confs[i] final_label = max(conf_sums, key=conf_sums.get) final_conf = float(np.mean([all_confs[i] for i, lbl in enumerate(all_preds) if lbl == final_label])) return final_label, round(final_conf, 3), all_preds, [round(c, 3) for c in all_confs] # Main classifier function def classify(audio_path, image, random_audio=False, random_image=False): # Random audio selection if random_audio and len(audio_ds) > 0: try: sample = random.choice(audio_ds) # Dataset may store audio as path or array audio_obj = sample["audio"] if isinstance(audio_obj, dict) and "path" in audio_obj: audio_path = audio_obj["path"] elif isinstance(audio_obj, dict) and "array" in audio_obj: # Save temporarily with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: audio_path = tmpfile.name sf.write(audio_path, audio_obj["array"], audio_obj["sampling_rate"]) else: # fallback: datasets.Audio object audio_array, sr = audio_obj["array"], audio_obj["sampling_rate"] with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: audio_path = tmpfile.name sf.write(audio_path, audio_array, sr) except Exception as e: print("Error loading random audio:", e) audio_path = None # Random image selection if random_image and len(image_ds) > 0: try: sample = random.choice(image_ds) img_obj = sample["image"] if not isinstance(img_obj, Image.Image): img_obj = Image.fromarray(img_obj) # convert ndarray to PIL image = img_obj.convert("RGBA") except Exception as e: print("Error loading random image:", e) image = None # Process spectrogram image if image is not None: label, conf, probs = process_image_input(image) return { "Final Label": label, "Confidence": conf, "Details": probs }, label # Process raw audio if audio_path is not None: label, conf, all_preds, all_confs = process_audio_input(audio_path) return { "Final Label": label, "Confidence": conf, "All Chunk Labels": all_preds, "All Chunk Confidences": all_confs }, label return "Please upload an audio file OR a spectrogram image.", "" # Gradio Interface interface = gr.Interface( fn=classify, inputs=[ gr.Audio(type="filepath", label="Upload Audio (WAV/MP3)"), gr.Image(type="pil", label="Upload Spectrogram Image (PNG RGBA Supported)"), gr.Checkbox(label="Pick Random Audio from Dataset"), gr.Checkbox(label="Pick Random Mel Spectrogram Image from Dataset"), ], outputs=[ gr.JSON(label="Prediction Results"), gr.Textbox(label="Final Label", interactive=False) ], title="General Audio Classifier (Audio + Spectrogram Support)", description=( "\nUpload a raw audio file OR a spectrogram image.\n" "\nYou can also select random samples from your Hugging Face datasets.\n" "\nThe output shows a JSON with all details and a separate field for the final label.\n" "\nYour audio is split into 5-second chunks. Each chunk is converted into a Mel-spectrogram and passed through a CNN trained to recognize patterns in frequency and time. The model predicts a label for every chunk. The final result is chosen by majority vote, using confidence scores to break ties. The output shows the final label, its confidence, and the predictions for each chunk.\n" "\nHow the Model Makes Predictions The audio is split into 5-second chunks and each chunk is turned into a Mel-spectrogram. A CNN predicts a label and confidence score for every chunk. The final result is based on: Majority vote — the class that appears most often across chunks. Tie-breaker — if two or more classes appear the same number of times, the model selects the one with the highest total confidence across its chunks. Final confidence — the average confidence of all chunks predicted as the final class. The output shows the final label, its confidence, and the per-chunk predictions.\n" ), ) interface.launch()