import os import tempfile from typing import List, Tuple import librosa import numpy as np import streamlit as st import torch from transformers import AutoFeatureExtractor, AutoModelForAudioClassification # ========================= # Config # ========================= SAMPLE_RATE = 16000 SEGMENT_SEC = 10 SEGMENT_LEN = SAMPLE_RATE * SEGMENT_SEC TOP_K = 5 # Replace this with your actual Hugging Face model repo, for example: # MODEL_REPO = "your-username/ast-messy-mashup" # Try again MODEL_REPO = os.getenv("MODEL_REPO", "22ds2000101/20260411_best_ast_model.pt") LABELS = [ "blues", "classical", "country", "disco", "hiphop", "jazz", "metal", "pop", "reggae", "rock", ] @st.cache_resource(show_spinner=True) def load_model_and_extractor(): feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_REPO) model = AutoModelForAudioClassification.from_pretrained(MODEL_REPO) model.eval() return feature_extractor, model def load_audio(uploaded_file) -> np.ndarray: with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: tmp.write(uploaded_file.getbuffer()) tmp_path = tmp.name try: audio, _ = librosa.load(tmp_path, sr=SAMPLE_RATE, mono=True) return audio.astype(np.float32) finally: if os.path.exists(tmp_path): os.remove(tmp_path) def get_segments(audio: np.ndarray, max_segments: int = 3) -> List[np.ndarray]: """ Create deterministic 10-second segments. - If audio is short, pad once. - If audio is long, take evenly spaced windows. """ if len(audio) <= SEGMENT_LEN: padded = np.pad(audio, (0, SEGMENT_LEN - len(audio))) return [padded.astype(np.float32)] if max_segments <= 1: return [audio[:SEGMENT_LEN].astype(np.float32)] max_start = len(audio) - SEGMENT_LEN starts = np.linspace(0, max_start, num=max_segments, dtype=int) segments = [audio[s : s + SEGMENT_LEN].astype(np.float32) for s in starts] return segments def predict_audio( audio: np.ndarray, feature_extractor, model, max_segments: int = 3, ) -> Tuple[str, List[Tuple[str, float]]]: segments = get_segments(audio, max_segments=max_segments) probs_per_segment = [] with torch.no_grad(): for segment in segments: inputs = feature_extractor( segment, sampling_rate=SAMPLE_RATE, return_tensors="pt", ) logits = model(**inputs).logits probs = torch.softmax(logits, dim=-1).cpu().numpy()[0] probs_per_segment.append(probs) mean_probs = np.mean(np.stack(probs_per_segment), axis=0) pred_idx = int(np.argmax(mean_probs)) predicted_label = LABELS[pred_idx] ranked = sorted( [(LABELS[i], float(mean_probs[i])) for i in range(len(LABELS))], key=lambda x: x[1], reverse=True, ) return predicted_label, ranked[:TOP_K] # ========================= # UI # ========================= st.set_page_config(page_title="Messy Mashup Genre Classifier", page_icon="🎵", layout="centered") st.title("🎵 Messy Mashup Genre Classifier") st.markdown( "Upload an audio file and the app will predict its genre using your fine-tuned " "Audio Spectrogram Transformer model." ) with st.expander("Model settings", expanded=False): st.write(f"Model repo: `{MODEL_REPO}`") st.write(f"Sample rate: `{SAMPLE_RATE}` Hz") st.write(f"Segment length: `{SEGMENT_SEC}` seconds") st.write("Inference uses up to 3 evenly spaced segments and averages class probabilities.") uploaded_file = st.file_uploader( "Upload audio", type=["wav", "mp3", "flac", "ogg", "m4a"], ) if MODEL_REPO == "your-username/your-model-repo": st.warning( "Set your model repo first. In the Space Settings, add an environment variable named " "`MODEL_REPO`, or replace the default value inside `app.py`." ) if uploaded_file is not None: st.audio(uploaded_file) try: feature_extractor, model = load_model_and_extractor() with st.spinner("Running inference..."): audio = load_audio(uploaded_file) predicted_label, top_predictions = predict_audio(audio, feature_extractor, model) st.success(f"Predicted genre: **{predicted_label}**") st.subheader("Top predictions") for label, score in top_predictions: st.progress(min(max(score, 0.0), 1.0), text=f"{label}: {score:.4f}") except Exception as e: st.error("The app could not complete inference.") st.exception(e) st.markdown("---") st.caption( "Tip: for deployment, upload the fine-tuned model to a Hugging Face model repository and point " "this app to it with `MODEL_REPO`." )