ShoaibSSM's picture
Upload 6 files
027943a verified
from __future__ import annotations
import math
from functools import lru_cache
from pathlib import Path
import gradio as gr
import numpy as np
import torch
from scipy.signal import resample_poly
from transformers import ASTForAudioClassification, AutoConfig, AutoFeatureExtractor
APP_DIR = Path(__file__).resolve().parent
CHECKPOINT_PATH = APP_DIR / "working_ast_best.pth"
DEVICE = torch.device("cpu")
TOP_K = 5
TARGET_SECONDS = 10
GENRES = [
"blues",
"classical",
"country",
"disco",
"hiphop",
"jazz",
"metal",
"pop",
"reggae",
"rock",
]
def rms(waveform: np.ndarray) -> float:
return float(np.sqrt(np.mean(np.square(waveform), dtype=np.float64)))
def pad_or_crop(waveform: np.ndarray, target_samples: int) -> np.ndarray:
if waveform.shape[0] >= target_samples:
return waveform[:target_samples]
padded = np.zeros(target_samples, dtype=np.float32)
padded[: waveform.shape[0]] = waveform
return padded
def to_float32(audio: np.ndarray) -> np.ndarray:
array = np.asarray(audio)
if np.issubdtype(array.dtype, np.integer):
scale = max(abs(np.iinfo(array.dtype).min), np.iinfo(array.dtype).max)
return array.astype(np.float32) / float(scale)
return array.astype(np.float32, copy=False)
def to_mono(audio: np.ndarray) -> np.ndarray:
if audio.ndim == 1:
return audio
if audio.ndim != 2:
raise gr.Error("Expected mono or stereo audio.")
channel_axis = 0 if audio.shape[0] <= 4 and audio.shape[1] > 4 else 1
return audio.mean(axis=channel_axis, dtype=np.float32)
def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
if orig_sr == target_sr:
return audio.astype(np.float32, copy=False)
common = math.gcd(orig_sr, target_sr)
up = target_sr // common
down = orig_sr // common
return resample_poly(audio, up, down).astype(np.float32)
@lru_cache(maxsize=1)
def get_feature_extractor():
return AutoFeatureExtractor.from_pretrained(APP_DIR)
@lru_cache(maxsize=1)
def get_model():
config = AutoConfig.from_pretrained(APP_DIR)
model = ASTForAudioClassification(config)
state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu")
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if missing_keys or unexpected_keys:
raise RuntimeError(
"Checkpoint did not match the local AST config. "
f"Missing keys: {missing_keys[:5]}. "
f"Unexpected keys: {unexpected_keys[:5]}."
)
model.to(DEVICE)
model.eval()
return model
def preprocess_audio(audio_input: tuple[int, np.ndarray] | None) -> torch.Tensor:
if audio_input is None:
raise gr.Error("Upload an audio clip or record one from the microphone.")
sample_rate, audio = audio_input
waveform = to_float32(audio)
waveform = to_mono(waveform)
feature_extractor = get_feature_extractor()
target_sr = int(feature_extractor.sampling_rate)
target_samples = target_sr * TARGET_SECONDS
waveform = resample_audio(waveform, int(sample_rate), target_sr)
waveform = pad_or_crop(waveform, target_samples)
waveform = waveform / (rms(waveform) + 1e-9)
waveform = np.clip(waveform, -1.0, 1.0)
features = feature_extractor(
waveform,
sampling_rate=target_sr,
return_tensors="pt",
)
return features["input_values"]
def run_model(input_values: torch.Tensor) -> list[float]:
model = get_model()
with torch.inference_mode():
logits = model(input_values=input_values.to(DEVICE)).logits
probs = torch.softmax(logits, dim=-1)[0].cpu().tolist()
return probs
def predict(audio_input: tuple[int, np.ndarray] | None):
input_values = preprocess_audio(audio_input)
probabilities = run_model(input_values)
top_indices = np.argsort(probabilities)[::-1][:TOP_K]
label_scores = {genre: float(score) for genre, score in zip(GENRES, probabilities)}
top_rows = [[GENRES[i], round(probabilities[i] * 100.0, 2)] for i in top_indices]
best_index = int(np.argmax(probabilities))
best_genre = GENRES[best_index]
best_confidence = probabilities[best_index]
return (
best_genre,
f"{best_confidence:.2%}",
label_scores,
top_rows,
)
with gr.Blocks(title="Messy Mashup AST Genre Classifier") as demo:
gr.Markdown(
"""
# Messy Mashup AST Genre Classifier
Upload a mashup clip and the app predicts one of 10 music genres using a
fine-tuned Audio Spectrogram Transformer checkpoint.
The preprocessing mirrors your notebook flow:
resample to 16 kHz, convert to mono, pad or crop to 10 seconds, RMS
normalize, clamp, then pass the waveform through the AST feature extractor.
"""
)
with gr.Row():
with gr.Column(scale=2):
audio_input = gr.Audio(
sources=["upload", "microphone"],
type="numpy",
label="Audio Input",
)
submit_button = gr.Button("Predict Genre", variant="primary")
with gr.Column(scale=1):
predicted_genre = gr.Textbox(label="Predicted Genre")
confidence = gr.Textbox(label="Confidence")
label_scores = gr.Label(label="Class Probabilities", num_top_classes=TOP_K)
top_predictions = gr.Dataframe(
headers=["Genre", "Confidence (%)"],
datatype=["str", "number"],
row_count=(TOP_K, "fixed"),
col_count=(2, "fixed"),
label="Top Predictions",
)
submit_button.click(
fn=predict,
inputs=audio_input,
outputs=[predicted_genre, confidence, label_scores, top_predictions],
)
if __name__ == "__main__":
demo.queue(max_size=16).launch()