File size: 4,285 Bytes
7651694 6fa015b e0d79c7 600df41 e0d79c7 d8c58cd 600df41 e0d79c7 600df41 e0d79c7 8c00eb3 e0d79c7 7651694 8c00eb3 7651694 e0d79c7 904154d d8c58cd 6fa015b 7651694 8c00eb3 7651694 6fa015b 7651694 6fa015b 7651694 600df41 e0d79c7 7651694 8c00eb3 7651694 600df41 7651694 8c00eb3 904154d 7651694 600df41 7651694 6fa015b 7651694 600df41 7651694 904154d 6fa015b 7651694 e0d79c7 6fa015b 7651694 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import gradio as gr
import numpy as np
from PIL import Image
import random
import io
from collections import Counter, defaultdict
from datasets import load_dataset
from app.model import predict
from app.preprocess import preprocess_audio
# Load Hugging Face datasets directly
audio_ds = load_dataset("AIOmarRehan/General_Audio_Dataset", split="train")
image_ds = load_dataset("AIOmarRehan/Mel_Spectrogram_Images_for_Audio_Classification", split="train")
# Helper function to safely load images
def safe_load_image(img):
if img is None:
return None
if isinstance(img, np.ndarray):
img = Image.fromarray(img)
img = img.convert("RGBA")
return img
# Process spectrogram image
def process_image_input(img):
img = safe_load_image(img)
label, confidence, probs = predict(img)
return label, round(confidence, 3), probs
# Process raw audio
def process_audio_input(audio_path):
imgs = preprocess_audio(audio_path)
all_preds, all_confs, all_probs = [], [], []
for img in imgs:
label, conf, probs = predict(img)
all_preds.append(label)
all_confs.append(conf)
all_probs.append(probs)
# Majority vote
counter = Counter(all_preds)
max_count = max(counter.values())
candidates = [k for k, v in counter.items() if v == max_count]
if len(candidates) == 1:
final_label = candidates[0]
else:
conf_sums = defaultdict(float)
for i, lbl in enumerate(all_preds):
if lbl in candidates:
conf_sums[lbl] += all_confs[i]
final_label = max(conf_sums, key=conf_sums.get)
final_conf = float(np.mean([all_confs[i] for i, lbl in enumerate(all_preds) if lbl == final_label]))
return final_label, round(final_conf, 3), all_preds, [round(c, 3) for c in all_confs]
# Main classifier
def classify(audio_path, image, random_audio=False, random_image=False):
# Pick random audio from HF dataset
if random_audio and len(audio_ds) > 0:
sample = random.choice(audio_ds)
# If dataset stores audio as file path or array
if isinstance(sample["audio"], dict) and "path" in sample["audio"]:
audio_path = sample["audio"]["path"]
elif isinstance(sample["audio"], dict) and "array" in sample["audio"]:
# Save array temporarily
import soundfile as sf
audio_path = "/tmp/random_audio.wav"
sf.write(audio_path, sample["audio"]["array"], sample["audio"]["sampling_rate"])
# Pick random image from HF dataset
if random_image and len(image_ds) > 0:
sample = random.choice(image_ds)
# Handle image bytes
img_bytes = sample["image"] if isinstance(sample["image"], bytes) else sample["image"].tobytes()
image = Image.open(io.BytesIO(img_bytes)).convert("RGBA")
# If spectrogram image
if image is not None:
label, conf, probs = process_image_input(image)
return {
"Final Label": label,
"Confidence": conf,
"Details": probs
}, label
# If raw audio
if audio_path is not None:
label, conf, all_preds, all_confs = process_audio_input(audio_path)
return {
"Final Label": label,
"Confidence": conf,
"All Chunk Labels": all_preds,
"All Chunk Confidences": all_confs
}, label
return "Please upload an audio file OR a spectrogram image.", ""
# Gradio Interface
interface = gr.Interface(
fn=classify,
inputs=[
gr.Audio(type="filepath", label="Upload Audio (WAV/MP3)"),
gr.Image(type="pil", label="Upload Spectrogram Image (PNG RGBA Supported)"),
gr.Checkbox(label="Pick Random Audio from Dataset"),
gr.Checkbox(label="Pick Random Image from Dataset"),
],
outputs=[
gr.JSON(label="Prediction Results"),
gr.Textbox(label="Final Label", interactive=False)
],
title="General Audio Classifier (Audio + Spectrogram Support)",
description=(
"Upload a raw audio file OR a spectrogram image.\n"
"You can also select random samples from your Hugging Face datasets.\n"
"The output shows a JSON with all details and a separate field for the final label."
),
)
interface.launch() |