Huggbottle's picture
Update app.py to show confidence scores for genres
be34fa8 verified
import gradio as gr
import pickle
import os
import numpy as np
# Define paths to the pickle files
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
PKL_DIR = os.path.join(BASE_DIR, 'pkl_files')
SENTIMENT_MODEL_PATH = os.path.join(PKL_DIR, 'sentiment_model.pkl')
SENTIMENT_VECTORIZER_PATH = os.path.join(PKL_DIR, 'sentiment_vectorizer.pkl')
GENRE_MODEL_PATH = os.path.join(PKL_DIR, 'genre_model.pkl')
GENRE_VECTORIZER_PATH = os.path.join(PKL_DIR, 'genre_vectorizer.pkl')
# Load models and vectorizers
def load_pickle(path):
with open(path, 'rb') as f:
return pickle.load(f)
print("Loading models...")
try:
sentiment_model = load_pickle(SENTIMENT_MODEL_PATH)
sentiment_vectorizer = load_pickle(SENTIMENT_VECTORIZER_PATH)
genre_model = load_pickle(GENRE_MODEL_PATH)
genre_vectorizer = load_pickle(GENRE_VECTORIZER_PATH)
print("Models loaded successfully.")
except Exception as e:
print(f"Error loading models: {e}")
raise e
def predict_review(review_text):
if not review_text:
return "Please enter a review.", "Please enter a review."
# Sentiment Prediction
try:
# Transform text using the sentiment vectorizer
# Note: transform expects an iterable, so we wrap review_text in a list
sentiment_features = sentiment_vectorizer.transform([review_text])
sentiment_prediction = sentiment_model.predict(sentiment_features)[0]
except Exception as e:
sentiment_prediction = f"Error in sentiment prediction: {str(e)}"
# Genre Prediction
try:
# Transform text using the genre vectorizer
genre_features = genre_vectorizer.transform([review_text])
# Use decision_function to get confidence scores
if hasattr(genre_model, "decision_function"):
decision_scores = genre_model.decision_function(genre_features)[0]
# Apply Softmax to convert scores to probabilities
exp_scores = np.exp(decision_scores)
probabilities = exp_scores / np.sum(exp_scores)
# Map classes to probabilities
genre_prediction = {
label: float(prob)
for label, prob in zip(genre_model.classes_, probabilities)
}
else:
# Fallback if decision_function is not available
pred = genre_model.predict(genre_features)[0]
genre_prediction = {pred: 1.0}
except Exception as e:
genre_prediction = {"Error": str(e)}
return sentiment_prediction, genre_prediction
# Create Gradio Interface
iface = gr.Interface(
fn=predict_review,
inputs=gr.Textbox(lines=5, placeholder="Enter movie review here...", label="Movie Review"),
outputs=[
gr.Textbox(label="Predicted Sentiment"),
gr.Label(label="Predicted Genre", num_top_classes=3)
],
title="Movie Review Sentiment & Genre Classifier",
description="Enter a movie review to predict its sentiment and the movie genre based on the text."
)
if __name__ == "__main__":
iface.launch()