AIOmarRehan's picture
Update app.py
afb665f verified
raw
history blame
4.96 kB
import gradio as gr
import numpy as np
from PIL import Image
import io
import random
import tempfile
from collections import Counter, defaultdict
from datasets import load_dataset
from app.model import predict
from app.preprocess import preprocess_audio
import soundfile as sf
# Load Hugging Face datasets
audio_ds = load_dataset("AIOmarRehan/General_Audio_Dataset", split="train")
image_ds = load_dataset("AIOmarRehan/Mel_Spectrogram_Images_for_Audio_Classification", split="train")
# Helper functions
def safe_load_image(img):
"""Ensure input is PIL RGBA image"""
if img is None:
return None
if isinstance(img, np.ndarray):
img = Image.fromarray(img)
img = img.convert("RGBA")
return img
def process_image_input(img):
img = safe_load_image(img)
label, confidence, probs = predict(img)
return label, round(confidence, 3), probs
def process_audio_input(audio_path):
imgs = preprocess_audio(audio_path)
all_preds, all_confs, all_probs = [], [], []
for img in imgs:
label, conf, probs = predict(img)
all_preds.append(label)
all_confs.append(conf)
all_probs.append(probs)
# Majority vote
counter = Counter(all_preds)
max_count = max(counter.values())
candidates = [k for k, v in counter.items() if v == max_count]
if len(candidates) == 1:
final_label = candidates[0]
else:
conf_sums = defaultdict(float)
for i, lbl in enumerate(all_preds):
if lbl in candidates:
conf_sums[lbl] += all_confs[i]
final_label = max(conf_sums, key=conf_sums.get)
final_conf = float(np.mean([all_confs[i] for i, lbl in enumerate(all_preds) if lbl == final_label]))
return final_label, round(final_conf, 3), all_preds, [round(c, 3) for c in all_confs]
# Main classifier function
def classify(audio_path, image, random_audio=False, random_image=False):
# Random audio selection
if random_audio and len(audio_ds) > 0:
try:
sample = random.choice(audio_ds)
# Dataset may store audio as path or array
audio_obj = sample["audio"]
if isinstance(audio_obj, dict) and "path" in audio_obj:
audio_path = audio_obj["path"]
elif isinstance(audio_obj, dict) and "array" in audio_obj:
# Save temporarily
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
audio_path = tmpfile.name
sf.write(audio_path, audio_obj["array"], audio_obj["sampling_rate"])
else:
# fallback: datasets.Audio object
audio_array, sr = audio_obj["array"], audio_obj["sampling_rate"]
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
audio_path = tmpfile.name
sf.write(audio_path, audio_array, sr)
except Exception as e:
print("Error loading random audio:", e)
audio_path = None
# Random image selection
if random_image and len(image_ds) > 0:
try:
sample = random.choice(image_ds)
img_obj = sample["image"]
if not isinstance(img_obj, Image.Image):
img_obj = Image.fromarray(img_obj) # convert ndarray to PIL
image = img_obj.convert("RGBA")
except Exception as e:
print("Error loading random image:", e)
image = None
# Process spectrogram image
if image is not None:
label, conf, probs = process_image_input(image)
return {
"Final Label": label,
"Confidence": conf,
"Details": probs
}, label
# Process raw audio
if audio_path is not None:
label, conf, all_preds, all_confs = process_audio_input(audio_path)
return {
"Final Label": label,
"Confidence": conf,
"All Chunk Labels": all_preds,
"All Chunk Confidences": all_confs
}, label
return "Please upload an audio file OR a spectrogram image.", ""
# Gradio Interface
interface = gr.Interface(
fn=classify,
inputs=[
gr.Audio(type="filepath", label="Upload Audio (WAV/MP3)"),
gr.Image(type="pil", label="Upload Spectrogram Image (PNG RGBA Supported)"),
gr.Checkbox(label="Pick Random Audio from Dataset"),
gr.Checkbox(label="Pick Random Image from Dataset"),
],
outputs=[
gr.JSON(label="Prediction Results"),
gr.Textbox(label="Final Label", interactive=False)
],
title="General Audio Classifier (Audio + Spectrogram Support)",
description=(
"Upload a raw audio file OR a spectrogram image.\n"
"You can also select random samples from your Hugging Face datasets.\n"
"The output shows a JSON with all details and a separate field for the final label."
),
)
interface.launch()