Cecile / app.py
7sunshine7's picture
Update app.py
5c41c46 verified
import gradio as gr
import torch
import librosa
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
# 加载模型和特征提取器
model_path = "./music_genre_classifier/best_model" # 你上传的模型文件夹名(保持与Space中文件夹一致)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_path)
model = AutoModelForAudioClassification.from_pretrained(model_path)
# 使用GPU(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# 推理函数
def predict_genre(audio_path):
try:
audio, sr = librosa.load(audio_path, sr=16000)
inputs = feature_extractor(audio, sampling_rate=16000, return_tensors="pt").to(device)
with torch.no_grad():
logits = model(**inputs).logits
predicted_label = torch.argmax(logits, dim=-1).item()
genre = model.config.id2label.get(predicted_label, "Unknown")
return f"🎧 Predicted Genre: {genre}"
except Exception as e:
return f"❌ Prediction failed: {str(e)}"
# 创建 Gradio 接口
demo = gr.Interface(
fn=predict_genre,
inputs=gr.Audio(type="filepath", label="🎵 Upload a music file (.mp3, .wav, .flac)"),
outputs=gr.Textbox(label="Predicted Genre"),
title="🎶 Music Genre Classifier",
description="Upload a music file and this model will predict its genre using a fine-tuned Transformer model.",
)
# 启动Gradio界面
if __name__ == "__main__":
demo.launch() # 如果需要公开链接,使用 share=True