Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import math | |
| from functools import lru_cache | |
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from scipy.signal import resample_poly | |
| from transformers import ASTForAudioClassification, AutoConfig, AutoFeatureExtractor | |
| APP_DIR = Path(__file__).resolve().parent | |
| CHECKPOINT_PATH = APP_DIR / "working_ast_best.pth" | |
| DEVICE = torch.device("cpu") | |
| TOP_K = 5 | |
| TARGET_SECONDS = 10 | |
| GENRES = [ | |
| "blues", | |
| "classical", | |
| "country", | |
| "disco", | |
| "hiphop", | |
| "jazz", | |
| "metal", | |
| "pop", | |
| "reggae", | |
| "rock", | |
| ] | |
| def rms(waveform: np.ndarray) -> float: | |
| return float(np.sqrt(np.mean(np.square(waveform), dtype=np.float64))) | |
| def pad_or_crop(waveform: np.ndarray, target_samples: int) -> np.ndarray: | |
| if waveform.shape[0] >= target_samples: | |
| return waveform[:target_samples] | |
| padded = np.zeros(target_samples, dtype=np.float32) | |
| padded[: waveform.shape[0]] = waveform | |
| return padded | |
| def to_float32(audio: np.ndarray) -> np.ndarray: | |
| array = np.asarray(audio) | |
| if np.issubdtype(array.dtype, np.integer): | |
| scale = max(abs(np.iinfo(array.dtype).min), np.iinfo(array.dtype).max) | |
| return array.astype(np.float32) / float(scale) | |
| return array.astype(np.float32, copy=False) | |
| def to_mono(audio: np.ndarray) -> np.ndarray: | |
| if audio.ndim == 1: | |
| return audio | |
| if audio.ndim != 2: | |
| raise gr.Error("Expected mono or stereo audio.") | |
| channel_axis = 0 if audio.shape[0] <= 4 and audio.shape[1] > 4 else 1 | |
| return audio.mean(axis=channel_axis, dtype=np.float32) | |
| def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray: | |
| if orig_sr == target_sr: | |
| return audio.astype(np.float32, copy=False) | |
| common = math.gcd(orig_sr, target_sr) | |
| up = target_sr // common | |
| down = orig_sr // common | |
| return resample_poly(audio, up, down).astype(np.float32) | |
| def get_feature_extractor(): | |
| return AutoFeatureExtractor.from_pretrained(APP_DIR) | |
| def get_model(): | |
| config = AutoConfig.from_pretrained(APP_DIR) | |
| model = ASTForAudioClassification(config) | |
| state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu") | |
| missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) | |
| if missing_keys or unexpected_keys: | |
| raise RuntimeError( | |
| "Checkpoint did not match the local AST config. " | |
| f"Missing keys: {missing_keys[:5]}. " | |
| f"Unexpected keys: {unexpected_keys[:5]}." | |
| ) | |
| model.to(DEVICE) | |
| model.eval() | |
| return model | |
| def preprocess_audio(audio_input: tuple[int, np.ndarray] | None) -> torch.Tensor: | |
| if audio_input is None: | |
| raise gr.Error("Upload an audio clip or record one from the microphone.") | |
| sample_rate, audio = audio_input | |
| waveform = to_float32(audio) | |
| waveform = to_mono(waveform) | |
| feature_extractor = get_feature_extractor() | |
| target_sr = int(feature_extractor.sampling_rate) | |
| target_samples = target_sr * TARGET_SECONDS | |
| waveform = resample_audio(waveform, int(sample_rate), target_sr) | |
| waveform = pad_or_crop(waveform, target_samples) | |
| waveform = waveform / (rms(waveform) + 1e-9) | |
| waveform = np.clip(waveform, -1.0, 1.0) | |
| features = feature_extractor( | |
| waveform, | |
| sampling_rate=target_sr, | |
| return_tensors="pt", | |
| ) | |
| return features["input_values"] | |
| def run_model(input_values: torch.Tensor) -> list[float]: | |
| model = get_model() | |
| with torch.inference_mode(): | |
| logits = model(input_values=input_values.to(DEVICE)).logits | |
| probs = torch.softmax(logits, dim=-1)[0].cpu().tolist() | |
| return probs | |
| def predict(audio_input: tuple[int, np.ndarray] | None): | |
| input_values = preprocess_audio(audio_input) | |
| probabilities = run_model(input_values) | |
| top_indices = np.argsort(probabilities)[::-1][:TOP_K] | |
| label_scores = {genre: float(score) for genre, score in zip(GENRES, probabilities)} | |
| top_rows = [[GENRES[i], round(probabilities[i] * 100.0, 2)] for i in top_indices] | |
| best_index = int(np.argmax(probabilities)) | |
| best_genre = GENRES[best_index] | |
| best_confidence = probabilities[best_index] | |
| return ( | |
| best_genre, | |
| f"{best_confidence:.2%}", | |
| label_scores, | |
| top_rows, | |
| ) | |
| with gr.Blocks(title="Messy Mashup AST Genre Classifier") as demo: | |
| gr.Markdown( | |
| """ | |
| # Messy Mashup AST Genre Classifier | |
| Upload a mashup clip and the app predicts one of 10 music genres using a | |
| fine-tuned Audio Spectrogram Transformer checkpoint. | |
| The preprocessing mirrors your notebook flow: | |
| resample to 16 kHz, convert to mono, pad or crop to 10 seconds, RMS | |
| normalize, clamp, then pass the waveform through the AST feature extractor. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| audio_input = gr.Audio( | |
| sources=["upload", "microphone"], | |
| type="numpy", | |
| label="Audio Input", | |
| ) | |
| submit_button = gr.Button("Predict Genre", variant="primary") | |
| with gr.Column(scale=1): | |
| predicted_genre = gr.Textbox(label="Predicted Genre") | |
| confidence = gr.Textbox(label="Confidence") | |
| label_scores = gr.Label(label="Class Probabilities", num_top_classes=TOP_K) | |
| top_predictions = gr.Dataframe( | |
| headers=["Genre", "Confidence (%)"], | |
| datatype=["str", "number"], | |
| row_count=(TOP_K, "fixed"), | |
| col_count=(2, "fixed"), | |
| label="Top Predictions", | |
| ) | |
| submit_button.click( | |
| fn=predict, | |
| inputs=audio_input, | |
| outputs=[predicted_genre, confidence, label_scores, top_predictions], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=16).launch() | |