jeevitha-app's picture
Update app.py
cdc1936 verified
raw
history blame
4.12 kB
# ============================================================
# 🌍 Multi-Lingual Sentiment Analysis (English + Persian)
# With SHAP Interpretability
# ============================================================
import gradio as gr
import joblib
import numpy as np
import shap
import matplotlib.pyplot as plt
import os
# ------------------------------------------------------------
# 1️⃣ Load Pretrained Models and Vectorizers
# ------------------------------------------------------------
english_model = joblib.load("best_model.pkl")
english_vectorizer = joblib.load("tfidf_vectorizer.pkl")
persian_model = joblib.load("logistic_regression.pkl")
persian_vectorizer = joblib.load("tfidf_vectorizer_persian.pkl")
# Define class labels
english_labels = ["Negative", "Neutral", "Positive"]
persian_labels = ["منفی", "خنثی", "مثبت"]
# ------------------------------------------------------------
# 2️⃣ SHAP Visualization Function
# ------------------------------------------------------------
def get_shap_plot(model, vectorizer, text, class_index, class_name):
X_input = vectorizer.transform([text])
explainer = shap.Explainer(model, vectorizer.transform([" ".join(text.split()[:50])]))
shap_values = explainer(X_input)
shap_for_class = shap_values.values[0][:, class_index]
feature_names = np.array(vectorizer.get_feature_names_out())
top_idx = np.argsort(-np.abs(shap_for_class))[:10]
top_words = feature_names[top_idx]
top_impacts = shap_for_class[top_idx]
plt.figure(figsize=(6, 3))
colors = ["crimson" if v > 0 else "steelblue" for v in top_impacts]
plt.barh(top_words, top_impacts, color=colors)
plt.title(f"Top Words driving {class_name} prediction")
plt.xlabel("SHAP Value (Impact)")
plt.gca().invert_yaxis()
plt.tight_layout()
plt.savefig("shap_plot.png", bbox_inches='tight')
plt.close()
return top_words.tolist(), "shap_plot.png"
# ------------------------------------------------------------
# 3️⃣ Prediction + Interpretability Function
# ------------------------------------------------------------
def predict_sentiment(text, language):
if not text.strip():
return "Please enter a comment.", None
if language == "English":
model, vectorizer, labels = english_model, english_vectorizer, english_labels
else:
model, vectorizer, labels = persian_model, persian_vectorizer, persian_labels
X_input = vectorizer.transform([text])
probs = model.predict_proba(X_input)[0]
pred_idx = np.argmax(probs)
pred_class = labels[pred_idx]
conf = probs[pred_idx]
# SHAP interpretation
top_words, shap_plot = get_shap_plot(model, vectorizer, text, pred_idx, pred_class)
# Final output
explanation = f"""
**Predicted Sentiment:** {pred_class}
**Confidence:** {conf:.2f}
**Top Influential Words:** {', '.join(top_words)}
"""
return explanation, shap_plot
# ------------------------------------------------------------
# 4️⃣ Gradio Interface
# ------------------------------------------------------------
title = "🌐 Multi-Lingual Sentiment Analysis (English + Persian)"
description = """
Select a language, type a comment, and see both the sentiment prediction and SHAP interpretability.
"""
examples = [
["I love this product! Highly recommend.", "English"],
["Worst experience ever, totally disappointed.", "English"],
["The service was okay, nothing special.", "English"],
["این محصول فوق‌العاده است", "Persian"],
["تجربه‌ی بدی بود، ناراضی‌ام", "Persian"],
["کیفیتش متوسط بود", "Persian"]
]
demo = gr.Interface(
fn=predict_sentiment,
inputs=[
gr.Textbox(lines=3, label="Enter comment"),
gr.Radio(["English", "Persian"], label="Choose Dataset/Language", value="English")
],
outputs=[
gr.Markdown(label="Prediction & Explanation"),
gr.Image(label="Top Word Contributions")
],
title=title,
description=description,
examples=examples,
)
if __name__ == "__main__":
demo.launch()