|
|
import gradio as gr |
|
|
import lime |
|
|
from lime.lime_text import LimeTextExplainer |
|
|
import numpy as np |
|
|
from datasets import load_dataset |
|
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
|
from sklearn.model_selection import train_test_split |
|
|
from sklearn.linear_model import LogisticRegression |
|
|
from sklearn.pipeline import make_pipeline |
|
|
import shap |
|
|
import matplotlib.pyplot as plt |
|
|
import io |
|
|
from PIL import Image |
|
|
import pandas as pd |
|
|
|
|
|
|
|
|
|
|
|
dataset = load_dataset('imdb') |
|
|
|
|
|
|
|
|
text_train = [review['text'] for review in dataset['train']] |
|
|
y_train = [review['label'] for review in dataset['train']] |
|
|
text_test = [review['text'] for review in dataset['test']] |
|
|
y_test = [review['label'] for review in dataset['test']] |
|
|
|
|
|
|
|
|
vectorizer = TfidfVectorizer(stop_words='english', max_features=5000) |
|
|
X_train = vectorizer.fit_transform(text_train) |
|
|
X_test = vectorizer.transform(text_test) |
|
|
|
|
|
|
|
|
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42) |
|
|
|
|
|
|
|
|
model = LogisticRegression(max_iter=1000) |
|
|
model.fit(X_train, y_train) |
|
|
|
|
|
|
|
|
lime_explainer = LimeTextExplainer(class_names=['Negative', 'Positive']) |
|
|
|
|
|
|
|
|
shap_explainer = shap.LinearExplainer(model, X_train) |
|
|
|
|
|
def explain_text(input_text): |
|
|
|
|
|
input_vector = vectorizer.transform([input_text]) |
|
|
predicted_label = model.predict(input_vector)[0] |
|
|
label_name = 'Positive' if predicted_label == 1 else 'Negative' |
|
|
|
|
|
|
|
|
def predict_proba_for_lime(texts): |
|
|
return model.predict_proba(vectorizer.transform(texts)) |
|
|
|
|
|
lime_exp = lime_explainer.explain_instance(input_text, predict_proba_for_lime, num_features=10) |
|
|
lime_fig = lime_exp.as_pyplot_figure() |
|
|
lime_img = fig_to_nparray(lime_fig) |
|
|
|
|
|
|
|
|
lime_html = lime_exp.as_html() |
|
|
|
|
|
|
|
|
shap_values = shap_explainer.shap_values(input_vector)[0] |
|
|
feature_names = vectorizer.get_feature_names_out() |
|
|
|
|
|
|
|
|
shap_explanation = shap.Explanation( |
|
|
values=shap_values, |
|
|
base_values=shap_explainer.expected_value, |
|
|
feature_names=feature_names, |
|
|
data=input_vector.toarray()[0] |
|
|
) |
|
|
|
|
|
|
|
|
def highlight_text_shap(text, word_importances, feature_names, max_num_features): |
|
|
words = text.split() |
|
|
word_to_importance = {} |
|
|
for idx, word in enumerate(feature_names): |
|
|
if word in text.lower(): |
|
|
word_to_importance[word] = word_importances[idx] |
|
|
|
|
|
sorted_word_importance = sorted(word_to_importance.items(), key=lambda x: abs(x[1]), reverse=True)[:max_num_features] |
|
|
top_words = {word: importance for word, importance in sorted_word_importance} |
|
|
|
|
|
highlighted_text = [] |
|
|
for word in words: |
|
|
cleaned_word = ''.join(filter(str.isalnum, word)).lower() |
|
|
if cleaned_word in top_words: |
|
|
importance = top_words[cleaned_word] |
|
|
color = 'red' if importance > 0 else 'blue' |
|
|
highlighted_text.append(f'<span style="color:{color}">{word}</span>') |
|
|
else: |
|
|
highlighted_text.append(word) |
|
|
|
|
|
return ' '.join(highlighted_text) |
|
|
|
|
|
|
|
|
max_num_features = 10 |
|
|
|
|
|
|
|
|
shap_df = pd.DataFrame({ |
|
|
'Feature': shap_explanation.feature_names, |
|
|
'SHAP Value': shap_explanation.values |
|
|
}).sort_values(by='SHAP Value', ascending=False).head(max_num_features) |
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
|
plt.barh(shap_df['Feature'], shap_df['SHAP Value'], color=['red' if val > 0 else 'blue' for val in shap_df['SHAP Value']]) |
|
|
plt.xlabel('SHAP Value') |
|
|
plt.title('Top 10 Feature Importance') |
|
|
plt.tight_layout() |
|
|
shap_fig = fig_to_nparray(plt.gcf()) |
|
|
|
|
|
|
|
|
shap_highlighted_text = highlight_text_shap(input_text, shap_values, feature_names, max_num_features) |
|
|
|
|
|
return label_name, lime_img, shap_fig, lime_html, shap_highlighted_text |
|
|
|
|
|
def fig_to_nparray(fig): |
|
|
"""Convert a matplotlib figure to a NumPy array.""" |
|
|
buf = io.BytesIO() |
|
|
fig.savefig(buf, format='png') |
|
|
buf.seek(0) |
|
|
img = Image.open(buf) |
|
|
return np.array(img) |
|
|
|
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=explain_text, |
|
|
inputs=gr.Textbox(lines=2, placeholder="Enter your text here..."), |
|
|
outputs=[ |
|
|
gr.Label(label="Predicted Label"), |
|
|
gr.Image(type="numpy", label="LIME Explanation"), |
|
|
gr.Image(type="numpy", label="SHAP Explanation"), |
|
|
gr.HTML(label="LIME Highlighted Text Explanation"), |
|
|
gr.HTML(label="SHAP Highlighted Text Explanation"), |
|
|
], |
|
|
title="LIME and SHAP Explanations", |
|
|
description="Enter a text sample to see its prediction and explanations using LIME and SHAP." |
|
|
) |
|
|
|
|
|
|
|
|
iface.launch() |