Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
|
@@ -1,63 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
import numpy as np
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from lime.lime_text import LimeTextExplainer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
# --- Load Saved Model and Tokenizer ---
|
| 8 |
-
MODEL_PATH = './roberta-depression-classifier/'
|
| 9 |
tokenizer = RobertaTokenizer.from_pretrained(MODEL_PATH)
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
model.eval() # Set model to evaluation mode
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
#
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
#
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
def classify_and_explain(text):
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
text,
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
-
# ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
iface = gr.Interface(
|
| 48 |
fn=classify_and_explain,
|
| 49 |
inputs=gr.Textbox(lines=5, label="Enter Text for Analysis", placeholder="I've been feeling so alone and empty lately..."),
|
| 50 |
outputs=[
|
| 51 |
gr.Label(label="Prediction Probabilities"),
|
| 52 |
gr.HighlightedText(
|
| 53 |
-
label="Explanation (
|
| 54 |
-
color_map={"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
)
|
| 56 |
],
|
| 57 |
-
title="🔬 RoBERTa Depression
|
| 58 |
-
description="This tool uses a fine-tuned RoBERTa model to classify text
|
| 59 |
-
examples=[
|
|
|
|
|
|
|
|
|
|
| 60 |
)
|
| 61 |
|
| 62 |
if __name__ == "__main__":
|
| 63 |
-
iface.launch()
|
|
|
|
| 1 |
+
# import gradio as gr
|
| 2 |
+
# import torch
|
| 3 |
+
# import numpy as np
|
| 4 |
+
# from transformers import RobertaTokenizer, RobertaForSequenceClassification
|
| 5 |
+
# from lime.lime_text import LimeTextExplainer
|
| 6 |
+
|
| 7 |
+
# # --- Load Saved Model and Tokenizer ---
|
| 8 |
+
# MODEL_PATH = './roberta-depression-classifier/'
|
| 9 |
+
# tokenizer = RobertaTokenizer.from_pretrained(MODEL_PATH)
|
| 10 |
+
# model = RobertaForSequenceClassification.from_pretrained(MODEL_PATH)
|
| 11 |
+
# model.eval() # Set model to evaluation mode
|
| 12 |
+
|
| 13 |
+
# # --- Define Labels and Explainer ---
|
| 14 |
+
# CLASS_NAMES = ['no depression', 'moderate depression', 'severe depression', 'suicidal']
|
| 15 |
+
# explainer = LimeTextExplainer(class_names=CLASS_NAMES)
|
| 16 |
+
|
| 17 |
+
# # --- Create a Prediction Function for LIME ---
|
| 18 |
+
# def predictor(texts):
|
| 19 |
+
# inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=256)
|
| 20 |
+
# with torch.no_grad():
|
| 21 |
+
# logits = model(**inputs).logits
|
| 22 |
+
# # Convert logits to probabilities
|
| 23 |
+
# probs = torch.nn.functional.softmax(logits, dim=-1).detach().numpy()
|
| 24 |
+
# return probs
|
| 25 |
+
|
| 26 |
+
# # --- Main Function for Gradio Interface ---
|
| 27 |
+
# def classify_and_explain(text):
|
| 28 |
+
# # Get prediction probabilities
|
| 29 |
+
# prediction_probs = predictor([text])[0]
|
| 30 |
+
# # Get the index of the highest probability
|
| 31 |
+
# prediction_index = np.argmax(prediction_probs)
|
| 32 |
+
|
| 33 |
+
# # Generate LIME explanation for the top predicted class
|
| 34 |
+
# explanation = explainer.explain_instance(
|
| 35 |
+
# text,
|
| 36 |
+
# predictor,
|
| 37 |
+
# num_features=10, # Show top 10 most influential words
|
| 38 |
+
# labels=(prediction_index,)
|
| 39 |
+
# )
|
| 40 |
+
|
| 41 |
+
# # Format the explanation for Gradio's HighlightedText component
|
| 42 |
+
# highlighted_words = explanation.as_list(label=prediction_index)
|
| 43 |
+
|
| 44 |
+
# return {CLASS_NAMES[i]: float(prob) for i, prob in enumerate(prediction_probs)}, highlighted_words
|
| 45 |
+
|
| 46 |
+
# # --- Create and Launch the Gradio Interface ---
|
| 47 |
+
# iface = gr.Interface(
|
| 48 |
+
# fn=classify_and_explain,
|
| 49 |
+
# inputs=gr.Textbox(lines=5, label="Enter Text for Analysis", placeholder="I've been feeling so alone and empty lately..."),
|
| 50 |
+
# outputs=[
|
| 51 |
+
# gr.Label(label="Prediction Probabilities"),
|
| 52 |
+
# gr.HighlightedText(
|
| 53 |
+
# label="Explanation (Word Importance)",
|
| 54 |
+
# color_map={"POS": "green", "NEG": "red"} # Words supporting/contradicting the prediction
|
| 55 |
+
# )
|
| 56 |
+
# ],
|
| 57 |
+
# title="🔬 RoBERTa Depression Severity Classifier & Explainer",
|
| 58 |
+
# description="This tool uses a fine-tuned RoBERTa model to classify text into four depression categories. It also uses LIME to highlight the words that most influenced the prediction.",
|
| 59 |
+
# examples=[["I have been feeling down and hopeless for weeks. Nothing brings me joy anymore."]]
|
| 60 |
+
# )
|
| 61 |
+
|
| 62 |
+
# if __name__ == "__main__":
|
| 63 |
+
# iface.launch()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# ==============================================================================
|
| 68 |
+
# APP.PY - DEPRESSION CLASSIFIER WITH LIME & SHAP EXPLAINABILITY
|
| 69 |
+
# ==============================================================================
|
| 70 |
import gradio as gr
|
| 71 |
import torch
|
| 72 |
import numpy as np
|
| 73 |
+
import pandas as pd
|
| 74 |
+
from transformers import (
|
| 75 |
+
RobertaTokenizer,
|
| 76 |
+
RobertaForSequenceClassification,
|
| 77 |
+
pipeline
|
| 78 |
+
)
|
| 79 |
from lime.lime_text import LimeTextExplainer
|
| 80 |
+
import shap
|
| 81 |
+
import warnings
|
| 82 |
+
import os # <-- Added os module to handle file paths
|
| 83 |
+
import traceback # <-- Added for detailed error logging
|
| 84 |
+
|
| 85 |
+
# Suppress warnings for cleaner output
|
| 86 |
+
warnings.filterwarnings("ignore")
|
| 87 |
+
|
| 88 |
+
# --- 1. Load Saved Model and Tokenizer ---
|
| 89 |
+
print("Loading fine-tuned RoBERTa model and tokenizer...")
|
| 90 |
+
|
| 91 |
+
# --- FIX: Create a robust, absolute path to the model directory ---
|
| 92 |
+
# This ensures the script finds the model folder correctly. It assumes the
|
| 93 |
+
# model folder is in the same directory as this app.py script.
|
| 94 |
+
try:
|
| 95 |
+
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
| 96 |
+
except NameError:
|
| 97 |
+
# This handles the case where the script is run in an interactive environment like a notebook
|
| 98 |
+
SCRIPT_DIR = os.getcwd()
|
| 99 |
+
|
| 100 |
+
MODEL_PATH = os.path.join(SCRIPT_DIR, 'roberta-depression-classifier')
|
| 101 |
+
|
| 102 |
+
# --- NEW: Add a check to ensure the model directory exists ---
|
| 103 |
+
if not os.path.isdir(MODEL_PATH):
|
| 104 |
+
raise OSError(
|
| 105 |
+
f"Model directory not found at the calculated path: {MODEL_PATH}\n"
|
| 106 |
+
f"Please make sure the 'roberta-base-finetuned-depression' folder, "
|
| 107 |
+
f"containing your trained model files, is in the same directory as this app.py script."
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# --- Define Global Variables ---
|
| 111 |
+
CLASS_NAMES = ['no depression', 'moderate depression', 'severe depression', 'suicidal']
|
| 112 |
+
label2id = {label: i for i, label in enumerate(CLASS_NAMES)}
|
| 113 |
+
id2label = {i: label for i, label in enumerate(CLASS_NAMES)} #<-- FIX: Define id2label mapping
|
| 114 |
|
|
|
|
|
|
|
| 115 |
tokenizer = RobertaTokenizer.from_pretrained(MODEL_PATH)
|
| 116 |
+
|
| 117 |
+
# --- FIX: Load the model WITH the correct label mappings ---
|
| 118 |
+
# This is the key change. By passing id2label and label2id, we ensure the
|
| 119 |
+
# model's config is correct, and the pipeline will output the proper string labels.
|
| 120 |
+
model = RobertaForSequenceClassification.from_pretrained(
|
| 121 |
+
MODEL_PATH,
|
| 122 |
+
id2label=id2label,
|
| 123 |
+
label2id=label2id
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
model.eval() # Set model to evaluation mode
|
| 127 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 128 |
+
model.to(device)
|
| 129 |
+
print("Model loaded successfully.")
|
| 130 |
|
| 131 |
+
# ==============================================================================
|
| 132 |
+
# NEW: SETUP FOR SHAP EXPLAINABILITY
|
| 133 |
+
# ==============================================================================
|
| 134 |
+
# The SHAP library works best with the Hugging Face `pipeline` object.
|
| 135 |
+
# This pipeline handles tokenization, prediction, and moving data to the GPU for us.
|
| 136 |
+
print("Creating Hugging Face pipeline for SHAP...")
|
| 137 |
+
classifier_pipeline = pipeline(
|
| 138 |
+
"text-classification",
|
| 139 |
+
model=model,
|
| 140 |
+
tokenizer=tokenizer,
|
| 141 |
+
device=0 if torch.cuda.is_available() else -1, # Use GPU if available
|
| 142 |
+
return_all_scores=True
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Create the SHAP explainer using the pipeline.
|
| 146 |
+
# The 'text' explainer is optimized for NLP models.
|
| 147 |
+
print("Creating SHAP explainer...")
|
| 148 |
+
explainer_shap = shap.Explainer(classifier_pipeline)
|
| 149 |
+
print("SHAP is ready.")
|
| 150 |
+
|
| 151 |
+
# ==============================================================================
|
| 152 |
+
# SETUP FOR LIME EXPLAINABILITY (Your existing code)
|
| 153 |
+
# ==============================================================================
|
| 154 |
+
print("Creating LIME explainer...")
|
| 155 |
+
explainer_lime = LimeTextExplainer(class_names=CLASS_NAMES)
|
| 156 |
+
|
| 157 |
+
# Create a prediction function that LIME can use.
|
| 158 |
+
# It needs to return a numpy array of probabilities for each class.
|
| 159 |
+
def predictor_for_lime(texts):
|
| 160 |
+
# Use the pipeline we already created for SHAP for consistency.
|
| 161 |
+
predictions = classifier_pipeline(texts, padding=True, truncation=True, max_length=512)
|
| 162 |
+
probs = []
|
| 163 |
+
for prediction_set in predictions:
|
| 164 |
+
# Sort results to ensure the order of probabilities matches CLASS_NAMES
|
| 165 |
+
sorted_preds = sorted(prediction_set, key=lambda x: label2id[x['label']])
|
| 166 |
+
probs.append([p['score'] for p in sorted_preds])
|
| 167 |
+
return np.array(probs)
|
| 168 |
+
print("LIME is ready.")
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# --- 3. Main Function for Gradio Interface (UPDATED) ---
|
| 172 |
def classify_and_explain(text):
|
| 173 |
+
"""
|
| 174 |
+
This function now performs classification and generates explanations
|
| 175 |
+
from BOTH LIME and SHAP, with added error handling for debugging.
|
| 176 |
+
"""
|
| 177 |
+
if not text or not text.strip():
|
| 178 |
+
# Handle empty input gracefully
|
| 179 |
+
empty_probs = {label: 0.0 for label in CLASS_NAMES}
|
| 180 |
+
return empty_probs, [("Enter text to see explanation.", 0)], [("Enter text to see explanation.", 0)]
|
| 181 |
+
|
| 182 |
+
try:
|
| 183 |
+
# --- A. Get Prediction ---
|
| 184 |
+
prediction_results = classifier_pipeline(text)[0]
|
| 185 |
+
sorted_preds = sorted(prediction_results, key=lambda x: label2id[x['label']])
|
| 186 |
+
prediction_probs_dict = {p['label']: p['score'] for p in sorted_preds}
|
| 187 |
+
prediction_index = np.argmax([p['score'] for p in sorted_preds])
|
| 188 |
+
predicted_class_name = CLASS_NAMES[prediction_index]
|
| 189 |
+
except Exception as e:
|
| 190 |
+
print("--- ERROR DURING PREDICTION ---")
|
| 191 |
+
traceback.print_exc()
|
| 192 |
+
raise gr.Error(f"Failed during prediction: {e}")
|
| 193 |
+
|
| 194 |
+
# --- B. Generate LIME Explanation ---
|
| 195 |
+
try:
|
| 196 |
+
lime_exp = explainer_lime.explain_instance(
|
| 197 |
+
text,
|
| 198 |
+
predictor_for_lime,
|
| 199 |
+
num_features=10,
|
| 200 |
+
labels=(prediction_index,)
|
| 201 |
+
)
|
| 202 |
+
lime_highlighted = lime_exp.as_list(label=prediction_index)
|
| 203 |
+
except Exception as e:
|
| 204 |
+
print("--- ERROR DURING LIME EXPLANATION ---")
|
| 205 |
+
traceback.print_exc()
|
| 206 |
+
lime_highlighted = [("LIME failed to generate.", 0)]
|
| 207 |
|
| 208 |
+
# --- C. Generate SHAP Explanation ---
|
| 209 |
+
try:
|
| 210 |
+
shap_values = explainer_shap([text])
|
| 211 |
+
|
| 212 |
+
# --- FINAL FIX: Definitive logic for merging subword tokens ---
|
| 213 |
+
# This new approach directly uses SHAP's internal grouping to avoid manual errors.
|
| 214 |
+
shap_explanation_for_pred_class = None
|
| 215 |
+
for i, label in enumerate(CLASS_NAMES):
|
| 216 |
+
if label == predicted_class_name:
|
| 217 |
+
# We use the cohort's data and values which are already grouped correctly
|
| 218 |
+
tokens = shap_values.cohorts(1).data[0,:,i]
|
| 219 |
+
values = shap_values.cohorts(1).values[0,:,i]
|
| 220 |
+
|
| 221 |
+
# Combine tokens and values, then format for Gradio
|
| 222 |
+
word_attributions = []
|
| 223 |
+
for token, value in zip(tokens, values):
|
| 224 |
+
if token not in [tokenizer.bos_token, tokenizer.eos_token, tokenizer.sep_token, tokenizer.pad_token]:
|
| 225 |
+
word_attributions.append((token, value))
|
| 226 |
+
|
| 227 |
+
# Sort by absolute importance and take top 10 for display
|
| 228 |
+
word_attributions.sort(key=lambda x: abs(x[1]), reverse=True)
|
| 229 |
+
shap_highlighted = word_attributions[:10]
|
| 230 |
+
break
|
| 231 |
+
|
| 232 |
+
if shap_explanation_for_pred_class is None:
|
| 233 |
+
shap_highlighted = [("SHAP data not found for class.", 0)]
|
| 234 |
+
|
| 235 |
+
except Exception as e:
|
| 236 |
+
print("--- ERROR DURING SHAP EXPLANATION ---")
|
| 237 |
+
traceback.print_exc()
|
| 238 |
+
shap_highlighted = [("SHAP failed to generate.", 0)]
|
| 239 |
+
|
| 240 |
+
return prediction_probs_dict, lime_highlighted, shap_highlighted
|
| 241 |
+
|
| 242 |
+
# --- 4. Create and Launch the Gradio Interface (UPDATED) ---
|
| 243 |
iface = gr.Interface(
|
| 244 |
fn=classify_and_explain,
|
| 245 |
inputs=gr.Textbox(lines=5, label="Enter Text for Analysis", placeholder="I've been feeling so alone and empty lately..."),
|
| 246 |
outputs=[
|
| 247 |
gr.Label(label="Prediction Probabilities"),
|
| 248 |
gr.HighlightedText(
|
| 249 |
+
label="LIME Explanation (Local Surrogate)",
|
| 250 |
+
color_map={"POSITIVE": "green", "NEGATIVE": "red"}
|
| 251 |
+
),
|
| 252 |
+
gr.HighlightedText(
|
| 253 |
+
label="SHAP Explanation (Game-Theoretic Attribution)",
|
| 254 |
+
color_map={"POSITIVE": "blue", "NEGATIVE": "orange"}
|
| 255 |
)
|
| 256 |
],
|
| 257 |
+
title="🔬 RoBERTa Depression Classifier with LIME & SHAP",
|
| 258 |
+
description="This tool uses a fine-tuned RoBERTa model to classify text and provides two state-of-the-art explanations. LIME approximates the model locally, while SHAP provides theoretically grounded contribution scores for each word.",
|
| 259 |
+
examples=[
|
| 260 |
+
["I have been feeling down and hopeless for weeks. Nothing brings me joy anymore."],
|
| 261 |
+
["It all feels so pointless. I've been thinking about whether it's even worth being here anymore."]
|
| 262 |
+
]
|
| 263 |
)
|
| 264 |
|
| 265 |
if __name__ == "__main__":
|
| 266 |
+
iface.launch()
|