File size: 3,082 Bytes
7651694
 
 
 
 
 
8c00eb3
 
 
 
 
 
 
 
 
 
 
7651694
8c00eb3
 
 
7651694
8c00eb3
 
 
 
 
 
7651694
8c00eb3
7651694
 
 
 
8c00eb3
904154d
7651694
8c00eb3
7651694
 
 
 
 
 
 
 
 
 
 
8c00eb3
7651694
 
 
 
 
 
 
 
 
 
 
 
 
8c00eb3
 
 
7651694
 
 
 
8c00eb3
904154d
7651694
8c00eb3
7651694
 
 
 
 
 
 
 
8c00eb3
904154d
 
7651694
 
 
 
 
 
 
 
 
 
 
 
 
 
904154d
8c00eb3
7651694
 
 
 
 
904154d
 
7651694
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
import numpy as np
from PIL import Image
from app.preprocess import preprocess_audio
from app.model import predict
from collections import Counter, defaultdict
import librosa


# IMAGE HANDLING
def safe_load_image(img):
    """
    Ensure the input is a valid PIL RGBA image.
    Gradio sometimes gives numpy arrays β†’ we convert safely.
    """
    if img is None:
        return None

    # If numpy array β†’ convert to PIL
    if isinstance(img, np.ndarray):
        img = Image.fromarray(img)

    # Convert to RGBA, to make sure the Alpha channel keep
    img = img.convert("RGBA")
    return img


# PROCESS SPECTROGRAM IMAGE
def process_image_input(img):
    img = safe_load_image(img)
    label, confidence, probs = predict(img)
    return label, round(confidence, 3), probs


# PROCESS RAW AUDIO
def process_audio_input(audio_path):

    imgs = preprocess_audio(audio_path)  # returns list of PIL RGBA images

    all_preds = []
    all_confs = []
    all_probs = []

    for img in imgs:
        label, conf, probs = predict(img)
        all_preds.append(label)
        all_confs.append(conf)
        all_probs.append(probs)

    # Majority vote
    counter = Counter(all_preds)
    max_count = max(counter.values())
    candidates = [k for k, v in counter.items() if v == max_count]

    if len(candidates) == 1:
        final_label = candidates[0]
    else:
        conf_sums = defaultdict(float)
        for i, label in enumerate(all_preds):
            if label in candidates:
                conf_sums[label] += all_confs[i]
        final_label = max(conf_sums, key=conf_sums.get)

    final_conf = float(
        np.mean([all_confs[i] for i, lbl in enumerate(all_preds) if lbl == final_label])
    )

    return final_label, round(final_conf, 3), all_preds, [round(c, 3) for c in all_confs]


# MAIN CLASSIFIER
def classify(audio_path, image):

    # If spectrogram image
    if image is not None:
        label, conf, probs = process_image_input(image)
        return {
            "Final Label": label,
            "Confidence": conf,
            "Details": probs
        }

    # If raw audio
    if audio_path is not None:
        label, conf, all_preds, all_confs = process_audio_input(audio_path)
        return {
            "Final Label": label,
            "Confidence": conf,
            "All Chunk Labels": all_preds,
            "All Chunk Confidences": all_confs
        }

    return "Please upload an audio file OR a spectrogram image."


# GRADIO UI
interface = gr.Interface(
    fn=classify,
    inputs=[
        gr.Audio(type="filepath", label="Upload Audio (WAV/MP3)"),
        gr.Image(type="pil", label="Upload Spectrogram Image (PNG RGBA Supported)")
    ],
    outputs=gr.JSON(label="Prediction Results"),
    title="General Audio Classifier (Audio + Spectrogram Support)",
    description=(
        "Upload a raw audio file OR a spectrogram image.\n"
        "If audio β†’ model preprocesses into mel-spectrogram chunks.\n"
        "If image β†’ model classifies the spectrogram directly.\n"
    ),
)

interface.launch()