File size: 2,524 Bytes
3dd49c1
bce0b3d
 
9cfa10f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bce0b3d
 
9cfa10f
 
 
bce0b3d
9cfa10f
 
 
bce0b3d
 
 
 
 
 
9cfa10f
 
 
bce0b3d
9cfa10f
 
 
 
 
bce0b3d
9cfa10f
bce0b3d
9cfa10f
 
 
bce0b3d
 
9cfa10f
 
 
 
 
 
 
 
 
 
 
bce0b3d
 
9cfa10f
bce0b3d
9cfa10f
 
bce0b3d
9cfa10f
 
bce0b3d
9cfa10f
bce0b3d
9cfa10f
bce0b3d
9cfa10f
 
 
bce0b3d
9cfa10f
bce0b3d
 
 
 
9cfa10f
bce0b3d
9cfa10f
 
 
 
bce0b3d
3dd49c1
bce0b3d
9cfa10f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

import streamlit as st
import torch
import torchaudio
import numpy as np
from transformers import ASTForAudioClassification, ASTFeatureExtractor
import librosa

# ================= CONFIG =================
SR = 16000
CLIP_DURATION = 10
CLIP_LEN = SR * CLIP_DURATION

SEGMENTS_PER_CLIP = 3   
SEG_LEN = CLIP_LEN // SEGMENTS_PER_CLIP

GENRES = ["blues","classical","country","disco","hiphop",
          "jazz","metal","pop","reggae","rock"]

IDX2GENRE = {i: g for i, g in enumerate(GENRES)}

DEVICE = "cpu"   

# ================= LOAD MODEL =================
@st.cache_resource
def load_model():
    model = ASTForAudioClassification.from_pretrained(
        "22ds2000101/MessyMashUp_AST_Final"
    )
    model.to(DEVICE)
    model.eval()

    extractor = ASTFeatureExtractor.from_pretrained(
        "22ds2000101/MessyMashUp_AST_Final"
    )

    return model, extractor

model, feature_extractor = load_model()






def load_audio(file):
    try:
        y, sr = librosa.load(file, sr=SR)

        # mono already handled by librosa
        if len(y) >= CLIP_LEN:
            y = y[:CLIP_LEN]
        else:
            y = np.pad(y, (0, CLIP_LEN - len(y)))

        return torch.tensor(y).float()

    except Exception as e:
        st.error(f"Error loading audio: {e}")
        return None


# ================= INFERENCE =================
def predict(waveform):
    segments = [
        waveform[i*SEG_LEN:(i+1)*SEG_LEN].numpy()
        for i in range(SEGMENTS_PER_CLIP)
    ]

    inputs = feature_extractor(
        segments,
        sampling_rate=SR,
        return_tensors="pt"
    )

    xb = inputs.input_values.to(DEVICE)

    with torch.no_grad():
        logits = model(input_values=xb).logits

    logits = logits.mean(0)
    pred = logits.argmax().item()

    probs = torch.softmax(logits, dim=0).cpu().numpy()

    return IDX2GENRE[pred], probs

# ================= UI =================
st.title("AST Model for classification of Music")
st.write("Upload a 10-second audio clip")

uploaded_file = st.file_uploader("Upload Audio", type=["wav", "mp3"])

if uploaded_file is not None:
    st.audio(uploaded_file)

    waveform = load_audio(uploaded_file)

    if waveform is not None:
        if st.button("Predict"):
            with st.spinner("Predicting..."):
                genre, probs = predict(waveform)

            st.success(f"Predicted Genre: {genre.upper()}")

            st.subheader("Confidence Scores")
            for i, g in IDX2GENRE.items():
                st.write(f"{g}: {probs[i]:.4f}")