Shrishti03's picture
Create app.py
247063e verified
import gradio as gr
import torch
import librosa
import numpy as np
from transformers import ASTFeatureExtractor, ASTForAudioClassification
# Use CPU for the free Space
device = torch.device("cpu")
# Pointing to the model you just published!
MODEL_ID = "Shrishti03/messy-mashup-ast"
# Download your custom model and feature extractor
feature_extractor = ASTFeatureExtractor.from_pretrained(MODEL_ID)
model = ASTForAudioClassification.from_pretrained(MODEL_ID)
model.to(device)
model.eval()
def predict_audio(audio_filepath):
if audio_filepath is None:
return None
# Load audio at 16kHz
audio, sr = librosa.load(audio_filepath, sr=16000)
crop_len = 10 * sr
if len(audio) < crop_len:
audio = np.pad(audio, (0, crop_len - len(audio)))
# Your custom sliding-window TTA logic from Kaggle
shifts = 3
crops_per_shift = 12
logits_sum = None
total_weight = 0
for s in range(shifts):
shift_offset = int((len(audio) / shifts) * s)
shifted_audio = np.roll(audio, shift_offset)
step = max((len(audio) - crop_len) // (crops_per_shift - 1), 1)
for i in range(crops_per_shift):
start = i * step
segment = shifted_audio[start:start + crop_len]
segment = segment / (np.max(np.abs(segment)) + 1e-6)
inputs = feature_extractor(
segment,
sampling_rate=16000,
return_tensors="pt"
)
input_values = inputs["input_values"].to(device)
with torch.no_grad():
outputs = model(input_values)
logits = outputs.logits.squeeze(0)
probs = torch.softmax(logits, dim=0)
weight = torch.max(probs).item()
if logits_sum is None:
logits_sum = logits * weight
else:
logits_sum += logits * weight
total_weight += weight
final_logits = logits_sum / total_weight
final_probs = torch.softmax(final_logits, dim=0).numpy()
# Map probabilities to the labels you saved in the config
result_dict = {model.config.id2label[i]: float(final_probs[i]) for i in range(10)}
return result_dict
# Build the Web UI
demo = gr.Interface(
fn=predict_audio,
inputs=gr.Audio(type="filepath", label="Upload Audio (.wav)"),
outputs=gr.Label(num_top_classes=3, label="Predicted Genre"),
title="Messy Mashup: AST Audio Classifier",
description="Upload a noisy music mashup. This Audio Spectrogram Transformer uses a 10-second sliding-window strategy to analyze the track and predict the genre."
)
if __name__ == "__main__":
demo.launch()