| import gradio as gr |
| import torch |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| from pathlib import Path |
|
|
| |
| MODEL_PATH = "outputs/final_model" |
|
|
| if Path(MODEL_PATH).exists(): |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) |
| model.eval() |
| MODEL_LOADED = True |
| else: |
| MODEL_LOADED = False |
| model = None |
| tokenizer = None |
|
|
| def predict_genre(track_name): |
| """Predict genre for a track name""" |
| if not MODEL_LOADED: |
| return "Model not found. Please train first." |
| |
| if not track_name: |
| return "Please enter a track name" |
| |
| |
| inputs = tokenizer(track_name, return_tensors='pt', padding=True, truncation=True, max_length=256) |
| |
| |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| probs = torch.softmax(outputs.logits, dim=-1) |
| pred_id = torch.argmax(probs, dim=-1).item() |
| confidence = probs[0, pred_id].item() |
| |
| |
| pred_label = model.config.id2label.get(pred_id, f"Class_{pred_id}") |
| |
| return f"**Genre:** {pred_label}\n\n**Confidence:** {confidence:.2%}" |
|
|
| |
| with gr.Blocks(title="Spotify Genre Classifier", theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# 🎵 Spotify Genre Classifier") |
| gr.Markdown("Enter a song track name to predict its genre using a fine-tuned GPT-2 model.") |
| |
| with gr.Row(): |
| with gr.Column(): |
| track_input = gr.Textbox( |
| label="Track Name", |
| placeholder="e.g., Bohemian Rhapsody", |
| lines=1 |
| ) |
| predict_btn = gr.Button("🔮 Predict Genre", variant="primary") |
| |
| with gr.Column(): |
| output = gr.Textbox(label="Prediction") |
| |
| |
| gr.Examples( |
| examples=[ |
| "Bohemian Rhapsody", |
| "Shape of You", |
| "Old Town Road", |
| "Blinding Lights", |
| "Bad Guy", |
| "Stairway to Heaven", |
| "Smells Like Teen Spirit", |
| "Billie Jean", |
| ], |
| inputs=track_input |
| ) |
| |
| predict_btn.click(fn=predict_genre, inputs=track_input, outputs=output) |
| track_input.submit(fn=predict_genre, inputs=track_input, outputs=output) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|