| import gradio as gr |
| import torch |
| import librosa |
| import numpy as np |
| from transformers import ASTFeatureExtractor, ASTForAudioClassification |
|
|
| |
| device = torch.device("cpu") |
|
|
| |
| MODEL_ID = "Shrishti03/messy-mashup-ast" |
|
|
| |
| feature_extractor = ASTFeatureExtractor.from_pretrained(MODEL_ID) |
| model = ASTForAudioClassification.from_pretrained(MODEL_ID) |
| model.to(device) |
| model.eval() |
|
|
| def predict_audio(audio_filepath): |
| if audio_filepath is None: |
| return None |
| |
| |
| audio, sr = librosa.load(audio_filepath, sr=16000) |
| crop_len = 10 * sr |
| |
| if len(audio) < crop_len: |
| audio = np.pad(audio, (0, crop_len - len(audio))) |
|
|
| |
| shifts = 3 |
| crops_per_shift = 12 |
| logits_sum = None |
| total_weight = 0 |
|
|
| for s in range(shifts): |
| shift_offset = int((len(audio) / shifts) * s) |
| shifted_audio = np.roll(audio, shift_offset) |
| step = max((len(audio) - crop_len) // (crops_per_shift - 1), 1) |
|
|
| for i in range(crops_per_shift): |
| start = i * step |
| segment = shifted_audio[start:start + crop_len] |
| segment = segment / (np.max(np.abs(segment)) + 1e-6) |
|
|
| inputs = feature_extractor( |
| segment, |
| sampling_rate=16000, |
| return_tensors="pt" |
| ) |
| input_values = inputs["input_values"].to(device) |
|
|
| with torch.no_grad(): |
| outputs = model(input_values) |
|
|
| logits = outputs.logits.squeeze(0) |
| probs = torch.softmax(logits, dim=0) |
| weight = torch.max(probs).item() |
|
|
| if logits_sum is None: |
| logits_sum = logits * weight |
| else: |
| logits_sum += logits * weight |
|
|
| total_weight += weight |
|
|
| final_logits = logits_sum / total_weight |
| final_probs = torch.softmax(final_logits, dim=0).numpy() |
| |
| |
| result_dict = {model.config.id2label[i]: float(final_probs[i]) for i in range(10)} |
| return result_dict |
|
|
| |
| demo = gr.Interface( |
| fn=predict_audio, |
| inputs=gr.Audio(type="filepath", label="Upload Audio (.wav)"), |
| outputs=gr.Label(num_top_classes=3, label="Predicted Genre"), |
| title="Messy Mashup: AST Audio Classifier", |
| description="Upload a noisy music mashup. This Audio Spectrogram Transformer uses a 10-second sliding-window strategy to analyze the track and predict the genre." |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |