File size: 4,173 Bytes
bd43bae
 
 
 
 
 
11c8e8e
 
bd43bae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11c8e8e
bd43bae
 
 
 
 
 
 
 
11c8e8e
bd43bae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11c8e8e
bd43bae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11c8e8e
bd43bae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11c8e8e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import gradio as gr
import torch
import librosa
import numpy as np
from transformers import ASTFeatureExtractor, ASTForAudioClassification

# CONFIG───────
HF_REPO    = "vectorverse/Messy_Mashup_Genre_Classifier"
SAMPLE_RATE = 16000
DURATION    = 20
MAX_LENGTH  = SAMPLE_RATE * DURATION
N_TTA       = 5

GENRES   = ["blues", "classical", "country", "disco", "hiphop",
            "jazz", "metal", "pop", "reggae", "rock"]
id2label = {i: g for i, g in enumerate(GENRES)}

GENRE_EMOJI = {
    "blues": "🎸", "classical": "🎻", "country": "🀠", "disco": "πŸͺ©",
    "hiphop": "🎀", "jazz": "🎺", "metal": "🀘", "pop": "🎡",
    "reggae": "🌴", "rock": "πŸ”₯"
}

#LOAD MODEL (once at startup) 
print("Loading model...")
feature_extractor = ASTFeatureExtractor.from_pretrained(HF_REPO)
model = ASTForAudioClassification.from_pretrained(HF_REPO)
model.eval()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model.to(DEVICE)
print(f"Model ready on {DEVICE}!")

# AUDIO HELPERS 
def load_audio(path):
    y, _ = librosa.load(path, sr=SAMPLE_RATE, mono=True)
    return y.astype(np.float32)

def normalize(y):
    return y / (np.max(np.abs(y)) + 1e-6)

def random_crop(y):
    if len(y) >= MAX_LENGTH:
        start = np.random.randint(0, len(y) - MAX_LENGTH)
        return y[start:start + MAX_LENGTH]
    return np.pad(y, (0, MAX_LENGTH - len(y)))

def center_crop(y):
    if len(y) >= MAX_LENGTH:
        start = (len(y) - MAX_LENGTH) // 2
        return y[start:start + MAX_LENGTH]
    return np.pad(y, (0, MAX_LENGTH - len(y)))

# PREDICTION WITH TTA
def predict(audio_path):
    if audio_path is None:
        return "Please upload an audio file.", None

    try:
        audio = load_audio(audio_path)
    except Exception as e:
        return f"Error loading audio: {e}", None

    # TTA: center crop + N_TTA-1 random crops
    crops = [center_crop(audio)]
    for _ in range(N_TTA - 1):
        crops.append(random_crop(audio))

    all_probs = []
    for crop in crops:
        crop = normalize(crop)
        inputs = feature_extractor(
            crop, sampling_rate=SAMPLE_RATE, return_tensors="pt"
        )
        input_values = inputs["input_values"].to(DEVICE)
        with torch.no_grad():
            logits = model(input_values=input_values).logits
            probs  = torch.softmax(logits, dim=1).cpu().numpy()
        all_probs.append(probs)

    avg_probs   = np.mean(all_probs, axis=0)[0]          # shape: (10,)
    pred_idx    = int(np.argmax(avg_probs))
    pred_genre  = id2label[pred_idx]
    confidence  = float(avg_probs[pred_idx]) * 100

    # Build label dict for Gradio bar chart
    label_probs = {
        f"{GENRE_EMOJI.get(id2label[i], '')} {id2label[i].capitalize()}": float(avg_probs[i])
        for i in range(len(GENRES))
    }

    result = f"## {GENRE_EMOJI.get(pred_genre, '')} {pred_genre.capitalize()}\n**Confidence: {confidence:.1f}%**"
    return result, label_probs

# GRADIO UI 
with gr.Blocks(title="🎡 Music Genre Classifier") as demo:
    gr.Markdown(
        """
        # 🎡 Music Genre Classifier
        Upload any music file and the model will predict its genre.
        Supports: blues, classical, country, disco, hip-hop, jazz, metal, pop, reggae, rock.

        *Model: Fine-tuned Audio Spectrogram Transformer (AST) Β· TTA x5*
        """
    )

    with gr.Row():
        with gr.Column(scale=1):
            audio_input = gr.Audio(
                label="Upload Audio",
                type="filepath",
                sources=["upload", "microphone"]
            )
            predict_btn = gr.Button("🎯 Predict Genre", variant="primary")

        with gr.Column(scale=1):
            result_md   = gr.Markdown(label="Prediction")
            prob_chart  = gr.Label(label="Genre Probabilities", num_top_classes=10)

    predict_btn.click(
        fn      = predict,
        inputs  = [audio_input],
        outputs = [result_md, prob_chart]
    )

    gr.Examples(
        examples    = [],          # add example audio paths here if you have them
        inputs      = [audio_input],
        label       = "Examples"
    )

if __name__ == "__main__":
    demo.launch()