Spaces:
Running
Running
File size: 4,402 Bytes
f275627 2fe9c55 f275627 2fe9c55 f275627 2fe9c55 f275627 9d54aaa f275627 cb0c894 9afc401 f275627 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import torchaudio.transforms as T
import torchvision.models as models
import gradio as gr
import numpy as np
import os
SAMPLE_RATE = 22050
CROP_SEC = 6.0
CROP_LEN = int(SAMPLE_RATE * CROP_SEC)
N_MELS = 128
N_FFT = 2048
HOP_LENGTH = 512
GENRES = sorted(["blues", "classical", "country", "disco", "hiphop",
"jazz", "metal", "pop", "reggae", "rock"])
GENRE2ID = {g: i for i, g in enumerate(GENRES)}
ID2GENRE = {i: g for i, g in enumerate(GENRES)}
DEVICE = torch.device("cpu")
class PretrainedEfficientNet(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.efficientnet = models.efficientnet_b0(weights=None)
old = self.efficientnet.features[0][0]
self.efficientnet.features[0][0] = nn.Conv2d(
1, old.out_channels, kernel_size=old.kernel_size,
stride=old.stride, padding=old.padding, bias=False)
self.efficientnet.classifier[1] = nn.Linear(
self.efficientnet.classifier[1].in_features, num_classes)
def forward(self, x):
return self.efficientnet(x)
model = PretrainedEfficientNet(num_classes=10)
weights_path = os.path.join(os.path.dirname(__file__), "best_effnet.pth")
state_dict = torch.load(weights_path, map_location=DEVICE, weights_only=True)
model.load_state_dict(state_dict)
model.eval()
model.to(DEVICE)
mel_transform = T.MelSpectrogram(
sample_rate=SAMPLE_RATE, n_fft=N_FFT,
hop_length=HOP_LENGTH, n_mels=N_MELS)
db_transform = T.AmplitudeToDB()
def preprocess_audio(audio_tuple):
sr, waveform_np = audio_tuple
waveform = torch.tensor(waveform_np, dtype=torch.float32)
if waveform.dim() == 2:
waveform = waveform.mean(dim=-1)
waveform = waveform.unsqueeze(0)
if waveform.abs().max() > 2.0:
waveform = waveform / 32768.0
if sr != SAMPLE_RATE:
waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE)
return waveform
def crop_or_pad(waveform, length):
if waveform.shape[1] >= length:
start = (waveform.shape[1] - length) // 2
return waveform[:, start:start + length]
return F.pad(waveform, (0, length - waveform.shape[1]))
def get_tta_crops(waveform, crop_len):
crops = []
total = waveform.shape[1]
if total <= crop_len:
padded = F.pad(waveform, (0, crop_len - total))
return [padded]
crops.append(waveform[:, :crop_len])
mid = (total - crop_len) // 2
crops.append(waveform[:, mid:mid + crop_len])
crops.append(waveform[:, -crop_len:])
return crops
def wave_to_mel(wave):
mel = mel_transform(wave)
mel_db = db_transform(mel)
mel_db = (mel_db - mel_db.mean()) / (mel_db.std() + 1e-6)
return mel_db
@torch.no_grad()
def predict_genre(audio):
if audio is None:
return {g: 0.0 for g in GENRES}
waveform = preprocess_audio(audio)
crops = get_tta_crops(waveform, CROP_LEN)
avg_probs = torch.zeros(10)
for crop in crops:
mel = wave_to_mel(crop).unsqueeze(0).to(DEVICE)
logits = model(mel)
probs = torch.softmax(logits, dim=1).squeeze(0).cpu()
avg_probs += probs
avg_probs /= len(crops)
result = {GENRES[i]: float(avg_probs[i]) for i in range(10)}
return result
DESCRIPTION = """
## Messy Mashup — Music Genre Classifier
Upload a music clip or record from your microphone and the AI will
identify the genre from 10 categories: **Blues, Classical, Country, Disco,
HipHop, Jazz, Metal, Pop, Reggae, Rock**.
### How it works
- **Model:** EfficientNet-B0 fine-tuned on 10,000+ synthetic mashups
- **Test-Time Augmentation:** 3 crops (start, middle, end) averaged for robustness
- **Training Score:** 0.90 Macro F1
*Built for BSDA2001P: Introduction to DL and GenAI - IIT Madras*
"""
demo = gr.Interface(
fn=predict_genre,
inputs=gr.Audio(
label="Upload or Record Audio",
type="numpy"
),
outputs=gr.Label(
num_top_classes=10,
label="Genre Prediction"
),
title="Messy Mashup Genre Classifier",
description=DESCRIPTION,
examples=[
["song0002.wav"],
["song0003.wav"],
["song0009.wav"]
]
)
if __name__ == "__main__":
demo.launch()
|