| import os |
| import tempfile |
| from typing import List, Tuple |
|
|
| import librosa |
| import numpy as np |
| import streamlit as st |
| import torch |
| from transformers import AutoFeatureExtractor, AutoModelForAudioClassification |
|
|
| |
| |
| |
| SAMPLE_RATE = 16000 |
| SEGMENT_SEC = 10 |
| SEGMENT_LEN = SAMPLE_RATE * SEGMENT_SEC |
| TOP_K = 5 |
|
|
| |
| |
| |
| MODEL_REPO = os.getenv("MODEL_REPO", "22ds2000101/20260411_best_ast_model.pt") |
|
|
| LABELS = [ |
| "blues", |
| "classical", |
| "country", |
| "disco", |
| "hiphop", |
| "jazz", |
| "metal", |
| "pop", |
| "reggae", |
| "rock", |
| ] |
|
|
|
|
| @st.cache_resource(show_spinner=True) |
| def load_model_and_extractor(): |
| feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_REPO) |
| model = AutoModelForAudioClassification.from_pretrained(MODEL_REPO) |
| model.eval() |
| return feature_extractor, model |
|
|
|
|
| def load_audio(uploaded_file) -> np.ndarray: |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: |
| tmp.write(uploaded_file.getbuffer()) |
| tmp_path = tmp.name |
|
|
| try: |
| audio, _ = librosa.load(tmp_path, sr=SAMPLE_RATE, mono=True) |
| return audio.astype(np.float32) |
| finally: |
| if os.path.exists(tmp_path): |
| os.remove(tmp_path) |
|
|
|
|
|
|
| def get_segments(audio: np.ndarray, max_segments: int = 3) -> List[np.ndarray]: |
| """ |
| Create deterministic 10-second segments. |
| - If audio is short, pad once. |
| - If audio is long, take evenly spaced windows. |
| """ |
| if len(audio) <= SEGMENT_LEN: |
| padded = np.pad(audio, (0, SEGMENT_LEN - len(audio))) |
| return [padded.astype(np.float32)] |
|
|
| if max_segments <= 1: |
| return [audio[:SEGMENT_LEN].astype(np.float32)] |
|
|
| max_start = len(audio) - SEGMENT_LEN |
| starts = np.linspace(0, max_start, num=max_segments, dtype=int) |
| segments = [audio[s : s + SEGMENT_LEN].astype(np.float32) for s in starts] |
| return segments |
|
|
|
|
|
|
| def predict_audio( |
| audio: np.ndarray, |
| feature_extractor, |
| model, |
| max_segments: int = 3, |
| ) -> Tuple[str, List[Tuple[str, float]]]: |
| segments = get_segments(audio, max_segments=max_segments) |
|
|
| probs_per_segment = [] |
|
|
| with torch.no_grad(): |
| for segment in segments: |
| inputs = feature_extractor( |
| segment, |
| sampling_rate=SAMPLE_RATE, |
| return_tensors="pt", |
| ) |
| logits = model(**inputs).logits |
| probs = torch.softmax(logits, dim=-1).cpu().numpy()[0] |
| probs_per_segment.append(probs) |
|
|
| mean_probs = np.mean(np.stack(probs_per_segment), axis=0) |
| pred_idx = int(np.argmax(mean_probs)) |
| predicted_label = LABELS[pred_idx] |
|
|
| ranked = sorted( |
| [(LABELS[i], float(mean_probs[i])) for i in range(len(LABELS))], |
| key=lambda x: x[1], |
| reverse=True, |
| ) |
|
|
| return predicted_label, ranked[:TOP_K] |
|
|
|
|
| |
| |
| |
| st.set_page_config(page_title="Messy Mashup Genre Classifier", page_icon="🎵", layout="centered") |
|
|
| st.title("🎵 Messy Mashup Genre Classifier") |
| st.markdown( |
| "Upload an audio file and the app will predict its genre using your fine-tuned " |
| "Audio Spectrogram Transformer model." |
| ) |
|
|
| with st.expander("Model settings", expanded=False): |
| st.write(f"Model repo: `{MODEL_REPO}`") |
| st.write(f"Sample rate: `{SAMPLE_RATE}` Hz") |
| st.write(f"Segment length: `{SEGMENT_SEC}` seconds") |
| st.write("Inference uses up to 3 evenly spaced segments and averages class probabilities.") |
|
|
| uploaded_file = st.file_uploader( |
| "Upload audio", |
| type=["wav", "mp3", "flac", "ogg", "m4a"], |
| ) |
|
|
| if MODEL_REPO == "your-username/your-model-repo": |
| st.warning( |
| "Set your model repo first. In the Space Settings, add an environment variable named " |
| "`MODEL_REPO`, or replace the default value inside `app.py`." |
| ) |
|
|
| if uploaded_file is not None: |
| st.audio(uploaded_file) |
|
|
| try: |
| feature_extractor, model = load_model_and_extractor() |
|
|
| with st.spinner("Running inference..."): |
| audio = load_audio(uploaded_file) |
| predicted_label, top_predictions = predict_audio(audio, feature_extractor, model) |
|
|
| st.success(f"Predicted genre: **{predicted_label}**") |
|
|
| st.subheader("Top predictions") |
| for label, score in top_predictions: |
| st.progress(min(max(score, 0.0), 1.0), text=f"{label}: {score:.4f}") |
|
|
| except Exception as e: |
| st.error("The app could not complete inference.") |
| st.exception(e) |
|
|
| st.markdown("---") |
| st.caption( |
| "Tip: for deployment, upload the fine-tuned model to a Hugging Face model repository and point " |
| "this app to it with `MODEL_REPO`." |
| ) |
|
|