import gradio as gr import torch import librosa import numpy as np from transformers import ASTFeatureExtractor, ASTForAudioClassification # CONFIG─────── HF_REPO = "vectorverse/Messy_Mashup_Genre_Classifier" SAMPLE_RATE = 16000 DURATION = 20 MAX_LENGTH = SAMPLE_RATE * DURATION N_TTA = 5 GENRES = ["blues", "classical", "country", "disco", "hiphop", "jazz", "metal", "pop", "reggae", "rock"] id2label = {i: g for i, g in enumerate(GENRES)} GENRE_EMOJI = { "blues": "🎸", "classical": "🎻", "country": "🤠", "disco": "🪩", "hiphop": "🎤", "jazz": "🎺", "metal": "🤘", "pop": "🎵", "reggae": "🌴", "rock": "🔥" } #LOAD MODEL (once at startup) print("Loading model...") feature_extractor = ASTFeatureExtractor.from_pretrained(HF_REPO) model = ASTForAudioClassification.from_pretrained(HF_REPO) model.eval() DEVICE = "cuda" if torch.cuda.is_available() else "cpu" model.to(DEVICE) print(f"Model ready on {DEVICE}!") # AUDIO HELPERS def load_audio(path): y, _ = librosa.load(path, sr=SAMPLE_RATE, mono=True) return y.astype(np.float32) def normalize(y): return y / (np.max(np.abs(y)) + 1e-6) def random_crop(y): if len(y) >= MAX_LENGTH: start = np.random.randint(0, len(y) - MAX_LENGTH) return y[start:start + MAX_LENGTH] return np.pad(y, (0, MAX_LENGTH - len(y))) def center_crop(y): if len(y) >= MAX_LENGTH: start = (len(y) - MAX_LENGTH) // 2 return y[start:start + MAX_LENGTH] return np.pad(y, (0, MAX_LENGTH - len(y))) # PREDICTION WITH TTA def predict(audio_path): if audio_path is None: return "Please upload an audio file.", None try: audio = load_audio(audio_path) except Exception as e: return f"Error loading audio: {e}", None # TTA: center crop + N_TTA-1 random crops crops = [center_crop(audio)] for _ in range(N_TTA - 1): crops.append(random_crop(audio)) all_probs = [] for crop in crops: crop = normalize(crop) inputs = feature_extractor( crop, sampling_rate=SAMPLE_RATE, return_tensors="pt" ) input_values = inputs["input_values"].to(DEVICE) with torch.no_grad(): logits = model(input_values=input_values).logits probs = torch.softmax(logits, dim=1).cpu().numpy() all_probs.append(probs) avg_probs = np.mean(all_probs, axis=0)[0] # shape: (10,) pred_idx = int(np.argmax(avg_probs)) pred_genre = id2label[pred_idx] confidence = float(avg_probs[pred_idx]) * 100 # Build label dict for Gradio bar chart label_probs = { f"{GENRE_EMOJI.get(id2label[i], '')} {id2label[i].capitalize()}": float(avg_probs[i]) for i in range(len(GENRES)) } result = f"## {GENRE_EMOJI.get(pred_genre, '')} {pred_genre.capitalize()}\n**Confidence: {confidence:.1f}%**" return result, label_probs # GRADIO UI with gr.Blocks(title="🎵 Music Genre Classifier") as demo: gr.Markdown( """ # 🎵 Music Genre Classifier Upload any music file and the model will predict its genre. Supports: blues, classical, country, disco, hip-hop, jazz, metal, pop, reggae, rock. *Model: Fine-tuned Audio Spectrogram Transformer (AST) · TTA x5* """ ) with gr.Row(): with gr.Column(scale=1): audio_input = gr.Audio( label="Upload Audio", type="filepath", sources=["upload", "microphone"] ) predict_btn = gr.Button("🎯 Predict Genre", variant="primary") with gr.Column(scale=1): result_md = gr.Markdown(label="Prediction") prob_chart = gr.Label(label="Genre Probabilities", num_top_classes=10) predict_btn.click( fn = predict, inputs = [audio_input], outputs = [result_md, prob_chart] ) gr.Examples( examples = [], # add example audio paths here if you have them inputs = [audio_input], label = "Examples" ) if __name__ == "__main__": demo.launch()