File size: 2,524 Bytes
3dd49c1 bce0b3d 9cfa10f bce0b3d 9cfa10f bce0b3d 9cfa10f bce0b3d 9cfa10f bce0b3d 9cfa10f bce0b3d 9cfa10f bce0b3d 9cfa10f bce0b3d 9cfa10f bce0b3d 9cfa10f bce0b3d 9cfa10f bce0b3d 9cfa10f bce0b3d 9cfa10f bce0b3d 9cfa10f bce0b3d 9cfa10f bce0b3d 9cfa10f bce0b3d 9cfa10f bce0b3d 9cfa10f bce0b3d 3dd49c1 bce0b3d 9cfa10f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import streamlit as st
import torch
import torchaudio
import numpy as np
from transformers import ASTForAudioClassification, ASTFeatureExtractor
import librosa
# ================= CONFIG =================
SR = 16000
CLIP_DURATION = 10
CLIP_LEN = SR * CLIP_DURATION
SEGMENTS_PER_CLIP = 3
SEG_LEN = CLIP_LEN // SEGMENTS_PER_CLIP
GENRES = ["blues","classical","country","disco","hiphop",
"jazz","metal","pop","reggae","rock"]
IDX2GENRE = {i: g for i, g in enumerate(GENRES)}
DEVICE = "cpu"
# ================= LOAD MODEL =================
@st.cache_resource
def load_model():
model = ASTForAudioClassification.from_pretrained(
"22ds2000101/MessyMashUp_AST_Final"
)
model.to(DEVICE)
model.eval()
extractor = ASTFeatureExtractor.from_pretrained(
"22ds2000101/MessyMashUp_AST_Final"
)
return model, extractor
model, feature_extractor = load_model()
def load_audio(file):
try:
y, sr = librosa.load(file, sr=SR)
# mono already handled by librosa
if len(y) >= CLIP_LEN:
y = y[:CLIP_LEN]
else:
y = np.pad(y, (0, CLIP_LEN - len(y)))
return torch.tensor(y).float()
except Exception as e:
st.error(f"Error loading audio: {e}")
return None
# ================= INFERENCE =================
def predict(waveform):
segments = [
waveform[i*SEG_LEN:(i+1)*SEG_LEN].numpy()
for i in range(SEGMENTS_PER_CLIP)
]
inputs = feature_extractor(
segments,
sampling_rate=SR,
return_tensors="pt"
)
xb = inputs.input_values.to(DEVICE)
with torch.no_grad():
logits = model(input_values=xb).logits
logits = logits.mean(0)
pred = logits.argmax().item()
probs = torch.softmax(logits, dim=0).cpu().numpy()
return IDX2GENRE[pred], probs
# ================= UI =================
st.title("AST Model for classification of Music")
st.write("Upload a 10-second audio clip")
uploaded_file = st.file_uploader("Upload Audio", type=["wav", "mp3"])
if uploaded_file is not None:
st.audio(uploaded_file)
waveform = load_audio(uploaded_file)
if waveform is not None:
if st.button("Predict"):
with st.spinner("Predicting..."):
genre, probs = predict(waveform)
st.success(f"Predicted Genre: {genre.upper()}")
st.subheader("Confidence Scores")
for i, g in IDX2GENRE.items():
st.write(f"{g}: {probs[i]:.4f}") |