File size: 6,097 Bytes
7651694
 
 
e0d79c7
afb665f
 
600df41
e0d79c7
d8c58cd
 
afb665f
600df41
afb665f
 
 
600df41
afb665f
8c00eb3
afb665f
8c00eb3
 
 
 
 
 
 
7651694
8c00eb3
7651694
 
 
904154d
d8c58cd
6fa015b
7651694
 
 
 
 
 
 
8c00eb3
7651694
 
 
 
 
 
 
 
6fa015b
 
 
7651694
 
6fa015b
7651694
 
afb665f
600df41
afb665f
e0d79c7
afb665f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0d79c7
afb665f
e0d79c7
afb665f
 
 
 
 
 
 
 
 
7651694
afb665f
7651694
 
 
 
 
 
600df41
7651694
afb665f
904154d
 
7651694
 
 
 
 
600df41
7651694
6fa015b
7651694
600df41
7651694
 
 
904154d
6fa015b
 
2fc5594
6fa015b
 
 
 
7651694
 
 
2fc5594
 
 
 
 
 
 
10f27a0
a40c093
 
10f27a0
a40c093
10f27a0
a40c093
10f27a0
a40c093
10f27a0
a40c093
7651694
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import gradio as gr
import numpy as np
from PIL import Image
import io
import random
import tempfile
from collections import Counter, defaultdict
from datasets import load_dataset
from app.model import predict
from app.preprocess import preprocess_audio
import soundfile as sf

# Load Hugging Face datasets
audio_ds = load_dataset("AIOmarRehan/General_Audio_Dataset", split="train")
image_ds = load_dataset("AIOmarRehan/Mel_Spectrogram_Images_for_Audio_Classification", split="train")

# Helper functions
def safe_load_image(img):
    """Ensure input is PIL RGBA image"""
    if img is None:
        return None
    if isinstance(img, np.ndarray):
        img = Image.fromarray(img)
    img = img.convert("RGBA")
    return img

def process_image_input(img):
    img = safe_load_image(img)
    label, confidence, probs = predict(img)
    return label, round(confidence, 3), probs

def process_audio_input(audio_path):
    imgs = preprocess_audio(audio_path)
    all_preds, all_confs, all_probs = [], [], []

    for img in imgs:
        label, conf, probs = predict(img)
        all_preds.append(label)
        all_confs.append(conf)
        all_probs.append(probs)

    # Majority vote
    counter = Counter(all_preds)
    max_count = max(counter.values())
    candidates = [k for k, v in counter.items() if v == max_count]

    if len(candidates) == 1:
        final_label = candidates[0]
    else:
        conf_sums = defaultdict(float)
        for i, lbl in enumerate(all_preds):
            if lbl in candidates:
                conf_sums[lbl] += all_confs[i]
        final_label = max(conf_sums, key=conf_sums.get)

    final_conf = float(np.mean([all_confs[i] for i, lbl in enumerate(all_preds) if lbl == final_label]))
    return final_label, round(final_conf, 3), all_preds, [round(c, 3) for c in all_confs]

# Main classifier function
def classify(audio_path, image, random_audio=False, random_image=False):
    # Random audio selection
    if random_audio and len(audio_ds) > 0:
        try:
            sample = random.choice(audio_ds)
            # Dataset may store audio as path or array
            audio_obj = sample["audio"]
            if isinstance(audio_obj, dict) and "path" in audio_obj:
                audio_path = audio_obj["path"]
            elif isinstance(audio_obj, dict) and "array" in audio_obj:
                # Save temporarily
                with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
                    audio_path = tmpfile.name
                    sf.write(audio_path, audio_obj["array"], audio_obj["sampling_rate"])
            else:
                # fallback: datasets.Audio object
                audio_array, sr = audio_obj["array"], audio_obj["sampling_rate"]
                with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
                    audio_path = tmpfile.name
                    sf.write(audio_path, audio_array, sr)
        except Exception as e:
            print("Error loading random audio:", e)
            audio_path = None

    # Random image selection
    if random_image and len(image_ds) > 0:
        try:
            sample = random.choice(image_ds)
            img_obj = sample["image"]
            if not isinstance(img_obj, Image.Image):
                img_obj = Image.fromarray(img_obj)  # convert ndarray to PIL
            image = img_obj.convert("RGBA")
        except Exception as e:
            print("Error loading random image:", e)
            image = None

    # Process spectrogram image
    if image is not None:
        label, conf, probs = process_image_input(image)
        return {
            "Final Label": label,
            "Confidence": conf,
            "Details": probs
        }, label

    # Process raw audio
    if audio_path is not None:
        label, conf, all_preds, all_confs = process_audio_input(audio_path)
        return {
            "Final Label": label,
            "Confidence": conf,
            "All Chunk Labels": all_preds,
            "All Chunk Confidences": all_confs
        }, label

    return "Please upload an audio file OR a spectrogram image.", ""

# Gradio Interface
interface = gr.Interface(
    fn=classify,
    inputs=[
        gr.Audio(type="filepath", label="Upload Audio (WAV/MP3)"),
        gr.Image(type="pil", label="Upload Spectrogram Image (PNG RGBA Supported)"),
        gr.Checkbox(label="Pick Random Audio from Dataset"),
        gr.Checkbox(label="Pick Random Mel Spectrogram Image from Dataset"),
    ],
    outputs=[
        gr.JSON(label="Prediction Results"),
        gr.Textbox(label="Final Label", interactive=False)
    ],
    title="General Audio Classifier (Audio + Spectrogram Support)",
    description=(
        "\nUpload a raw audio file OR a spectrogram image.\n"
        "\nYou can also select random samples from your Hugging Face datasets.\n"
        "\nThe output shows a JSON with all details and a separate field for the final label.\n"
        "\nYour audio is split into 5-second chunks. Each chunk is converted into a Mel-spectrogram and passed through a CNN trained to recognize patterns in frequency and time. 
        The model predicts a label for every chunk. 
        The final result is chosen by majority vote, using confidence scores to break ties. 
        The output shows the final label, its confidence, and the predictions for each chunk.\n"
        "\nHow the Model Makes Predictions
            The audio is split into 5-second chunks and each chunk is turned into a Mel-spectrogram. A CNN predicts a label and confidence score for every chunk. 
            The final result is based on:

            Majority vote β€” the class that appears most often across chunks.

            Tie-breaker β€” if two or more classes appear the same number of times, the model selects the one with the highest total confidence across its chunks.

            Final confidence β€” the average confidence of all chunks predicted as the final class.

            The output shows the final label, its confidence, and the per-chunk predictions.\n"
    ),
)

interface.launch()