spotify / app.py
maxxcarl's picture
Upload folder using huggingface_hub
be08283 verified
Raw
History Blame Contribute Delete
2.39 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from pathlib import Path
# Load model
MODEL_PATH = "outputs/final_model"
if Path(MODEL_PATH).exists():
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
model.eval()
MODEL_LOADED = True
else:
MODEL_LOADED = False
model = None
tokenizer = None
def predict_genre(track_name):
"""Predict genre for a track name"""
if not MODEL_LOADED:
return "Model not found. Please train first."
if not track_name:
return "Please enter a track name"
# Tokenize
inputs = tokenizer(track_name, return_tensors='pt', padding=True, truncation=True, max_length=256)
# Predict
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)
pred_id = torch.argmax(probs, dim=-1).item()
confidence = probs[0, pred_id].item()
# Get label
pred_label = model.config.id2label.get(pred_id, f"Class_{pred_id}")
return f"**Genre:** {pred_label}\n\n**Confidence:** {confidence:.2%}"
# Create Gradio interface
with gr.Blocks(title="Spotify Genre Classifier", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎵 Spotify Genre Classifier")
gr.Markdown("Enter a song track name to predict its genre using a fine-tuned GPT-2 model.")
with gr.Row():
with gr.Column():
track_input = gr.Textbox(
label="Track Name",
placeholder="e.g., Bohemian Rhapsody",
lines=1
)
predict_btn = gr.Button("🔮 Predict Genre", variant="primary")
with gr.Column():
output = gr.Textbox(label="Prediction")
# Examples
gr.Examples(
examples=[
"Bohemian Rhapsody",
"Shape of You",
"Old Town Road",
"Blinding Lights",
"Bad Guy",
"Stairway to Heaven",
"Smells Like Teen Spirit",
"Billie Jean",
],
inputs=track_input
)
predict_btn.click(fn=predict_genre, inputs=track_input, outputs=output)
track_input.submit(fn=predict_genre, inputs=track_input, outputs=output)
if __name__ == "__main__":
demo.launch()