Spaces:
Runtime error
Runtime error
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import shap | |
| import torch | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| # Load model and tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment") | |
| model = AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment") | |
| # Define prediction function | |
| def predict(texts): | |
| processed_texts = [] | |
| for text in texts: | |
| if isinstance(text, list): | |
| processed_text = tokenizer.convert_tokens_to_string(text) | |
| else: | |
| processed_text = text | |
| processed_texts.append(processed_text) | |
| inputs = tokenizer( | |
| processed_texts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| add_special_tokens=True | |
| ) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| return probabilities.numpy() | |
| # Initialize SHAP components | |
| output_names_list = [model.config.id2label[i] for i in range(len(model.config.id2label))] | |
| masker = shap.maskers.Text(tokenizer=tokenizer, mask_token=tokenizer.mask_token, collapse_mask_token=True) | |
| explainer = shap.Explainer(model=predict, masker=masker, output_names=output_names_list) | |
| def analyze_text(text): | |
| # Get predictions | |
| probabilities = predict([text])[0] | |
| predicted_class = np.argmax(probabilities) | |
| predicted_label = model.config.id2label[predicted_class] | |
| # Generate SHAP explanations | |
| shap_values = explainer([text]) | |
| # Create HTML visualizations for all classes | |
| html_plots = [] | |
| for i in range(shap_values.shape[-1]): | |
| # Create SHAP text plot and convert to HTML | |
| plot_html = shap.plots.text(shap_values[0, :, i], display=False) | |
| html_plots.append(plot_html) | |
| # Format confidence scores | |
| confidence_scores = {model.config.id2label[i]: float(probabilities[i]) | |
| for i in range(len(probabilities))} | |
| return (predicted_label, | |
| confidence_scores, | |
| *html_plots) | |
| # Create Gradio interface with HTML components | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("## ๐ BERT Sentiment Analysis with SHAP Explanations") | |
| with gr.Row(): | |
| input_text = gr.Textbox(label="Input Text", placeholder="Enter text to analyze...") | |
| with gr.Row(): | |
| predict_btn = gr.Button("Analyze Sentiment") | |
| with gr.Row(): | |
| label_output = gr.Label(label="Predicted Sentiment") | |
| prob_output = gr.Label(label="Confidence Scores") | |
| with gr.Row(): | |
| gr.Markdown(""" | |
| ### SHAP Explanations | |
| Below you can see how each word contributes to different sentiment scores (1-5 stars). | |
| Red text increases the score, blue decreases it. | |
| """) | |
| # Individual Explanation Rows | |
| plot_components = [] | |
| for i in range(5): | |
| with gr.Row(): | |
| plot_components.append( | |
| gr.HTML( | |
| label=f"Explanation for {model.config.id2label[i]}", | |
| elem_classes=f"shap-plot-{i+1}" | |
| ) | |
| ) | |
| predict_btn.click( | |
| fn=analyze_text, | |
| inputs=input_text, | |
| outputs=[label_output, prob_output] + plot_components | |
| ) | |
| examples = gr.Examples( | |
| examples=[ | |
| ["This product exceeded all my expectations!"], | |
| ["Terrible customer service experience."], | |
| ["The movie was okay, nothing special."], | |
| ["You are kinda cool"], | |
| ], | |
| inputs=input_text | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(debug = True) |