fahin-one's picture
Upload app.py
1e6ce2d verified
# import gradio as gr
# import torch
# import numpy as np
# from transformers import RobertaTokenizer, RobertaForSequenceClassification
# from lime.lime_text import LimeTextExplainer
# # --- Load Saved Model and Tokenizer ---
# MODEL_PATH = './roberta-depression-classifier/'
# tokenizer = RobertaTokenizer.from_pretrained(MODEL_PATH)
# model = RobertaForSequenceClassification.from_pretrained(MODEL_PATH)
# model.eval() # Set model to evaluation mode
# # --- Define Labels and Explainer ---
# CLASS_NAMES = ['no depression', 'moderate depression', 'severe depression', 'suicidal']
# explainer = LimeTextExplainer(class_names=CLASS_NAMES)
# # --- Create a Prediction Function for LIME ---
# def predictor(texts):
# inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=256)
# with torch.no_grad():
# logits = model(**inputs).logits
# # Convert logits to probabilities
# probs = torch.nn.functional.softmax(logits, dim=-1).detach().numpy()
# return probs
# # --- Main Function for Gradio Interface ---
# def classify_and_explain(text):
# # Get prediction probabilities
# prediction_probs = predictor([text])[0]
# # Get the index of the highest probability
# prediction_index = np.argmax(prediction_probs)
# # Generate LIME explanation for the top predicted class
# explanation = explainer.explain_instance(
# text,
# predictor,
# num_features=10, # Show top 10 most influential words
# labels=(prediction_index,)
# )
# # Format the explanation for Gradio's HighlightedText component
# highlighted_words = explanation.as_list(label=prediction_index)
# return {CLASS_NAMES[i]: float(prob) for i, prob in enumerate(prediction_probs)}, highlighted_words
# # --- Create and Launch the Gradio Interface ---
# iface = gr.Interface(
# fn=classify_and_explain,
# inputs=gr.Textbox(lines=5, label="Enter Text for Analysis", placeholder="I've been feeling so alone and empty lately..."),
# outputs=[
# gr.Label(label="Prediction Probabilities"),
# gr.HighlightedText(
# label="Explanation (Word Importance)",
# color_map={"POS": "green", "NEG": "red"} # Words supporting/contradicting the prediction
# )
# ],
# title="🔬 RoBERTa Depression Severity Classifier & Explainer",
# description="This tool uses a fine-tuned RoBERTa model to classify text into four depression categories. It also uses LIME to highlight the words that most influenced the prediction.",
# examples=[["I have been feeling down and hopeless for weeks. Nothing brings me joy anymore."]]
# )
# if __name__ == "__main__":
# iface.launch()
# ==============================================================================
# APP.PY - DEPRESSION CLASSIFIER WITH LIME & SHAP EXPLAINABILITY
# ==============================================================================
import gradio as gr
import torch
import numpy as np
import pandas as pd
from transformers import (
RobertaTokenizer,
RobertaForSequenceClassification,
pipeline
)
from lime.lime_text import LimeTextExplainer
import shap
import warnings
import os # <-- Added os module to handle file paths
import traceback # <-- Added for detailed error logging
# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")
# --- 1. Load Saved Model and Tokenizer ---
print("Loading fine-tuned RoBERTa model and tokenizer...")
# --- FIX: Create a robust, absolute path to the model directory ---
# This ensures the script finds the model folder correctly. It assumes the
# model folder is in the same directory as this app.py script.
try:
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
except NameError:
# This handles the case where the script is run in an interactive environment like a notebook
SCRIPT_DIR = os.getcwd()
MODEL_PATH = os.path.join(SCRIPT_DIR, 'roberta-depression-classifier')
# --- NEW: Add a check to ensure the model directory exists ---
if not os.path.isdir(MODEL_PATH):
raise OSError(
f"Model directory not found at the calculated path: {MODEL_PATH}\n"
f"Please make sure the 'roberta-base-finetuned-depression' folder, "
f"containing your trained model files, is in the same directory as this app.py script."
)
# --- Define Global Variables ---
CLASS_NAMES = ['no depression', 'moderate depression', 'severe depression', 'suicidal']
label2id = {label: i for i, label in enumerate(CLASS_NAMES)}
id2label = {i: label for i, label in enumerate(CLASS_NAMES)} #<-- FIX: Define id2label mapping
tokenizer = RobertaTokenizer.from_pretrained(MODEL_PATH)
# --- FIX: Load the model WITH the correct label mappings ---
# This is the key change. By passing id2label and label2id, we ensure the
# model's config is correct, and the pipeline will output the proper string labels.
model = RobertaForSequenceClassification.from_pretrained(
MODEL_PATH,
id2label=id2label,
label2id=label2id
)
model.eval() # Set model to evaluation mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print("Model loaded successfully.")
# ==============================================================================
# NEW: SETUP FOR SHAP EXPLAINABILITY
# ==============================================================================
# The SHAP library works best with the Hugging Face `pipeline` object.
# This pipeline handles tokenization, prediction, and moving data to the GPU for us.
print("Creating Hugging Face pipeline for SHAP...")
classifier_pipeline = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1, # Use GPU if available
return_all_scores=True
)
# Create the SHAP explainer using the pipeline.
# The 'text' explainer is optimized for NLP models.
print("Creating SHAP explainer...")
explainer_shap = shap.Explainer(classifier_pipeline)
print("SHAP is ready.")
# ==============================================================================
# SETUP FOR LIME EXPLAINABILITY (Your existing code)
# ==============================================================================
print("Creating LIME explainer...")
explainer_lime = LimeTextExplainer(class_names=CLASS_NAMES)
# Create a prediction function that LIME can use.
# It needs to return a numpy array of probabilities for each class.
def predictor_for_lime(texts):
# Use the pipeline we already created for SHAP for consistency.
predictions = classifier_pipeline(texts, padding=True, truncation=True, max_length=512)
probs = []
for prediction_set in predictions:
# Sort results to ensure the order of probabilities matches CLASS_NAMES
sorted_preds = sorted(prediction_set, key=lambda x: label2id[x['label']])
probs.append([p['score'] for p in sorted_preds])
return np.array(probs)
print("LIME is ready.")
# --- 3. Main Function for Gradio Interface (UPDATED) ---
def classify_and_explain(text):
"""
This function now performs classification and generates explanations
from BOTH LIME and SHAP, with added error handling for debugging.
"""
if not text or not text.strip():
# Handle empty input gracefully
empty_probs = {label: 0.0 for label in CLASS_NAMES}
return empty_probs, [("Enter text to see explanation.", 0)], [("Enter text to see explanation.", 0)]
try:
# --- A. Get Prediction ---
prediction_results = classifier_pipeline(text)[0]
sorted_preds = sorted(prediction_results, key=lambda x: label2id[x['label']])
prediction_probs_dict = {p['label']: p['score'] for p in sorted_preds}
prediction_index = np.argmax([p['score'] for p in sorted_preds])
predicted_class_name = CLASS_NAMES[prediction_index]
except Exception as e:
print("--- ERROR DURING PREDICTION ---")
traceback.print_exc()
raise gr.Error(f"Failed during prediction: {e}")
# --- B. Generate LIME Explanation ---
try:
lime_exp = explainer_lime.explain_instance(
text,
predictor_for_lime,
num_features=10,
labels=(prediction_index,)
)
lime_highlighted = lime_exp.as_list(label=prediction_index)
except Exception as e:
print("--- ERROR DURING LIME EXPLANATION ---")
traceback.print_exc()
lime_highlighted = [("LIME failed to generate.", 0)]
# --- C. Generate SHAP Explanation ---
try:
shap_values = explainer_shap([text])
# --- FINAL FIX: Definitive logic for merging subword tokens ---
# This new approach directly uses SHAP's internal grouping to avoid manual errors.
shap_explanation_for_pred_class = None
for i, label in enumerate(CLASS_NAMES):
if label == predicted_class_name:
# We use the cohort's data and values which are already grouped correctly
tokens = shap_values.cohorts(1).data[0,:,i]
values = shap_values.cohorts(1).values[0,:,i]
# Combine tokens and values, then format for Gradio
word_attributions = []
for token, value in zip(tokens, values):
if token not in [tokenizer.bos_token, tokenizer.eos_token, tokenizer.sep_token, tokenizer.pad_token]:
word_attributions.append((token, value))
# Sort by absolute importance and take top 10 for display
word_attributions.sort(key=lambda x: abs(x[1]), reverse=True)
shap_highlighted = word_attributions[:10]
break
if shap_explanation_for_pred_class is None:
shap_highlighted = [("SHAP data not found for class.", 0)]
except Exception as e:
print("--- ERROR DURING SHAP EXPLANATION ---")
traceback.print_exc()
shap_highlighted = [("SHAP failed to generate.", 0)]
return prediction_probs_dict, lime_highlighted, shap_highlighted
# --- 4. Create and Launch the Gradio Interface (UPDATED) ---
iface = gr.Interface(
fn=classify_and_explain,
inputs=gr.Textbox(lines=5, label="Enter Text for Analysis", placeholder="I've been feeling so alone and empty lately..."),
outputs=[
gr.Label(label="Prediction Probabilities"),
gr.HighlightedText(
label="LIME Explanation (Local Surrogate)",
color_map={"POSITIVE": "green", "NEGATIVE": "red"}
),
gr.HighlightedText(
label="SHAP Explanation (Game-Theoretic Attribution)",
color_map={"POSITIVE": "blue", "NEGATIVE": "orange"}
)
],
title="🔬 RoBERTa Depression Classifier with LIME & SHAP",
description="This tool uses a fine-tuned RoBERTa model to classify text and provides two state-of-the-art explanations. LIME approximates the model locally, while SHAP provides theoretically grounded contribution scores for each word.",
examples=[
["I have been feeling down and hopeless for weeks. Nothing brings me joy anymore."],
["It all feels so pointless. I've been thinking about whether it's even worth being here anymore."]
]
)
if __name__ == "__main__":
iface.launch()