22ds2000101's picture
Update app.py
3dd49c1 verified
import streamlit as st
import torch
import torchaudio
import numpy as np
from transformers import ASTForAudioClassification, ASTFeatureExtractor
import librosa
# ================= CONFIG =================
SR = 16000
CLIP_DURATION = 10
CLIP_LEN = SR * CLIP_DURATION
SEGMENTS_PER_CLIP = 3
SEG_LEN = CLIP_LEN // SEGMENTS_PER_CLIP
GENRES = ["blues","classical","country","disco","hiphop",
"jazz","metal","pop","reggae","rock"]
IDX2GENRE = {i: g for i, g in enumerate(GENRES)}
DEVICE = "cpu"
# ================= LOAD MODEL =================
@st.cache_resource
def load_model():
model = ASTForAudioClassification.from_pretrained(
"22ds2000101/MessyMashUp_AST_Final"
)
model.to(DEVICE)
model.eval()
extractor = ASTFeatureExtractor.from_pretrained(
"22ds2000101/MessyMashUp_AST_Final"
)
return model, extractor
model, feature_extractor = load_model()
def load_audio(file):
try:
y, sr = librosa.load(file, sr=SR)
# mono already handled by librosa
if len(y) >= CLIP_LEN:
y = y[:CLIP_LEN]
else:
y = np.pad(y, (0, CLIP_LEN - len(y)))
return torch.tensor(y).float()
except Exception as e:
st.error(f"Error loading audio: {e}")
return None
# ================= INFERENCE =================
def predict(waveform):
segments = [
waveform[i*SEG_LEN:(i+1)*SEG_LEN].numpy()
for i in range(SEGMENTS_PER_CLIP)
]
inputs = feature_extractor(
segments,
sampling_rate=SR,
return_tensors="pt"
)
xb = inputs.input_values.to(DEVICE)
with torch.no_grad():
logits = model(input_values=xb).logits
logits = logits.mean(0)
pred = logits.argmax().item()
probs = torch.softmax(logits, dim=0).cpu().numpy()
return IDX2GENRE[pred], probs
# ================= UI =================
st.title("AST Model for classification of Music")
st.write("Upload a 10-second audio clip")
uploaded_file = st.file_uploader("Upload Audio", type=["wav", "mp3"])
if uploaded_file is not None:
st.audio(uploaded_file)
waveform = load_audio(uploaded_file)
if waveform is not None:
if st.button("Predict"):
with st.spinner("Predicting..."):
genre, probs = predict(waveform)
st.success(f"Predicted Genre: {genre.upper()}")
st.subheader("Confidence Scores")
for i, g in IDX2GENRE.items():
st.write(f"{g}: {probs[i]:.4f}")