File size: 3,082 Bytes
7651694 8c00eb3 7651694 8c00eb3 7651694 8c00eb3 7651694 8c00eb3 7651694 8c00eb3 904154d 7651694 8c00eb3 7651694 8c00eb3 7651694 8c00eb3 7651694 8c00eb3 904154d 7651694 8c00eb3 7651694 8c00eb3 904154d 7651694 904154d 8c00eb3 7651694 904154d 7651694 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import gradio as gr
import numpy as np
from PIL import Image
from app.preprocess import preprocess_audio
from app.model import predict
from collections import Counter, defaultdict
import librosa
# IMAGE HANDLING
def safe_load_image(img):
"""
Ensure the input is a valid PIL RGBA image.
Gradio sometimes gives numpy arrays β we convert safely.
"""
if img is None:
return None
# If numpy array β convert to PIL
if isinstance(img, np.ndarray):
img = Image.fromarray(img)
# Convert to RGBA, to make sure the Alpha channel keep
img = img.convert("RGBA")
return img
# PROCESS SPECTROGRAM IMAGE
def process_image_input(img):
img = safe_load_image(img)
label, confidence, probs = predict(img)
return label, round(confidence, 3), probs
# PROCESS RAW AUDIO
def process_audio_input(audio_path):
imgs = preprocess_audio(audio_path) # returns list of PIL RGBA images
all_preds = []
all_confs = []
all_probs = []
for img in imgs:
label, conf, probs = predict(img)
all_preds.append(label)
all_confs.append(conf)
all_probs.append(probs)
# Majority vote
counter = Counter(all_preds)
max_count = max(counter.values())
candidates = [k for k, v in counter.items() if v == max_count]
if len(candidates) == 1:
final_label = candidates[0]
else:
conf_sums = defaultdict(float)
for i, label in enumerate(all_preds):
if label in candidates:
conf_sums[label] += all_confs[i]
final_label = max(conf_sums, key=conf_sums.get)
final_conf = float(
np.mean([all_confs[i] for i, lbl in enumerate(all_preds) if lbl == final_label])
)
return final_label, round(final_conf, 3), all_preds, [round(c, 3) for c in all_confs]
# MAIN CLASSIFIER
def classify(audio_path, image):
# If spectrogram image
if image is not None:
label, conf, probs = process_image_input(image)
return {
"Final Label": label,
"Confidence": conf,
"Details": probs
}
# If raw audio
if audio_path is not None:
label, conf, all_preds, all_confs = process_audio_input(audio_path)
return {
"Final Label": label,
"Confidence": conf,
"All Chunk Labels": all_preds,
"All Chunk Confidences": all_confs
}
return "Please upload an audio file OR a spectrogram image."
# GRADIO UI
interface = gr.Interface(
fn=classify,
inputs=[
gr.Audio(type="filepath", label="Upload Audio (WAV/MP3)"),
gr.Image(type="pil", label="Upload Spectrogram Image (PNG RGBA Supported)")
],
outputs=gr.JSON(label="Prediction Results"),
title="General Audio Classifier (Audio + Spectrogram Support)",
description=(
"Upload a raw audio file OR a spectrogram image.\n"
"If audio β model preprocesses into mel-spectrogram chunks.\n"
"If image β model classifies the spectrogram directly.\n"
),
)
interface.launch() |