import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification from pathlib import Path # Load model MODEL_PATH = "outputs/final_model" if Path(MODEL_PATH).exists(): tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) model.eval() MODEL_LOADED = True else: MODEL_LOADED = False model = None tokenizer = None def predict_genre(track_name): """Predict genre for a track name""" if not MODEL_LOADED: return "Model not found. Please train first." if not track_name: return "Please enter a track name" # Tokenize inputs = tokenizer(track_name, return_tensors='pt', padding=True, truncation=True, max_length=256) # Predict with torch.no_grad(): outputs = model(**inputs) probs = torch.softmax(outputs.logits, dim=-1) pred_id = torch.argmax(probs, dim=-1).item() confidence = probs[0, pred_id].item() # Get label pred_label = model.config.id2label.get(pred_id, f"Class_{pred_id}") return f"**Genre:** {pred_label}\n\n**Confidence:** {confidence:.2%}" # Create Gradio interface with gr.Blocks(title="Spotify Genre Classifier", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🎵 Spotify Genre Classifier") gr.Markdown("Enter a song track name to predict its genre using a fine-tuned GPT-2 model.") with gr.Row(): with gr.Column(): track_input = gr.Textbox( label="Track Name", placeholder="e.g., Bohemian Rhapsody", lines=1 ) predict_btn = gr.Button("🔮 Predict Genre", variant="primary") with gr.Column(): output = gr.Textbox(label="Prediction") # Examples gr.Examples( examples=[ "Bohemian Rhapsody", "Shape of You", "Old Town Road", "Blinding Lights", "Bad Guy", "Stairway to Heaven", "Smells Like Teen Spirit", "Billie Jean", ], inputs=track_input ) predict_btn.click(fn=predict_genre, inputs=track_input, outputs=output) track_input.submit(fn=predict_genre, inputs=track_input, outputs=output) if __name__ == "__main__": demo.launch()