|
|
| import streamlit as st |
| import torch |
| import torchaudio |
| import numpy as np |
| from transformers import ASTForAudioClassification, ASTFeatureExtractor |
| import librosa |
|
|
| |
| SR = 16000 |
| CLIP_DURATION = 10 |
| CLIP_LEN = SR * CLIP_DURATION |
|
|
| SEGMENTS_PER_CLIP = 3 |
| SEG_LEN = CLIP_LEN // SEGMENTS_PER_CLIP |
|
|
| GENRES = ["blues","classical","country","disco","hiphop", |
| "jazz","metal","pop","reggae","rock"] |
|
|
| IDX2GENRE = {i: g for i, g in enumerate(GENRES)} |
|
|
| DEVICE = "cpu" |
|
|
| |
| @st.cache_resource |
| def load_model(): |
| model = ASTForAudioClassification.from_pretrained( |
| "22ds2000101/MessyMashUp_AST_Final" |
| ) |
| model.to(DEVICE) |
| model.eval() |
|
|
| extractor = ASTFeatureExtractor.from_pretrained( |
| "22ds2000101/MessyMashUp_AST_Final" |
| ) |
|
|
| return model, extractor |
|
|
| model, feature_extractor = load_model() |
|
|
|
|
|
|
|
|
|
|
|
|
| def load_audio(file): |
| try: |
| y, sr = librosa.load(file, sr=SR) |
|
|
| |
| if len(y) >= CLIP_LEN: |
| y = y[:CLIP_LEN] |
| else: |
| y = np.pad(y, (0, CLIP_LEN - len(y))) |
|
|
| return torch.tensor(y).float() |
|
|
| except Exception as e: |
| st.error(f"Error loading audio: {e}") |
| return None |
|
|
|
|
| |
| def predict(waveform): |
| segments = [ |
| waveform[i*SEG_LEN:(i+1)*SEG_LEN].numpy() |
| for i in range(SEGMENTS_PER_CLIP) |
| ] |
|
|
| inputs = feature_extractor( |
| segments, |
| sampling_rate=SR, |
| return_tensors="pt" |
| ) |
|
|
| xb = inputs.input_values.to(DEVICE) |
|
|
| with torch.no_grad(): |
| logits = model(input_values=xb).logits |
|
|
| logits = logits.mean(0) |
| pred = logits.argmax().item() |
|
|
| probs = torch.softmax(logits, dim=0).cpu().numpy() |
|
|
| return IDX2GENRE[pred], probs |
|
|
| |
| st.title("AST Model for classification of Music") |
| st.write("Upload a 10-second audio clip") |
|
|
| uploaded_file = st.file_uploader("Upload Audio", type=["wav", "mp3"]) |
|
|
| if uploaded_file is not None: |
| st.audio(uploaded_file) |
|
|
| waveform = load_audio(uploaded_file) |
|
|
| if waveform is not None: |
| if st.button("Predict"): |
| with st.spinner("Predicting..."): |
| genre, probs = predict(waveform) |
|
|
| st.success(f"Predicted Genre: {genre.upper()}") |
|
|
| st.subheader("Confidence Scores") |
| for i, g in IDX2GENRE.items(): |
| st.write(f"{g}: {probs[i]:.4f}") |