File size: 3,677 Bytes
7651694
 
 
 
 
 
8c00eb3
6fa015b
 
 
 
 
 
8c00eb3
 
 
 
 
 
 
 
 
 
 
 
 
7651694
8c00eb3
7651694
 
 
 
8c00eb3
904154d
6fa015b
 
7651694
 
 
 
 
 
 
8c00eb3
7651694
 
 
 
 
 
 
 
6fa015b
 
 
7651694
 
6fa015b
7651694
 
 
 
8c00eb3
6fa015b
 
 
 
 
 
 
 
 
 
 
7651694
8c00eb3
7651694
 
 
 
 
 
6fa015b
7651694
8c00eb3
904154d
 
7651694
 
 
 
 
6fa015b
7651694
6fa015b
7651694
 
 
 
 
 
904154d
6fa015b
 
 
 
 
 
 
7651694
 
 
 
6fa015b
 
7651694
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
import numpy as np
from PIL import Image
from app.preprocess import preprocess_audio
from app.model import predict
from collections import Counter, defaultdict
import librosa
import random
from datasets import load_dataset

# Load Hugging Face datasets
audio_ds = load_dataset("AIOmarRehan/General_Audio_Dataset", split="train")
image_ds = load_dataset("AIOmarRehan/Mel_Spectrogram_Images_for_Audio_Classification", split="train")


# IMAGE HANDLING
def safe_load_image(img):
    if img is None:
        return None
    if isinstance(img, np.ndarray):
        img = Image.fromarray(img)
    img = img.convert("RGBA")
    return img


# PROCESS SPECTROGRAM IMAGE
def process_image_input(img):
    img = safe_load_image(img)
    label, confidence, probs = predict(img)
    return label, round(confidence, 3), probs


# PROCESS RAW AUDIO
def process_audio_input(audio_path):
    imgs = preprocess_audio(audio_path)
    all_preds, all_confs, all_probs = [], [], []

    for img in imgs:
        label, conf, probs = predict(img)
        all_preds.append(label)
        all_confs.append(conf)
        all_probs.append(probs)

    # Majority vote
    counter = Counter(all_preds)
    max_count = max(counter.values())
    candidates = [k for k, v in counter.items() if v == max_count]

    if len(candidates) == 1:
        final_label = candidates[0]
    else:
        conf_sums = defaultdict(float)
        for i, lbl in enumerate(all_preds):
            if lbl in candidates:
                conf_sums[lbl] += all_confs[i]
        final_label = max(conf_sums, key=conf_sums.get)

    final_conf = float(np.mean([all_confs[i] for i, lbl in enumerate(all_preds) if lbl == final_label]))

    return final_label, round(final_conf, 3), all_preds, [round(c, 3) for c in all_confs]


# MAIN CLASSIFIER
def classify(audio_path, image, random_audio, random_image):
    # Load random audio if selected
    if random_audio:
        rand_sample = random.choice(audio_ds)
        audio_path = rand_sample["audio"]["path"]

    # Load random image if selected
    if random_image:
        rand_sample = random.choice(image_ds)
        img_bytes = rand_sample["image"]
        image = Image.open(img_bytes).convert("RGBA")

    # If spectrogram image
    if image is not None:
        label, conf, probs = process_image_input(image)
        return {
            "Final Label": label,
            "Confidence": conf,
            "Details": probs
        }, label  

    # If raw audio
    if audio_path is not None:
        label, conf, all_preds, all_confs = process_audio_input(audio_path)
        return {
            "Final Label": label,
            "Confidence": conf,
            "All Chunk Labels": all_preds,
            "All Chunk Confidences": all_confs
        }, label  

    return "Please upload an audio file OR a spectrogram image.", ""


# GRADIO UI
interface = gr.Interface(
    fn=classify,
    inputs=[
        gr.Audio(type="filepath", label="Upload Audio (WAV/MP3)"),
        gr.Image(type="pil", label="Upload Spectrogram Image (PNG RGBA Supported)"),
        gr.Checkbox(label="Pick Random Audio from Dataset"),
        gr.Checkbox(label="Pick Random Image from Dataset"),
    ],
    outputs=[
        gr.JSON(label="Prediction Results"),
        gr.Textbox(label="Final Label", interactive=False)
    ],
    title="General Audio Classifier (Audio + Spectrogram Support)",
    description=(
        "Upload a raw audio file OR a spectrogram image.\n"
        "You can also select random samples from the Hugging Face datasets.\n"
        "The output shows a JSON with all details and a separate field for the final label."
    ),
)

interface.launch()