File size: 4,173 Bytes
bd43bae 11c8e8e bd43bae 11c8e8e bd43bae 11c8e8e bd43bae 11c8e8e bd43bae 11c8e8e bd43bae 11c8e8e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | import gradio as gr
import torch
import librosa
import numpy as np
from transformers import ASTFeatureExtractor, ASTForAudioClassification
# CONFIGβββββββ
HF_REPO = "vectorverse/Messy_Mashup_Genre_Classifier"
SAMPLE_RATE = 16000
DURATION = 20
MAX_LENGTH = SAMPLE_RATE * DURATION
N_TTA = 5
GENRES = ["blues", "classical", "country", "disco", "hiphop",
"jazz", "metal", "pop", "reggae", "rock"]
id2label = {i: g for i, g in enumerate(GENRES)}
GENRE_EMOJI = {
"blues": "πΈ", "classical": "π»", "country": "π€ ", "disco": "πͺ©",
"hiphop": "π€", "jazz": "πΊ", "metal": "π€", "pop": "π΅",
"reggae": "π΄", "rock": "π₯"
}
#LOAD MODEL (once at startup)
print("Loading model...")
feature_extractor = ASTFeatureExtractor.from_pretrained(HF_REPO)
model = ASTForAudioClassification.from_pretrained(HF_REPO)
model.eval()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model.to(DEVICE)
print(f"Model ready on {DEVICE}!")
# AUDIO HELPERS
def load_audio(path):
y, _ = librosa.load(path, sr=SAMPLE_RATE, mono=True)
return y.astype(np.float32)
def normalize(y):
return y / (np.max(np.abs(y)) + 1e-6)
def random_crop(y):
if len(y) >= MAX_LENGTH:
start = np.random.randint(0, len(y) - MAX_LENGTH)
return y[start:start + MAX_LENGTH]
return np.pad(y, (0, MAX_LENGTH - len(y)))
def center_crop(y):
if len(y) >= MAX_LENGTH:
start = (len(y) - MAX_LENGTH) // 2
return y[start:start + MAX_LENGTH]
return np.pad(y, (0, MAX_LENGTH - len(y)))
# PREDICTION WITH TTA
def predict(audio_path):
if audio_path is None:
return "Please upload an audio file.", None
try:
audio = load_audio(audio_path)
except Exception as e:
return f"Error loading audio: {e}", None
# TTA: center crop + N_TTA-1 random crops
crops = [center_crop(audio)]
for _ in range(N_TTA - 1):
crops.append(random_crop(audio))
all_probs = []
for crop in crops:
crop = normalize(crop)
inputs = feature_extractor(
crop, sampling_rate=SAMPLE_RATE, return_tensors="pt"
)
input_values = inputs["input_values"].to(DEVICE)
with torch.no_grad():
logits = model(input_values=input_values).logits
probs = torch.softmax(logits, dim=1).cpu().numpy()
all_probs.append(probs)
avg_probs = np.mean(all_probs, axis=0)[0] # shape: (10,)
pred_idx = int(np.argmax(avg_probs))
pred_genre = id2label[pred_idx]
confidence = float(avg_probs[pred_idx]) * 100
# Build label dict for Gradio bar chart
label_probs = {
f"{GENRE_EMOJI.get(id2label[i], '')} {id2label[i].capitalize()}": float(avg_probs[i])
for i in range(len(GENRES))
}
result = f"## {GENRE_EMOJI.get(pred_genre, '')} {pred_genre.capitalize()}\n**Confidence: {confidence:.1f}%**"
return result, label_probs
# GRADIO UI
with gr.Blocks(title="π΅ Music Genre Classifier") as demo:
gr.Markdown(
"""
# π΅ Music Genre Classifier
Upload any music file and the model will predict its genre.
Supports: blues, classical, country, disco, hip-hop, jazz, metal, pop, reggae, rock.
*Model: Fine-tuned Audio Spectrogram Transformer (AST) Β· TTA x5*
"""
)
with gr.Row():
with gr.Column(scale=1):
audio_input = gr.Audio(
label="Upload Audio",
type="filepath",
sources=["upload", "microphone"]
)
predict_btn = gr.Button("π― Predict Genre", variant="primary")
with gr.Column(scale=1):
result_md = gr.Markdown(label="Prediction")
prob_chart = gr.Label(label="Genre Probabilities", num_top_classes=10)
predict_btn.click(
fn = predict,
inputs = [audio_input],
outputs = [result_md, prob_chart]
)
gr.Examples(
examples = [], # add example audio paths here if you have them
inputs = [audio_input],
label = "Examples"
)
if __name__ == "__main__":
demo.launch() |