from __future__ import annotations import math from functools import lru_cache from pathlib import Path import gradio as gr import numpy as np import torch from scipy.signal import resample_poly from transformers import ASTForAudioClassification, AutoConfig, AutoFeatureExtractor APP_DIR = Path(__file__).resolve().parent CHECKPOINT_PATH = APP_DIR / "working_ast_best.pth" DEVICE = torch.device("cpu") TOP_K = 5 TARGET_SECONDS = 10 GENRES = [ "blues", "classical", "country", "disco", "hiphop", "jazz", "metal", "pop", "reggae", "rock", ] def rms(waveform: np.ndarray) -> float: return float(np.sqrt(np.mean(np.square(waveform), dtype=np.float64))) def pad_or_crop(waveform: np.ndarray, target_samples: int) -> np.ndarray: if waveform.shape[0] >= target_samples: return waveform[:target_samples] padded = np.zeros(target_samples, dtype=np.float32) padded[: waveform.shape[0]] = waveform return padded def to_float32(audio: np.ndarray) -> np.ndarray: array = np.asarray(audio) if np.issubdtype(array.dtype, np.integer): scale = max(abs(np.iinfo(array.dtype).min), np.iinfo(array.dtype).max) return array.astype(np.float32) / float(scale) return array.astype(np.float32, copy=False) def to_mono(audio: np.ndarray) -> np.ndarray: if audio.ndim == 1: return audio if audio.ndim != 2: raise gr.Error("Expected mono or stereo audio.") channel_axis = 0 if audio.shape[0] <= 4 and audio.shape[1] > 4 else 1 return audio.mean(axis=channel_axis, dtype=np.float32) def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray: if orig_sr == target_sr: return audio.astype(np.float32, copy=False) common = math.gcd(orig_sr, target_sr) up = target_sr // common down = orig_sr // common return resample_poly(audio, up, down).astype(np.float32) @lru_cache(maxsize=1) def get_feature_extractor(): return AutoFeatureExtractor.from_pretrained(APP_DIR) @lru_cache(maxsize=1) def get_model(): config = AutoConfig.from_pretrained(APP_DIR) model = ASTForAudioClassification(config) state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu") missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) if missing_keys or unexpected_keys: raise RuntimeError( "Checkpoint did not match the local AST config. " f"Missing keys: {missing_keys[:5]}. " f"Unexpected keys: {unexpected_keys[:5]}." ) model.to(DEVICE) model.eval() return model def preprocess_audio(audio_input: tuple[int, np.ndarray] | None) -> torch.Tensor: if audio_input is None: raise gr.Error("Upload an audio clip or record one from the microphone.") sample_rate, audio = audio_input waveform = to_float32(audio) waveform = to_mono(waveform) feature_extractor = get_feature_extractor() target_sr = int(feature_extractor.sampling_rate) target_samples = target_sr * TARGET_SECONDS waveform = resample_audio(waveform, int(sample_rate), target_sr) waveform = pad_or_crop(waveform, target_samples) waveform = waveform / (rms(waveform) + 1e-9) waveform = np.clip(waveform, -1.0, 1.0) features = feature_extractor( waveform, sampling_rate=target_sr, return_tensors="pt", ) return features["input_values"] def run_model(input_values: torch.Tensor) -> list[float]: model = get_model() with torch.inference_mode(): logits = model(input_values=input_values.to(DEVICE)).logits probs = torch.softmax(logits, dim=-1)[0].cpu().tolist() return probs def predict(audio_input: tuple[int, np.ndarray] | None): input_values = preprocess_audio(audio_input) probabilities = run_model(input_values) top_indices = np.argsort(probabilities)[::-1][:TOP_K] label_scores = {genre: float(score) for genre, score in zip(GENRES, probabilities)} top_rows = [[GENRES[i], round(probabilities[i] * 100.0, 2)] for i in top_indices] best_index = int(np.argmax(probabilities)) best_genre = GENRES[best_index] best_confidence = probabilities[best_index] return ( best_genre, f"{best_confidence:.2%}", label_scores, top_rows, ) with gr.Blocks(title="Messy Mashup AST Genre Classifier") as demo: gr.Markdown( """ # Messy Mashup AST Genre Classifier Upload a mashup clip and the app predicts one of 10 music genres using a fine-tuned Audio Spectrogram Transformer checkpoint. The preprocessing mirrors your notebook flow: resample to 16 kHz, convert to mono, pad or crop to 10 seconds, RMS normalize, clamp, then pass the waveform through the AST feature extractor. """ ) with gr.Row(): with gr.Column(scale=2): audio_input = gr.Audio( sources=["upload", "microphone"], type="numpy", label="Audio Input", ) submit_button = gr.Button("Predict Genre", variant="primary") with gr.Column(scale=1): predicted_genre = gr.Textbox(label="Predicted Genre") confidence = gr.Textbox(label="Confidence") label_scores = gr.Label(label="Class Probabilities", num_top_classes=TOP_K) top_predictions = gr.Dataframe( headers=["Genre", "Confidence (%)"], datatype=["str", "number"], row_count=(TOP_K, "fixed"), col_count=(2, "fixed"), label="Top Predictions", ) submit_button.click( fn=predict, inputs=audio_input, outputs=[predicted_genre, confidence, label_scores, top_predictions], ) if __name__ == "__main__": demo.queue(max_size=16).launch()