File size: 2,690 Bytes
247063e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
import torch
import librosa
import numpy as np
from transformers import ASTFeatureExtractor, ASTForAudioClassification

# Use CPU for the free Space
device = torch.device("cpu") 

# Pointing to the model you just published!
MODEL_ID = "Shrishti03/messy-mashup-ast" 

# Download your custom model and feature extractor
feature_extractor = ASTFeatureExtractor.from_pretrained(MODEL_ID)
model = ASTForAudioClassification.from_pretrained(MODEL_ID)
model.to(device)
model.eval()

def predict_audio(audio_filepath):
    if audio_filepath is None:
        return None
        
    # Load audio at 16kHz
    audio, sr = librosa.load(audio_filepath, sr=16000)
    crop_len = 10 * sr
    
    if len(audio) < crop_len:
        audio = np.pad(audio, (0, crop_len - len(audio)))

    # Your custom sliding-window TTA logic from Kaggle
    shifts = 3
    crops_per_shift = 12
    logits_sum = None
    total_weight = 0

    for s in range(shifts):
        shift_offset = int((len(audio) / shifts) * s)
        shifted_audio = np.roll(audio, shift_offset)
        step = max((len(audio) - crop_len) // (crops_per_shift - 1), 1)

        for i in range(crops_per_shift):
            start = i * step
            segment = shifted_audio[start:start + crop_len]
            segment = segment / (np.max(np.abs(segment)) + 1e-6)

            inputs = feature_extractor(
                segment,
                sampling_rate=16000,
                return_tensors="pt"
            )
            input_values = inputs["input_values"].to(device)

            with torch.no_grad():
                outputs = model(input_values)

            logits = outputs.logits.squeeze(0)
            probs = torch.softmax(logits, dim=0)
            weight = torch.max(probs).item()

            if logits_sum is None:
                logits_sum = logits * weight
            else:
                logits_sum += logits * weight

            total_weight += weight

    final_logits = logits_sum / total_weight
    final_probs = torch.softmax(final_logits, dim=0).numpy()
    
    # Map probabilities to the labels you saved in the config
    result_dict = {model.config.id2label[i]: float(final_probs[i]) for i in range(10)}
    return result_dict

# Build the Web UI
demo = gr.Interface(
    fn=predict_audio,
    inputs=gr.Audio(type="filepath", label="Upload Audio (.wav)"),
    outputs=gr.Label(num_top_classes=3, label="Predicted Genre"),
    title="Messy Mashup: AST Audio Classifier",
    description="Upload a noisy music mashup. This Audio Spectrogram Transformer uses a 10-second sliding-window strategy to analyze the track and predict the genre."
)

if __name__ == "__main__":
    demo.launch()