Zuitebiechan
feat: add music genre classifier using DistilHuBERT
3af23ae
import gradio as gr
from transformers import pipeline
import numpy as np
# 加载音乐分类模型
model_name = "lewtun/distilhubert-finetuned-gtzan"
classifier = pipeline("audio-classification", model=model_name)
def classify_audio(audio_path):
"""
分类音频文件并返回前5个预测结果
"""
try:
# 使用 pipeline 进行预测
predictions = classifier(audio_path, top_k=5)
# 格式化结果
results = {
"top1": predictions[0],
"top5": predictions
}
return results
except Exception as e:
return {"error": str(e)}
# 创建 Gradio 界面
demo = gr.Interface(
fn=classify_audio,
inputs=gr.Audio(type="filepath", label="Upload Music File"),
outputs=gr.JSON(label="Classification Results"),
title="Music Genre Classifier",
description="Upload a music file to classify its genre using DistilHuBERT fine-tuned on GTZAN dataset.",
examples=[],
api_name="predict" # 这将创建一个 /api/predict endpoint
)
if __name__ == "__main__":
demo.launch()