ast-classifier / app.py
m7k-run's picture
Update app.py
93e78ed verified
import gradio as gr
import torch
import librosa
from transformers import AutoModelForAudioClassification, AutoFeatureExtractor
# Load model directly from HF Hub
repo_name = "m7k-run/ast-genre-classifier"
model = AutoModelForAudioClassification.from_pretrained(repo_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(repo_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
genres = ["blues", "classical", "country", "disco", "hiphop",
"jazz", "metal", "pop", "reggae", "rock"]
idx_genre = {i: g for i, g in enumerate(genres)}
MAX_LEN = 160000
def pad_truncate(audio, max_len=MAX_LEN):
if len(audio) > max_len:
audio = audio[:max_len]
else:
pad_len = max_len - len(audio)
audio = np.pad(audio, (0, pad_len))
return audio
def predict(audio):
audio_array, _ = librosa.load(audio, sr=16000)
audio_array = pad_truncate(audio_array)
inputs = feature_extractor(audio_array, sampling_rate=16000, return_tensors="pt")
input_values = inputs["input_values"].to(device)
with torch.no_grad():
logits = model(input_values).logits
pred_idx = logits.argmax(dim=1).item()
return idx_genre[pred_idx]
demo = gr.Interface(
fn=predict,
inputs=gr.Audio(type="filepath"),
outputs=gr.Label(),
title="Music Genre Classifier"
)
demo.launch()