File size: 1,395 Bytes
39d077e
 
e01013c
39d077e
 
 
5b541ac
39d077e
 
93e78ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39d077e
93e78ed
 
e01013c
93e78ed
 
39d077e
 
93e78ed
 
39d077e
93e78ed
39d077e
 
 
93e78ed
39d077e
93e78ed
39d077e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import gradio as gr
import torch
import librosa
from transformers import AutoModelForAudioClassification, AutoFeatureExtractor

# Load model directly from HF Hub
repo_name = "m7k-run/ast-genre-classifier"
model = AutoModelForAudioClassification.from_pretrained(repo_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(repo_name)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()


genres = ["blues", "classical", "country", "disco", "hiphop",
          "jazz", "metal", "pop", "reggae", "rock"]
idx_genre = {i: g for i, g in enumerate(genres)}

MAX_LEN = 160000

def pad_truncate(audio, max_len=MAX_LEN):
    if len(audio) > max_len:
        audio = audio[:max_len]
    else:
        pad_len = max_len - len(audio)
        audio = np.pad(audio, (0, pad_len))
    return audio

def predict(audio):
    audio_array, _ = librosa.load(audio, sr=16000)
    audio_array = pad_truncate(audio_array)
    
    inputs = feature_extractor(audio_array, sampling_rate=16000, return_tensors="pt")
    input_values = inputs["input_values"].to(device)
    
    with torch.no_grad():
        logits = model(input_values).logits
        pred_idx = logits.argmax(dim=1).item()
    
    return idx_genre[pred_idx]  

demo = gr.Interface(
    fn=predict,
    inputs=gr.Audio(type="filepath"),
    outputs=gr.Label(),
    title="Music Genre Classifier"
)

demo.launch()