dikro commited on
Commit
f275627
·
verified ·
1 Parent(s): a7482a4

full-code commited

Browse files
Files changed (3) hide show
  1. app.py +152 -0
  2. best_effnet.pth +3 -0
  3. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torchaudio
6
+ import torchaudio.transforms as T
7
+ import torchvision.models as models
8
+ import gradio as gr
9
+ import numpy as np
10
+ import os
11
+
12
+ SAMPLE_RATE = 22050
13
+ CROP_SEC = 6.0
14
+ CROP_LEN = int(SAMPLE_RATE * CROP_SEC)
15
+ N_MELS = 128
16
+ N_FFT = 2048
17
+ HOP_LENGTH = 512
18
+
19
+ GENRES = sorted(["blues", "classical", "country", "disco", "hiphop",
20
+ "jazz", "metal", "pop", "reggae", "rock"])
21
+ GENRE2ID = {g: i for i, g in enumerate(GENRES)}
22
+ ID2GENRE = {i: g for i, g in enumerate(GENRES)}
23
+
24
+ DEVICE = torch.device("cpu")
25
+ class PretrainedEfficientNet(nn.Module):
26
+ def __init__(self, num_classes=10):
27
+ super().__init__()
28
+ self.net = models.efficientnet_b0(weights=None)
29
+ old = self.net.features[0][0]
30
+ self.net.features[0][0] = nn.Conv2d(
31
+ 1, old.out_channels, kernel_size=old.kernel_size,
32
+ stride=old.stride, padding=old.padding, bias=False)
33
+ self.net.classifier[1] = nn.Linear(
34
+ self.net.classifier[1].in_features, num_classes)
35
+
36
+ def forward(self, x):
37
+ return self.net(x)
38
+
39
+ model = PretrainedEfficientNet(num_classes=10)
40
+ weights_path = os.path.join(os.path.dirname(__file__), "best_effnet.pth")
41
+ state_dict = torch.load(weights_path, map_location=DEVICE, weights_only=True)
42
+ model.load_state_dict(state_dict)
43
+ model.eval()
44
+ model.to(DEVICE)
45
+
46
+ mel_transform = T.MelSpectrogram(
47
+ sample_rate=SAMPLE_RATE, n_fft=N_FFT,
48
+ hop_length=HOP_LENGTH, n_mels=N_MELS)
49
+ db_transform = T.AmplitudeToDB()
50
+
51
+
52
+ def preprocess_audio(audio_tuple):
53
+ sr, waveform_np = audio_tuple
54
+
55
+ waveform = torch.tensor(waveform_np, dtype=torch.float32)
56
+
57
+ if waveform.dim() == 2:
58
+ waveform = waveform.mean(dim=-1)
59
+ waveform = waveform.unsqueeze(0)
60
+
61
+ if waveform.abs().max() > 2.0:
62
+ waveform = waveform / 32768.0
63
+
64
+ if sr != SAMPLE_RATE:
65
+ waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE)
66
+
67
+ return waveform
68
+
69
+
70
+ def crop_or_pad(waveform, length):
71
+ if waveform.shape[1] >= length:
72
+ start = (waveform.shape[1] - length) // 2
73
+ return waveform[:, start:start + length]
74
+ return F.pad(waveform, (0, length - waveform.shape[1]))
75
+
76
+
77
+ def get_tta_crops(waveform, crop_len):
78
+ crops = []
79
+ total = waveform.shape[1]
80
+ if total <= crop_len:
81
+ padded = F.pad(waveform, (0, crop_len - total))
82
+ return [padded]
83
+
84
+ crops.append(waveform[:, :crop_len])
85
+ mid = (total - crop_len) // 2
86
+ crops.append(waveform[:, mid:mid + crop_len])
87
+ crops.append(waveform[:, -crop_len:])
88
+ return crops
89
+
90
+
91
+ def wave_to_mel(wave):
92
+ mel = mel_transform(wave)
93
+ mel_db = db_transform(mel)
94
+ mel_db = (mel_db - mel_db.mean()) / (mel_db.std() + 1e-6)
95
+ return mel_db
96
+
97
+ @torch.no_grad()
98
+ def predict_genre(audio):
99
+ if audio is None:
100
+ return {g: 0.0 for g in GENRES}
101
+
102
+ waveform = preprocess_audio(audio)
103
+ crops = get_tta_crops(waveform, CROP_LEN)
104
+
105
+ avg_probs = torch.zeros(10)
106
+ for crop in crops:
107
+ mel = wave_to_mel(crop).unsqueeze(0).to(DEVICE)
108
+ logits = model(mel)
109
+ probs = torch.softmax(logits, dim=1).squeeze(0).cpu()
110
+ avg_probs += probs
111
+ avg_probs /= len(crops)
112
+
113
+ result = {GENRES[i]: float(avg_probs[i]) for i in range(10)}
114
+ return result
115
+
116
+ DESCRIPTION = """
117
+ ## Messy Mashup — Music Genre Classifier
118
+
119
+ Upload a music clip or record from your microphone and the AI will
120
+ identify the genre from 10 categories: **Blues, Classical, Country, Disco,
121
+ HipHop, Jazz, Metal, Pop, Reggae, Rock**.
122
+
123
+ ### How it works
124
+ - **Model:** EfficientNet-B0 fine-tuned on 10,000+ synthetic mashups
125
+ - **Test-Time Augmentation:** 3 crops (start, middle, end) averaged for robustness
126
+ - **Training Score:** 0.90 Macro F1
127
+
128
+ *Built for BSDA2001P: Introduction to DL and GenAI — IIT Madras*
129
+ """
130
+
131
+ demo = gr.Interface(
132
+ fn=predict_genre,
133
+ inputs=gr.Audio(
134
+ label="Upload or Record Audio",
135
+ type="numpy"
136
+ ),
137
+ outputs=gr.Label(
138
+ num_top_classes=10,
139
+ label="Genre Prediction"
140
+ ),
141
+ title="Messy Mashup Genre Classifier",
142
+ description=DESCRIPTION,
143
+ theme=gr.themes.Soft(
144
+ primary_hue="violet",
145
+ secondary_hue="blue",
146
+ ),
147
+ allow_flagging="never",
148
+ analytics_enabled=False,
149
+ )
150
+
151
+ if __name__ == "__main__":
152
+ demo.launch()
best_effnet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc8bc0ba0e496a4d9a0100954ed71b43f5935bbfccd84b825a40d7d98d9ca305
3
+ size 16388071
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ torchvision
4
+ gradio>=4.0
5
+ numpy