File size: 4,402 Bytes
f275627
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fe9c55
 
 
f275627
 
2fe9c55
 
f275627
 
2fe9c55
 
f275627
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d54aaa
f275627
 
 
 
 
 
 
 
 
 
 
 
 
 
cb0c894
 
 
 
 
9afc401
f275627
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import torchaudio.transforms as T
import torchvision.models as models
import gradio as gr
import numpy as np
import os
 
SAMPLE_RATE = 22050
CROP_SEC = 6.0
CROP_LEN = int(SAMPLE_RATE * CROP_SEC)
N_MELS = 128
N_FFT = 2048
HOP_LENGTH = 512

GENRES = sorted(["blues", "classical", "country", "disco", "hiphop",
                 "jazz", "metal", "pop", "reggae", "rock"])
GENRE2ID = {g: i for i, g in enumerate(GENRES)}
ID2GENRE = {i: g for i, g in enumerate(GENRES)}

DEVICE = torch.device("cpu")  
class PretrainedEfficientNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.efficientnet = models.efficientnet_b0(weights=None)   
        old = self.efficientnet.features[0][0]
        self.efficientnet.features[0][0] = nn.Conv2d(
            1, old.out_channels, kernel_size=old.kernel_size,
            stride=old.stride, padding=old.padding, bias=False)
        self.efficientnet.classifier[1] = nn.Linear(
            self.efficientnet.classifier[1].in_features, num_classes)

    def forward(self, x):
        return self.efficientnet(x)
 
 
model = PretrainedEfficientNet(num_classes=10)
weights_path = os.path.join(os.path.dirname(__file__), "best_effnet.pth")
state_dict = torch.load(weights_path, map_location=DEVICE, weights_only=True)
model.load_state_dict(state_dict)
model.eval()
model.to(DEVICE)
 
mel_transform = T.MelSpectrogram(
    sample_rate=SAMPLE_RATE, n_fft=N_FFT,
    hop_length=HOP_LENGTH, n_mels=N_MELS)
db_transform = T.AmplitudeToDB()


def preprocess_audio(audio_tuple): 
    sr, waveform_np = audio_tuple
 
    waveform = torch.tensor(waveform_np, dtype=torch.float32)
 
    if waveform.dim() == 2: 
        waveform = waveform.mean(dim=-1)
    waveform = waveform.unsqueeze(0)   
 
    if waveform.abs().max() > 2.0:
        waveform = waveform / 32768.0
 
    if sr != SAMPLE_RATE:
        waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE)

    return waveform


def crop_or_pad(waveform, length): 
    if waveform.shape[1] >= length:
        start = (waveform.shape[1] - length) // 2  
        return waveform[:, start:start + length]
    return F.pad(waveform, (0, length - waveform.shape[1]))


def get_tta_crops(waveform, crop_len): 
    crops = []
    total = waveform.shape[1]
    if total <= crop_len:
        padded = F.pad(waveform, (0, crop_len - total))
        return [padded]

    crops.append(waveform[:, :crop_len])                           
    mid = (total - crop_len) // 2
    crops.append(waveform[:, mid:mid + crop_len])                  
    crops.append(waveform[:, -crop_len:])                           
    return crops


def wave_to_mel(wave): 
    mel = mel_transform(wave)
    mel_db = db_transform(mel)
    mel_db = (mel_db - mel_db.mean()) / (mel_db.std() + 1e-6)
    return mel_db
 
@torch.no_grad()
def predict_genre(audio): 
    if audio is None:
        return {g: 0.0 for g in GENRES}

    waveform = preprocess_audio(audio)
    crops = get_tta_crops(waveform, CROP_LEN)
 
    avg_probs = torch.zeros(10)
    for crop in crops:
        mel = wave_to_mel(crop).unsqueeze(0).to(DEVICE)  
        logits = model(mel)
        probs = torch.softmax(logits, dim=1).squeeze(0).cpu()
        avg_probs += probs
    avg_probs /= len(crops)

    result = {GENRES[i]: float(avg_probs[i]) for i in range(10)}
    return result
 
DESCRIPTION = """
## Messy Mashup — Music Genre Classifier

Upload a music clip or record from your microphone and the AI will 
identify the genre from 10 categories: **Blues, Classical, Country, Disco, 
HipHop, Jazz, Metal, Pop, Reggae, Rock**.

### How it works
- **Model:** EfficientNet-B0 fine-tuned on 10,000+ synthetic mashups
- **Test-Time Augmentation:** 3 crops (start, middle, end) averaged for robustness
- **Training Score:** 0.90 Macro F1

*Built for BSDA2001P: Introduction to DL and GenAI - IIT Madras*
"""

demo = gr.Interface(
    fn=predict_genre,
    inputs=gr.Audio(
        label="Upload or Record Audio",
        type="numpy"
    ),
    outputs=gr.Label(
        num_top_classes=10,
        label="Genre Prediction"
    ),
    title="Messy Mashup Genre Classifier",
    description=DESCRIPTION,
    examples=[
        ["song0002.wav"],
        ["song0003.wav"],
        ["song0009.wav"]
    ]
    )

if __name__ == "__main__":
    demo.launch()