Spaces:
Sleeping
Sleeping
Zuitebiechan commited on
Commit ·
3af23ae
1
Parent(s): f73b4ad
feat: add music genre classifier using DistilHuBERT
Browse files- Load lewtun/distilhubert-finetuned-gtzan model
- Create Gradio interface for audio classification
- Return top 5 genre predictions in JSON format
- Enable API endpoint at /api/predict
- app.py +39 -0
- requirements.txt +6 -0
app.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from transformers import pipeline
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
# 加载音乐分类模型
|
| 6 |
+
model_name = "lewtun/distilhubert-finetuned-gtzan"
|
| 7 |
+
classifier = pipeline("audio-classification", model=model_name)
|
| 8 |
+
|
| 9 |
+
def classify_audio(audio_path):
|
| 10 |
+
"""
|
| 11 |
+
分类音频文件并返回前5个预测结果
|
| 12 |
+
"""
|
| 13 |
+
try:
|
| 14 |
+
# 使用 pipeline 进行预测
|
| 15 |
+
predictions = classifier(audio_path, top_k=5)
|
| 16 |
+
|
| 17 |
+
# 格式化结果
|
| 18 |
+
results = {
|
| 19 |
+
"top1": predictions[0],
|
| 20 |
+
"top5": predictions
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
return results
|
| 24 |
+
except Exception as e:
|
| 25 |
+
return {"error": str(e)}
|
| 26 |
+
|
| 27 |
+
# 创建 Gradio 界面
|
| 28 |
+
demo = gr.Interface(
|
| 29 |
+
fn=classify_audio,
|
| 30 |
+
inputs=gr.Audio(type="filepath", label="Upload Music File"),
|
| 31 |
+
outputs=gr.JSON(label="Classification Results"),
|
| 32 |
+
title="Music Genre Classifier",
|
| 33 |
+
description="Upload a music file to classify its genre using DistilHuBERT fine-tuned on GTZAN dataset.",
|
| 34 |
+
examples=[],
|
| 35 |
+
api_name="predict" # 这将创建一个 /api/predict endpoint
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
if __name__ == "__main__":
|
| 39 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==4.44.0
|
| 2 |
+
transformers==4.45.0
|
| 3 |
+
torch==2.4.0
|
| 4 |
+
torchaudio==2.4.0
|
| 5 |
+
librosa==0.10.2
|
| 6 |
+
soundfile==0.12.1
|