Zuitebiechan commited on
Commit
3af23ae
·
1 Parent(s): f73b4ad

feat: add music genre classifier using DistilHuBERT

Browse files

- Load lewtun/distilhubert-finetuned-gtzan model
- Create Gradio interface for audio classification
- Return top 5 genre predictions in JSON format
- Enable API endpoint at /api/predict

Files changed (2) hide show
  1. app.py +39 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ import numpy as np
4
+
5
+ # 加载音乐分类模型
6
+ model_name = "lewtun/distilhubert-finetuned-gtzan"
7
+ classifier = pipeline("audio-classification", model=model_name)
8
+
9
+ def classify_audio(audio_path):
10
+ """
11
+ 分类音频文件并返回前5个预测结果
12
+ """
13
+ try:
14
+ # 使用 pipeline 进行预测
15
+ predictions = classifier(audio_path, top_k=5)
16
+
17
+ # 格式化结果
18
+ results = {
19
+ "top1": predictions[0],
20
+ "top5": predictions
21
+ }
22
+
23
+ return results
24
+ except Exception as e:
25
+ return {"error": str(e)}
26
+
27
+ # 创建 Gradio 界面
28
+ demo = gr.Interface(
29
+ fn=classify_audio,
30
+ inputs=gr.Audio(type="filepath", label="Upload Music File"),
31
+ outputs=gr.JSON(label="Classification Results"),
32
+ title="Music Genre Classifier",
33
+ description="Upload a music file to classify its genre using DistilHuBERT fine-tuned on GTZAN dataset.",
34
+ examples=[],
35
+ api_name="predict" # 这将创建一个 /api/predict endpoint
36
+ )
37
+
38
+ if __name__ == "__main__":
39
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==4.44.0
2
+ transformers==4.45.0
3
+ torch==2.4.0
4
+ torchaudio==2.4.0
5
+ librosa==0.10.2
6
+ soundfile==0.12.1