File size: 5,590 Bytes
7651694 e0d79c7 afb665f 600df41 e0d79c7 d8c58cd afb665f 600df41 afb665f 600df41 afb665f 8c00eb3 31122fd 8c00eb3 7651694 8c00eb3 7651694 904154d d8c58cd 6fa015b 7651694 8c00eb3 7651694 6fa015b 7651694 6fa015b 7651694 afb665f 600df41 afb665f e0d79c7 afb665f e0d79c7 afb665f e0d79c7 afb665f 7651694 afb665f 7651694 600df41 7651694 afb665f 904154d 7651694 600df41 7651694 6fa015b 7651694 e14351d 600df41 7651694 904154d e14351d 6fa015b 2fc5594 6fa015b 7651694 e14351d 7651694 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import gradio as gr
import numpy as np
from PIL import Image
import io
import random
import tempfile
from collections import Counter, defaultdict
from datasets import load_dataset
from app.model import predict
from app.preprocess import preprocess_audio
import soundfile as sf
# Load Hugging Face datasets
audio_ds = load_dataset("AIOmarRehan/General_Audio_Dataset", split="train")
image_ds = load_dataset("AIOmarRehan/Mel_Spectrogram_Images_for_Audio_Classification", split="train")
# Helper functions
def safe_load_image(img):
# Ensure input is PIL RGBA image
if img is None:
return None
if isinstance(img, np.ndarray):
img = Image.fromarray(img)
img = img.convert("RGBA")
return img
def process_image_input(img):
img = safe_load_image(img)
label, confidence, probs = predict(img)
return label, round(confidence, 3), probs
def process_audio_input(audio_path):
imgs = preprocess_audio(audio_path)
all_preds, all_confs, all_probs = [], [], []
for img in imgs:
label, conf, probs = predict(img)
all_preds.append(label)
all_confs.append(conf)
all_probs.append(probs)
# Majority vote
counter = Counter(all_preds)
max_count = max(counter.values())
candidates = [k for k, v in counter.items() if v == max_count]
if len(candidates) == 1:
final_label = candidates[0]
else:
conf_sums = defaultdict(float)
for i, lbl in enumerate(all_preds):
if lbl in candidates:
conf_sums[lbl] += all_confs[i]
final_label = max(conf_sums, key=conf_sums.get)
final_conf = float(np.mean([all_confs[i] for i, lbl in enumerate(all_preds) if lbl == final_label]))
return final_label, round(final_conf, 3), all_preds, [round(c, 3) for c in all_confs]
# Main classifier function
def classify(audio_path, image, random_audio=False, random_image=False):
# Random audio selection
if random_audio and len(audio_ds) > 0:
try:
sample = random.choice(audio_ds)
# Dataset may store audio as path or array
audio_obj = sample["audio"]
if isinstance(audio_obj, dict) and "path" in audio_obj:
audio_path = audio_obj["path"]
elif isinstance(audio_obj, dict) and "array" in audio_obj:
# Save temporarily
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
audio_path = tmpfile.name
sf.write(audio_path, audio_obj["array"], audio_obj["sampling_rate"])
else:
# fallback: datasets.Audio object
audio_array, sr = audio_obj["array"], audio_obj["sampling_rate"]
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
audio_path = tmpfile.name
sf.write(audio_path, audio_array, sr)
except Exception as e:
print("Error loading random audio:", e)
audio_path = None
# Random image selection
if random_image and len(image_ds) > 0:
try:
sample = random.choice(image_ds)
img_obj = sample["image"]
if not isinstance(img_obj, Image.Image):
img_obj = Image.fromarray(img_obj) # convert ndarray to PIL
image = img_obj.convert("RGBA")
except Exception as e:
print("Error loading random image:", e)
image = None
# Process spectrogram image
if image is not None:
label, conf, probs = process_image_input(image)
return {
"Final Label": label,
"Confidence": conf,
"Details": probs
}, label
# Process raw audio
if audio_path is not None:
label, conf, all_preds, all_confs = process_audio_input(audio_path)
return {
"Final Label": label,
"Confidence": conf,
"All Chunk Labels": all_preds,
"All Chunk Confidences": all_confs
}, label
return "Please upload an audio file OR a spectrogram image.", ""
description = """
Upload a raw audio file or a spectrogram image.
You may also pick random samples from the provided Hugging Face datasets.
The output includes a JSON structure with detailed predictions and a separate final label.
### How the Model Makes Predictions
Your audio is split into 5-second chunks, and each chunk is converted into a Mel-spectrogram.
A CNN predicts a label and confidence score for each chunk.
The final prediction is determined by:
1. **Majority vote** β the class predicted most frequently across chunks.
2. **Confidence tie-breaker** β if classes tie, the model selects the one with the **highest total confidence** across its chunks.
3. **Final confidence** β the average confidence of all chunks belonging to the final class.
The JSON output shows the final label, its confidence, and all per-chunk predictions.
"""
# Gradio Interface
interface = gr.Interface(
fn=classify,
inputs=[
gr.Audio(type="filepath", label="Upload Audio (WAV/MP3)"),
gr.Image(type="pil", label="Upload Spectrogram Image"),
gr.Checkbox(label="Pick Random Audio from Dataset"),
gr.Checkbox(label="Pick Random Mel Spectrogram Image from Dataset"),
],
outputs=[
gr.JSON(label="Prediction Results"),
gr.Textbox(label="Final Label", interactive=False)
],
title="General Audio Classifier (Audio + Spectrogram Support)",
description=description,
)
interface.launch() |