# import gradio as gr # import torch # import numpy as np # from transformers import RobertaTokenizer, RobertaForSequenceClassification # from lime.lime_text import LimeTextExplainer # # --- Load Saved Model and Tokenizer --- # MODEL_PATH = './roberta-depression-classifier/' # tokenizer = RobertaTokenizer.from_pretrained(MODEL_PATH) # model = RobertaForSequenceClassification.from_pretrained(MODEL_PATH) # model.eval() # Set model to evaluation mode # # --- Define Labels and Explainer --- # CLASS_NAMES = ['no depression', 'moderate depression', 'severe depression', 'suicidal'] # explainer = LimeTextExplainer(class_names=CLASS_NAMES) # # --- Create a Prediction Function for LIME --- # def predictor(texts): # inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=256) # with torch.no_grad(): # logits = model(**inputs).logits # # Convert logits to probabilities # probs = torch.nn.functional.softmax(logits, dim=-1).detach().numpy() # return probs # # --- Main Function for Gradio Interface --- # def classify_and_explain(text): # # Get prediction probabilities # prediction_probs = predictor([text])[0] # # Get the index of the highest probability # prediction_index = np.argmax(prediction_probs) # # Generate LIME explanation for the top predicted class # explanation = explainer.explain_instance( # text, # predictor, # num_features=10, # Show top 10 most influential words # labels=(prediction_index,) # ) # # Format the explanation for Gradio's HighlightedText component # highlighted_words = explanation.as_list(label=prediction_index) # return {CLASS_NAMES[i]: float(prob) for i, prob in enumerate(prediction_probs)}, highlighted_words # # --- Create and Launch the Gradio Interface --- # iface = gr.Interface( # fn=classify_and_explain, # inputs=gr.Textbox(lines=5, label="Enter Text for Analysis", placeholder="I've been feeling so alone and empty lately..."), # outputs=[ # gr.Label(label="Prediction Probabilities"), # gr.HighlightedText( # label="Explanation (Word Importance)", # color_map={"POS": "green", "NEG": "red"} # Words supporting/contradicting the prediction # ) # ], # title="🔬 RoBERTa Depression Severity Classifier & Explainer", # description="This tool uses a fine-tuned RoBERTa model to classify text into four depression categories. It also uses LIME to highlight the words that most influenced the prediction.", # examples=[["I have been feeling down and hopeless for weeks. Nothing brings me joy anymore."]] # ) # if __name__ == "__main__": # iface.launch() # ============================================================================== # APP.PY - DEPRESSION CLASSIFIER WITH LIME & SHAP EXPLAINABILITY # ============================================================================== import gradio as gr import torch import numpy as np import pandas as pd from transformers import ( RobertaTokenizer, RobertaForSequenceClassification, pipeline ) from lime.lime_text import LimeTextExplainer import shap import warnings import os # <-- Added os module to handle file paths import traceback # <-- Added for detailed error logging # Suppress warnings for cleaner output warnings.filterwarnings("ignore") # --- 1. Load Saved Model and Tokenizer --- print("Loading fine-tuned RoBERTa model and tokenizer...") # --- FIX: Create a robust, absolute path to the model directory --- # This ensures the script finds the model folder correctly. It assumes the # model folder is in the same directory as this app.py script. try: SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) except NameError: # This handles the case where the script is run in an interactive environment like a notebook SCRIPT_DIR = os.getcwd() MODEL_PATH = os.path.join(SCRIPT_DIR, 'roberta-depression-classifier') # --- NEW: Add a check to ensure the model directory exists --- if not os.path.isdir(MODEL_PATH): raise OSError( f"Model directory not found at the calculated path: {MODEL_PATH}\n" f"Please make sure the 'roberta-base-finetuned-depression' folder, " f"containing your trained model files, is in the same directory as this app.py script." ) # --- Define Global Variables --- CLASS_NAMES = ['no depression', 'moderate depression', 'severe depression', 'suicidal'] label2id = {label: i for i, label in enumerate(CLASS_NAMES)} id2label = {i: label for i, label in enumerate(CLASS_NAMES)} #<-- FIX: Define id2label mapping tokenizer = RobertaTokenizer.from_pretrained(MODEL_PATH) # --- FIX: Load the model WITH the correct label mappings --- # This is the key change. By passing id2label and label2id, we ensure the # model's config is correct, and the pipeline will output the proper string labels. model = RobertaForSequenceClassification.from_pretrained( MODEL_PATH, id2label=id2label, label2id=label2id ) model.eval() # Set model to evaluation mode device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) print("Model loaded successfully.") # ============================================================================== # NEW: SETUP FOR SHAP EXPLAINABILITY # ============================================================================== # The SHAP library works best with the Hugging Face `pipeline` object. # This pipeline handles tokenization, prediction, and moving data to the GPU for us. print("Creating Hugging Face pipeline for SHAP...") classifier_pipeline = pipeline( "text-classification", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1, # Use GPU if available return_all_scores=True ) # Create the SHAP explainer using the pipeline. # The 'text' explainer is optimized for NLP models. print("Creating SHAP explainer...") explainer_shap = shap.Explainer(classifier_pipeline) print("SHAP is ready.") # ============================================================================== # SETUP FOR LIME EXPLAINABILITY (Your existing code) # ============================================================================== print("Creating LIME explainer...") explainer_lime = LimeTextExplainer(class_names=CLASS_NAMES) # Create a prediction function that LIME can use. # It needs to return a numpy array of probabilities for each class. def predictor_for_lime(texts): # Use the pipeline we already created for SHAP for consistency. predictions = classifier_pipeline(texts, padding=True, truncation=True, max_length=512) probs = [] for prediction_set in predictions: # Sort results to ensure the order of probabilities matches CLASS_NAMES sorted_preds = sorted(prediction_set, key=lambda x: label2id[x['label']]) probs.append([p['score'] for p in sorted_preds]) return np.array(probs) print("LIME is ready.") # --- 3. Main Function for Gradio Interface (UPDATED) --- def classify_and_explain(text): """ This function now performs classification and generates explanations from BOTH LIME and SHAP, with added error handling for debugging. """ if not text or not text.strip(): # Handle empty input gracefully empty_probs = {label: 0.0 for label in CLASS_NAMES} return empty_probs, [("Enter text to see explanation.", 0)], [("Enter text to see explanation.", 0)] try: # --- A. Get Prediction --- prediction_results = classifier_pipeline(text)[0] sorted_preds = sorted(prediction_results, key=lambda x: label2id[x['label']]) prediction_probs_dict = {p['label']: p['score'] for p in sorted_preds} prediction_index = np.argmax([p['score'] for p in sorted_preds]) predicted_class_name = CLASS_NAMES[prediction_index] except Exception as e: print("--- ERROR DURING PREDICTION ---") traceback.print_exc() raise gr.Error(f"Failed during prediction: {e}") # --- B. Generate LIME Explanation --- try: lime_exp = explainer_lime.explain_instance( text, predictor_for_lime, num_features=10, labels=(prediction_index,) ) lime_highlighted = lime_exp.as_list(label=prediction_index) except Exception as e: print("--- ERROR DURING LIME EXPLANATION ---") traceback.print_exc() lime_highlighted = [("LIME failed to generate.", 0)] # --- C. Generate SHAP Explanation --- try: shap_values = explainer_shap([text]) # --- FINAL FIX: Definitive logic for merging subword tokens --- # This new approach directly uses SHAP's internal grouping to avoid manual errors. shap_explanation_for_pred_class = None for i, label in enumerate(CLASS_NAMES): if label == predicted_class_name: # We use the cohort's data and values which are already grouped correctly tokens = shap_values.cohorts(1).data[0,:,i] values = shap_values.cohorts(1).values[0,:,i] # Combine tokens and values, then format for Gradio word_attributions = [] for token, value in zip(tokens, values): if token not in [tokenizer.bos_token, tokenizer.eos_token, tokenizer.sep_token, tokenizer.pad_token]: word_attributions.append((token, value)) # Sort by absolute importance and take top 10 for display word_attributions.sort(key=lambda x: abs(x[1]), reverse=True) shap_highlighted = word_attributions[:10] break if shap_explanation_for_pred_class is None: shap_highlighted = [("SHAP data not found for class.", 0)] except Exception as e: print("--- ERROR DURING SHAP EXPLANATION ---") traceback.print_exc() shap_highlighted = [("SHAP failed to generate.", 0)] return prediction_probs_dict, lime_highlighted, shap_highlighted # --- 4. Create and Launch the Gradio Interface (UPDATED) --- iface = gr.Interface( fn=classify_and_explain, inputs=gr.Textbox(lines=5, label="Enter Text for Analysis", placeholder="I've been feeling so alone and empty lately..."), outputs=[ gr.Label(label="Prediction Probabilities"), gr.HighlightedText( label="LIME Explanation (Local Surrogate)", color_map={"POSITIVE": "green", "NEGATIVE": "red"} ), gr.HighlightedText( label="SHAP Explanation (Game-Theoretic Attribution)", color_map={"POSITIVE": "blue", "NEGATIVE": "orange"} ) ], title="🔬 RoBERTa Depression Classifier with LIME & SHAP", description="This tool uses a fine-tuned RoBERTa model to classify text and provides two state-of-the-art explanations. LIME approximates the model locally, while SHAP provides theoretically grounded contribution scores for each word.", examples=[ ["I have been feeling down and hopeless for weeks. Nothing brings me joy anymore."], ["It all feels so pointless. I've been thinking about whether it's even worth being here anymore."] ] ) if __name__ == "__main__": iface.launch()