import streamlit as st import torch import torchaudio import numpy as np from transformers import ASTForAudioClassification, ASTFeatureExtractor import librosa # ================= CONFIG ================= SR = 16000 CLIP_DURATION = 10 CLIP_LEN = SR * CLIP_DURATION SEGMENTS_PER_CLIP = 3 SEG_LEN = CLIP_LEN // SEGMENTS_PER_CLIP GENRES = ["blues","classical","country","disco","hiphop", "jazz","metal","pop","reggae","rock"] IDX2GENRE = {i: g for i, g in enumerate(GENRES)} DEVICE = "cpu" # ================= LOAD MODEL ================= @st.cache_resource def load_model(): model = ASTForAudioClassification.from_pretrained( "22ds2000101/MessyMashUp_AST_Final" ) model.to(DEVICE) model.eval() extractor = ASTFeatureExtractor.from_pretrained( "22ds2000101/MessyMashUp_AST_Final" ) return model, extractor model, feature_extractor = load_model() def load_audio(file): try: y, sr = librosa.load(file, sr=SR) # mono already handled by librosa if len(y) >= CLIP_LEN: y = y[:CLIP_LEN] else: y = np.pad(y, (0, CLIP_LEN - len(y))) return torch.tensor(y).float() except Exception as e: st.error(f"Error loading audio: {e}") return None # ================= INFERENCE ================= def predict(waveform): segments = [ waveform[i*SEG_LEN:(i+1)*SEG_LEN].numpy() for i in range(SEGMENTS_PER_CLIP) ] inputs = feature_extractor( segments, sampling_rate=SR, return_tensors="pt" ) xb = inputs.input_values.to(DEVICE) with torch.no_grad(): logits = model(input_values=xb).logits logits = logits.mean(0) pred = logits.argmax().item() probs = torch.softmax(logits, dim=0).cpu().numpy() return IDX2GENRE[pred], probs # ================= UI ================= st.title("AST Model for classification of Music") st.write("Upload a 10-second audio clip") uploaded_file = st.file_uploader("Upload Audio", type=["wav", "mp3"]) if uploaded_file is not None: st.audio(uploaded_file) waveform = load_audio(uploaded_file) if waveform is not None: if st.button("Predict"): with st.spinner("Predicting..."): genre, probs = predict(waveform) st.success(f"Predicted Genre: {genre.upper()}") st.subheader("Confidence Scores") for i, g in IDX2GENRE.items(): st.write(f"{g}: {probs[i]:.4f}")