import streamlit as st from transformers import AutoTokenizer, AutoModelForSequenceClassification import shap import torch import numpy as np import matplotlib.pyplot as plt import pandas as pd # Load model and tokenizer with caching @st.cache_resource def load_model(): tokenizer = AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment") model = AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment") return tokenizer, model tokenizer, model = load_model() # Define prediction function def predict(texts): processed_texts = [] for text in texts: processed_texts.append(text if not isinstance(text, list) else tokenizer.convert_tokens_to_string(text)) inputs = tokenizer( processed_texts, return_tensors="pt", padding=True, truncation=True, max_length=512, add_special_tokens=True ) with torch.no_grad(): outputs = model(**inputs) return torch.nn.functional.softmax(outputs.logits, dim=-1).numpy() # Initialize SHAP components output_names = [model.config.id2label[i] for i in range(5)] masker = shap.maskers.Text(tokenizer=tokenizer, mask_token=tokenizer.mask_token, collapse_mask_token=True) explainer = shap.Explainer(predict, masker, output_names=output_names) # Streamlit UI st.title("🎯 BERT Sentiment Analysis with SHAP") st.markdown(""" **How it works:** 1. Enter text in the box below 2. See predicted sentiment (1-5 stars) 3. View confidence scores and word-level explanations """) text_input = st.text_area("Input Text", placeholder="Enter text to analyze...", height=100) if st.button("Analyze Sentiment"): if text_input.strip(): with st.spinner("Analyzing..."): # Get predictions probabilities = predict([text_input])[0] predicted_class = np.argmax(probabilities) # Display results st.subheader("📊 Results") cols = st.columns(2) cols[0].metric("Predicted Sentiment", output_names[predicted_class]) with cols[1]: st.markdown("**Confidence Scores**") for label, score in zip(output_names, probabilities): st.progress(float(score), text=f"{label}: {score:.1%}") # Generate SHAP explanations st.subheader("🔍 Explanation") st.markdown(""" **Feature importance (word-level impacts)** 🔴 Higher positive values → Increases sentiment 🔵 Lower negative values → Decreases sentiment """) # Get SHAP values for the input text shap_values = explainer([text_input]) # Create tabs for each sentiment class tabs = st.tabs(output_names) for i, tab in enumerate(tabs): with tab: # Extract the values and corresponding tokens for our single example. # shap_values is of shape (1, num_tokens, num_classes) values = shap_values.values[0, :, i] # SHAP values for class i tokens = shap_values.data[0] # Tokenized words # Create a DataFrame to sort and plot the tokens by importance df = pd.DataFrame({"token": tokens, "shap_value": values}) # Sort tokens by the absolute SHAP value (smallest at the bottom for horizontal bar plot) df = df.sort_values("shap_value", key=lambda x: np.abs(x), ascending=True) # Create a horizontal bar plot fig, ax = plt.subplots(figsize=(8, max(4, len(tokens) * 0.3))) ax.barh(df["token"], df["shap_value"], color='skyblue') ax.set_xlabel("SHAP value") ax.set_title(f"SHAP bar plot for class '{output_names[i]}'") st.pyplot(fig) plt.close(fig) else: st.warning("Please enter some text to analyze") st.markdown("---") st.markdown("Example texts to try:") examples = st.columns(4) example_texts = [ "This product exceeded all my expectations!", "Terrible customer service experience.", "The movie was okay, nothing special.", "You are kinda cool" ] for col, text in zip(examples, example_texts): with col: if st.button(text, use_container_width=True): st.session_state.last_input = text if 'last_input' in st.session_state: text_input = st.text_area("", value=st.session_state.last_input, height=100)