Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import shap | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| # Load model and tokenizer with caching | |
| def load_model(): | |
| tokenizer = AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment") | |
| model = AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment") | |
| return tokenizer, model | |
| tokenizer, model = load_model() | |
| # Define prediction function | |
| def predict(texts): | |
| processed_texts = [] | |
| for text in texts: | |
| processed_texts.append(text if not isinstance(text, list) | |
| else tokenizer.convert_tokens_to_string(text)) | |
| inputs = tokenizer( | |
| processed_texts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| add_special_tokens=True | |
| ) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| return torch.nn.functional.softmax(outputs.logits, dim=-1).numpy() | |
| # Initialize SHAP components | |
| output_names = [model.config.id2label[i] for i in range(5)] | |
| masker = shap.maskers.Text(tokenizer=tokenizer, mask_token=tokenizer.mask_token, collapse_mask_token=True) | |
| explainer = shap.Explainer(predict, masker, output_names=output_names) | |
| # Streamlit UI | |
| st.title("π― BERT Sentiment Analysis with SHAP") | |
| st.markdown(""" | |
| **How it works:** | |
| 1. Enter text in the box below | |
| 2. See predicted sentiment (1-5 stars) | |
| 3. View confidence scores and word-level explanations | |
| """) | |
| text_input = st.text_area("Input Text", placeholder="Enter text to analyze...", height=100) | |
| if st.button("Analyze Sentiment"): | |
| if text_input.strip(): | |
| with st.spinner("Analyzing..."): | |
| # Get predictions | |
| probabilities = predict([text_input])[0] | |
| predicted_class = np.argmax(probabilities) | |
| # Display results | |
| st.subheader("π Results") | |
| cols = st.columns(2) | |
| cols[0].metric("Predicted Sentiment", output_names[predicted_class]) | |
| with cols[1]: | |
| st.markdown("**Confidence Scores**") | |
| for label, score in zip(output_names, probabilities): | |
| st.progress(float(score), text=f"{label}: {score:.1%}") | |
| # Generate SHAP explanations | |
| st.subheader("π Explanation") | |
| st.markdown(""" | |
| **Feature importance (word-level impacts)** | |
| π΄ Higher positive values β Increases sentiment | |
| π΅ Lower negative values β Decreases sentiment | |
| """) | |
| # Get SHAP values for the input text | |
| shap_values = explainer([text_input]) | |
| # Create tabs for each sentiment class | |
| tabs = st.tabs(output_names) | |
| for i, tab in enumerate(tabs): | |
| with tab: | |
| # Extract the values and corresponding tokens for our single example. | |
| # shap_values is of shape (1, num_tokens, num_classes) | |
| values = shap_values.values[0, :, i] # SHAP values for class i | |
| tokens = shap_values.data[0] # Tokenized words | |
| # Create a DataFrame to sort and plot the tokens by importance | |
| df = pd.DataFrame({"token": tokens, "shap_value": values}) | |
| # Sort tokens by the absolute SHAP value (smallest at the bottom for horizontal bar plot) | |
| df = df.sort_values("shap_value", key=lambda x: np.abs(x), ascending=True) | |
| # Create a horizontal bar plot | |
| fig, ax = plt.subplots(figsize=(8, max(4, len(tokens) * 0.3))) | |
| ax.barh(df["token"], df["shap_value"], color='skyblue') | |
| ax.set_xlabel("SHAP value") | |
| ax.set_title(f"SHAP bar plot for class '{output_names[i]}'") | |
| st.pyplot(fig) | |
| plt.close(fig) | |
| else: | |
| st.warning("Please enter some text to analyze") | |
| st.markdown("---") | |
| st.markdown("Example texts to try:") | |
| examples = st.columns(4) | |
| example_texts = [ | |
| "This product exceeded all my expectations!", | |
| "Terrible customer service experience.", | |
| "The movie was okay, nothing special.", | |
| "You are kinda cool" | |
| ] | |
| for col, text in zip(examples, example_texts): | |
| with col: | |
| if st.button(text, use_container_width=True): | |
| st.session_state.last_input = text | |
| if 'last_input' in st.session_state: | |
| text_input = st.text_area("", value=st.session_state.last_input, height=100) | |