Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import librosa | |
| from transformers import AutoModelForAudioClassification, AutoFeatureExtractor | |
| # Load model directly from HF Hub | |
| repo_name = "m7k-run/ast-genre-classifier" | |
| model = AutoModelForAudioClassification.from_pretrained(repo_name) | |
| feature_extractor = AutoFeatureExtractor.from_pretrained(repo_name) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| model.eval() | |
| genres = ["blues", "classical", "country", "disco", "hiphop", | |
| "jazz", "metal", "pop", "reggae", "rock"] | |
| idx_genre = {i: g for i, g in enumerate(genres)} | |
| MAX_LEN = 160000 | |
| def pad_truncate(audio, max_len=MAX_LEN): | |
| if len(audio) > max_len: | |
| audio = audio[:max_len] | |
| else: | |
| pad_len = max_len - len(audio) | |
| audio = np.pad(audio, (0, pad_len)) | |
| return audio | |
| def predict(audio): | |
| audio_array, _ = librosa.load(audio, sr=16000) | |
| audio_array = pad_truncate(audio_array) | |
| inputs = feature_extractor(audio_array, sampling_rate=16000, return_tensors="pt") | |
| input_values = inputs["input_values"].to(device) | |
| with torch.no_grad(): | |
| logits = model(input_values).logits | |
| pred_idx = logits.argmax(dim=1).item() | |
| return idx_genre[pred_idx] | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Audio(type="filepath"), | |
| outputs=gr.Label(), | |
| title="Music Genre Classifier" | |
| ) | |
| demo.launch() |