Spaces:
Sleeping
Sleeping
| # ============================================================ | |
| # 🌍 Multi-Lingual Sentiment Analysis (English + Persian) | |
| # With SHAP Interpretability | |
| # ============================================================ | |
| import gradio as gr | |
| import joblib | |
| import numpy as np | |
| import shap | |
| import matplotlib.pyplot as plt | |
| import os | |
| # ------------------------------------------------------------ | |
| # 1️⃣ Load Pretrained Models and Vectorizers | |
| # ------------------------------------------------------------ | |
| english_model = joblib.load("best_model.pkl") | |
| english_vectorizer = joblib.load("tfidf_vectorizer.pkl") | |
| persian_model = joblib.load("logistic_regression.pkl") | |
| persian_vectorizer = joblib.load("tfidf_vectorizer_persian.pkl") | |
| # Define class labels | |
| english_labels = ["Negative", "Neutral", "Positive"] | |
| persian_labels = ["منفی", "خنثی", "مثبت"] | |
| # ------------------------------------------------------------ | |
| # 2️⃣ SHAP Visualization Function | |
| # ------------------------------------------------------------ | |
| def get_shap_plot(model, vectorizer, text, class_index, class_name): | |
| X_input = vectorizer.transform([text]) | |
| explainer = shap.Explainer(model, vectorizer.transform([" ".join(text.split()[:50])])) | |
| shap_values = explainer(X_input) | |
| shap_for_class = shap_values.values[0][:, class_index] | |
| feature_names = np.array(vectorizer.get_feature_names_out()) | |
| top_idx = np.argsort(-np.abs(shap_for_class))[:10] | |
| top_words = feature_names[top_idx] | |
| top_impacts = shap_for_class[top_idx] | |
| plt.figure(figsize=(6, 3)) | |
| colors = ["crimson" if v > 0 else "steelblue" for v in top_impacts] | |
| plt.barh(top_words, top_impacts, color=colors) | |
| plt.title(f"Top Words driving {class_name} prediction") | |
| plt.xlabel("SHAP Value (Impact)") | |
| plt.gca().invert_yaxis() | |
| plt.tight_layout() | |
| plt.savefig("shap_plot.png", bbox_inches='tight') | |
| plt.close() | |
| return top_words.tolist(), "shap_plot.png" | |
| # ------------------------------------------------------------ | |
| # 3️⃣ Prediction + Interpretability Function | |
| # ------------------------------------------------------------ | |
| def predict_sentiment(text, language): | |
| if not text.strip(): | |
| return "Please enter a comment.", None | |
| if language == "English": | |
| model, vectorizer, labels = english_model, english_vectorizer, english_labels | |
| else: | |
| model, vectorizer, labels = persian_model, persian_vectorizer, persian_labels | |
| X_input = vectorizer.transform([text]) | |
| probs = model.predict_proba(X_input)[0] | |
| pred_idx = np.argmax(probs) | |
| pred_class = labels[pred_idx] | |
| conf = probs[pred_idx] | |
| # SHAP interpretation | |
| top_words, shap_plot = get_shap_plot(model, vectorizer, text, pred_idx, pred_class) | |
| # Final output | |
| explanation = f""" | |
| **Predicted Sentiment:** {pred_class} | |
| **Confidence:** {conf:.2f} | |
| **Top Influential Words:** {', '.join(top_words)} | |
| """ | |
| return explanation, shap_plot | |
| # ------------------------------------------------------------ | |
| # 4️⃣ Gradio Interface | |
| # ------------------------------------------------------------ | |
| title = "🌐 Multi-Lingual Sentiment Analysis (English + Persian)" | |
| description = """ | |
| Select a language, type a comment, and see both the sentiment prediction and SHAP interpretability. | |
| """ | |
| examples = [ | |
| ["I love this product! Highly recommend.", "English"], | |
| ["Worst experience ever, totally disappointed.", "English"], | |
| ["The service was okay, nothing special.", "English"], | |
| ["این محصول فوقالعاده است", "Persian"], | |
| ["تجربهی بدی بود، ناراضیام", "Persian"], | |
| ["کیفیتش متوسط بود", "Persian"] | |
| ] | |
| demo = gr.Interface( | |
| fn=predict_sentiment, | |
| inputs=[ | |
| gr.Textbox(lines=3, label="Enter comment"), | |
| gr.Radio(["English", "Persian"], label="Choose Dataset/Language", value="English") | |
| ], | |
| outputs=[ | |
| gr.Markdown(label="Prediction & Explanation"), | |
| gr.Image(label="Top Word Contributions") | |
| ], | |
| title=title, | |
| description=description, | |
| examples=examples, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |