| import gradio as gr |
| import torch |
| import librosa |
| import numpy as np |
| from transformers import ASTFeatureExtractor, ASTForAudioClassification |
|
|
| |
| HF_REPO = "vectorverse/Messy_Mashup_Genre_Classifier" |
| SAMPLE_RATE = 16000 |
| DURATION = 20 |
| MAX_LENGTH = SAMPLE_RATE * DURATION |
| N_TTA = 5 |
|
|
| GENRES = ["blues", "classical", "country", "disco", "hiphop", |
| "jazz", "metal", "pop", "reggae", "rock"] |
| id2label = {i: g for i, g in enumerate(GENRES)} |
|
|
| GENRE_EMOJI = { |
| "blues": "πΈ", "classical": "π»", "country": "π€ ", "disco": "πͺ©", |
| "hiphop": "π€", "jazz": "πΊ", "metal": "π€", "pop": "π΅", |
| "reggae": "π΄", "rock": "π₯" |
| } |
|
|
| |
| print("Loading model...") |
| feature_extractor = ASTFeatureExtractor.from_pretrained(HF_REPO) |
| model = ASTForAudioClassification.from_pretrained(HF_REPO) |
| model.eval() |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| model.to(DEVICE) |
| print(f"Model ready on {DEVICE}!") |
|
|
| |
| def load_audio(path): |
| y, _ = librosa.load(path, sr=SAMPLE_RATE, mono=True) |
| return y.astype(np.float32) |
|
|
| def normalize(y): |
| return y / (np.max(np.abs(y)) + 1e-6) |
|
|
| def random_crop(y): |
| if len(y) >= MAX_LENGTH: |
| start = np.random.randint(0, len(y) - MAX_LENGTH) |
| return y[start:start + MAX_LENGTH] |
| return np.pad(y, (0, MAX_LENGTH - len(y))) |
|
|
| def center_crop(y): |
| if len(y) >= MAX_LENGTH: |
| start = (len(y) - MAX_LENGTH) // 2 |
| return y[start:start + MAX_LENGTH] |
| return np.pad(y, (0, MAX_LENGTH - len(y))) |
|
|
| |
| def predict(audio_path): |
| if audio_path is None: |
| return "Please upload an audio file.", None |
|
|
| try: |
| audio = load_audio(audio_path) |
| except Exception as e: |
| return f"Error loading audio: {e}", None |
|
|
| |
| crops = [center_crop(audio)] |
| for _ in range(N_TTA - 1): |
| crops.append(random_crop(audio)) |
|
|
| all_probs = [] |
| for crop in crops: |
| crop = normalize(crop) |
| inputs = feature_extractor( |
| crop, sampling_rate=SAMPLE_RATE, return_tensors="pt" |
| ) |
| input_values = inputs["input_values"].to(DEVICE) |
| with torch.no_grad(): |
| logits = model(input_values=input_values).logits |
| probs = torch.softmax(logits, dim=1).cpu().numpy() |
| all_probs.append(probs) |
|
|
| avg_probs = np.mean(all_probs, axis=0)[0] |
| pred_idx = int(np.argmax(avg_probs)) |
| pred_genre = id2label[pred_idx] |
| confidence = float(avg_probs[pred_idx]) * 100 |
|
|
| |
| label_probs = { |
| f"{GENRE_EMOJI.get(id2label[i], '')} {id2label[i].capitalize()}": float(avg_probs[i]) |
| for i in range(len(GENRES)) |
| } |
|
|
| result = f"## {GENRE_EMOJI.get(pred_genre, '')} {pred_genre.capitalize()}\n**Confidence: {confidence:.1f}%**" |
| return result, label_probs |
|
|
| |
| with gr.Blocks(title="π΅ Music Genre Classifier") as demo: |
| gr.Markdown( |
| """ |
| # π΅ Music Genre Classifier |
| Upload any music file and the model will predict its genre. |
| Supports: blues, classical, country, disco, hip-hop, jazz, metal, pop, reggae, rock. |
| |
| *Model: Fine-tuned Audio Spectrogram Transformer (AST) Β· TTA x5* |
| """ |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| audio_input = gr.Audio( |
| label="Upload Audio", |
| type="filepath", |
| sources=["upload", "microphone"] |
| ) |
| predict_btn = gr.Button("π― Predict Genre", variant="primary") |
|
|
| with gr.Column(scale=1): |
| result_md = gr.Markdown(label="Prediction") |
| prob_chart = gr.Label(label="Genre Probabilities", num_top_classes=10) |
|
|
| predict_btn.click( |
| fn = predict, |
| inputs = [audio_input], |
| outputs = [result_md, prob_chart] |
| ) |
|
|
| gr.Examples( |
| examples = [], |
| inputs = [audio_input], |
| label = "Examples" |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |