import gradio as gr import torch import librosa import numpy as np from transformers import ASTFeatureExtractor, ASTForAudioClassification # Use CPU for the free Space device = torch.device("cpu") # Pointing to the model you just published! MODEL_ID = "Shrishti03/messy-mashup-ast" # Download your custom model and feature extractor feature_extractor = ASTFeatureExtractor.from_pretrained(MODEL_ID) model = ASTForAudioClassification.from_pretrained(MODEL_ID) model.to(device) model.eval() def predict_audio(audio_filepath): if audio_filepath is None: return None # Load audio at 16kHz audio, sr = librosa.load(audio_filepath, sr=16000) crop_len = 10 * sr if len(audio) < crop_len: audio = np.pad(audio, (0, crop_len - len(audio))) # Your custom sliding-window TTA logic from Kaggle shifts = 3 crops_per_shift = 12 logits_sum = None total_weight = 0 for s in range(shifts): shift_offset = int((len(audio) / shifts) * s) shifted_audio = np.roll(audio, shift_offset) step = max((len(audio) - crop_len) // (crops_per_shift - 1), 1) for i in range(crops_per_shift): start = i * step segment = shifted_audio[start:start + crop_len] segment = segment / (np.max(np.abs(segment)) + 1e-6) inputs = feature_extractor( segment, sampling_rate=16000, return_tensors="pt" ) input_values = inputs["input_values"].to(device) with torch.no_grad(): outputs = model(input_values) logits = outputs.logits.squeeze(0) probs = torch.softmax(logits, dim=0) weight = torch.max(probs).item() if logits_sum is None: logits_sum = logits * weight else: logits_sum += logits * weight total_weight += weight final_logits = logits_sum / total_weight final_probs = torch.softmax(final_logits, dim=0).numpy() # Map probabilities to the labels you saved in the config result_dict = {model.config.id2label[i]: float(final_probs[i]) for i in range(10)} return result_dict # Build the Web UI demo = gr.Interface( fn=predict_audio, inputs=gr.Audio(type="filepath", label="Upload Audio (.wav)"), outputs=gr.Label(num_top_classes=3, label="Predicted Genre"), title="Messy Mashup: AST Audio Classifier", description="Upload a noisy music mashup. This Audio Spectrogram Transformer uses a 10-second sliding-window strategy to analyze the track and predict the genre." ) if __name__ == "__main__": demo.launch()