fahin-one commited on
Commit
1e6ce2d
·
verified ·
1 Parent(s): 6657221

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +245 -42
app.py CHANGED
@@ -1,63 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  import numpy as np
4
- from transformers import RobertaTokenizer, RobertaForSequenceClassification
 
 
 
 
 
5
  from lime.lime_text import LimeTextExplainer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- # --- Load Saved Model and Tokenizer ---
8
- MODEL_PATH = './roberta-depression-classifier/'
9
  tokenizer = RobertaTokenizer.from_pretrained(MODEL_PATH)
10
- model = RobertaForSequenceClassification.from_pretrained(MODEL_PATH)
 
 
 
 
 
 
 
 
 
11
  model.eval() # Set model to evaluation mode
 
 
 
12
 
13
- # --- Define Labels and Explainer ---
14
- CLASS_NAMES = ['no depression', 'moderate depression', 'severe depression', 'suicidal']
15
- explainer = LimeTextExplainer(class_names=CLASS_NAMES)
16
-
17
- # --- Create a Prediction Function for LIME ---
18
- def predictor(texts):
19
- inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=256)
20
- with torch.no_grad():
21
- logits = model(**inputs).logits
22
- # Convert logits to probabilities
23
- probs = torch.nn.functional.softmax(logits, dim=-1).detach().numpy()
24
- return probs
25
-
26
- # --- Main Function for Gradio Interface ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def classify_and_explain(text):
28
- # Get prediction probabilities
29
- prediction_probs = predictor([text])[0]
30
- # Get the index of the highest probability
31
- prediction_index = np.argmax(prediction_probs)
32
-
33
- # Generate LIME explanation for the top predicted class
34
- explanation = explainer.explain_instance(
35
- text,
36
- predictor,
37
- num_features=10, # Show top 10 most influential words
38
- labels=(prediction_index,)
39
- )
40
-
41
- # Format the explanation for Gradio's HighlightedText component
42
- highlighted_words = explanation.as_list(label=prediction_index)
43
-
44
- return {CLASS_NAMES[i]: float(prob) for i, prob in enumerate(prediction_probs)}, highlighted_words
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- # --- Create and Launch the Gradio Interface ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  iface = gr.Interface(
48
  fn=classify_and_explain,
49
  inputs=gr.Textbox(lines=5, label="Enter Text for Analysis", placeholder="I've been feeling so alone and empty lately..."),
50
  outputs=[
51
  gr.Label(label="Prediction Probabilities"),
52
  gr.HighlightedText(
53
- label="Explanation (Word Importance)",
54
- color_map={"POS": "green", "NEG": "red"} # Words supporting/contradicting the prediction
 
 
 
 
55
  )
56
  ],
57
- title="🔬 RoBERTa Depression Severity Classifier & Explainer",
58
- description="This tool uses a fine-tuned RoBERTa model to classify text into four depression categories. It also uses LIME to highlight the words that most influenced the prediction.",
59
- examples=[["I have been feeling down and hopeless for weeks. Nothing brings me joy anymore."]]
 
 
 
60
  )
61
 
62
  if __name__ == "__main__":
63
- iface.launch()
 
1
+ # import gradio as gr
2
+ # import torch
3
+ # import numpy as np
4
+ # from transformers import RobertaTokenizer, RobertaForSequenceClassification
5
+ # from lime.lime_text import LimeTextExplainer
6
+
7
+ # # --- Load Saved Model and Tokenizer ---
8
+ # MODEL_PATH = './roberta-depression-classifier/'
9
+ # tokenizer = RobertaTokenizer.from_pretrained(MODEL_PATH)
10
+ # model = RobertaForSequenceClassification.from_pretrained(MODEL_PATH)
11
+ # model.eval() # Set model to evaluation mode
12
+
13
+ # # --- Define Labels and Explainer ---
14
+ # CLASS_NAMES = ['no depression', 'moderate depression', 'severe depression', 'suicidal']
15
+ # explainer = LimeTextExplainer(class_names=CLASS_NAMES)
16
+
17
+ # # --- Create a Prediction Function for LIME ---
18
+ # def predictor(texts):
19
+ # inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=256)
20
+ # with torch.no_grad():
21
+ # logits = model(**inputs).logits
22
+ # # Convert logits to probabilities
23
+ # probs = torch.nn.functional.softmax(logits, dim=-1).detach().numpy()
24
+ # return probs
25
+
26
+ # # --- Main Function for Gradio Interface ---
27
+ # def classify_and_explain(text):
28
+ # # Get prediction probabilities
29
+ # prediction_probs = predictor([text])[0]
30
+ # # Get the index of the highest probability
31
+ # prediction_index = np.argmax(prediction_probs)
32
+
33
+ # # Generate LIME explanation for the top predicted class
34
+ # explanation = explainer.explain_instance(
35
+ # text,
36
+ # predictor,
37
+ # num_features=10, # Show top 10 most influential words
38
+ # labels=(prediction_index,)
39
+ # )
40
+
41
+ # # Format the explanation for Gradio's HighlightedText component
42
+ # highlighted_words = explanation.as_list(label=prediction_index)
43
+
44
+ # return {CLASS_NAMES[i]: float(prob) for i, prob in enumerate(prediction_probs)}, highlighted_words
45
+
46
+ # # --- Create and Launch the Gradio Interface ---
47
+ # iface = gr.Interface(
48
+ # fn=classify_and_explain,
49
+ # inputs=gr.Textbox(lines=5, label="Enter Text for Analysis", placeholder="I've been feeling so alone and empty lately..."),
50
+ # outputs=[
51
+ # gr.Label(label="Prediction Probabilities"),
52
+ # gr.HighlightedText(
53
+ # label="Explanation (Word Importance)",
54
+ # color_map={"POS": "green", "NEG": "red"} # Words supporting/contradicting the prediction
55
+ # )
56
+ # ],
57
+ # title="🔬 RoBERTa Depression Severity Classifier & Explainer",
58
+ # description="This tool uses a fine-tuned RoBERTa model to classify text into four depression categories. It also uses LIME to highlight the words that most influenced the prediction.",
59
+ # examples=[["I have been feeling down and hopeless for weeks. Nothing brings me joy anymore."]]
60
+ # )
61
+
62
+ # if __name__ == "__main__":
63
+ # iface.launch()
64
+
65
+
66
+
67
+ # ==============================================================================
68
+ # APP.PY - DEPRESSION CLASSIFIER WITH LIME & SHAP EXPLAINABILITY
69
+ # ==============================================================================
70
  import gradio as gr
71
  import torch
72
  import numpy as np
73
+ import pandas as pd
74
+ from transformers import (
75
+ RobertaTokenizer,
76
+ RobertaForSequenceClassification,
77
+ pipeline
78
+ )
79
  from lime.lime_text import LimeTextExplainer
80
+ import shap
81
+ import warnings
82
+ import os # <-- Added os module to handle file paths
83
+ import traceback # <-- Added for detailed error logging
84
+
85
+ # Suppress warnings for cleaner output
86
+ warnings.filterwarnings("ignore")
87
+
88
+ # --- 1. Load Saved Model and Tokenizer ---
89
+ print("Loading fine-tuned RoBERTa model and tokenizer...")
90
+
91
+ # --- FIX: Create a robust, absolute path to the model directory ---
92
+ # This ensures the script finds the model folder correctly. It assumes the
93
+ # model folder is in the same directory as this app.py script.
94
+ try:
95
+ SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
96
+ except NameError:
97
+ # This handles the case where the script is run in an interactive environment like a notebook
98
+ SCRIPT_DIR = os.getcwd()
99
+
100
+ MODEL_PATH = os.path.join(SCRIPT_DIR, 'roberta-depression-classifier')
101
+
102
+ # --- NEW: Add a check to ensure the model directory exists ---
103
+ if not os.path.isdir(MODEL_PATH):
104
+ raise OSError(
105
+ f"Model directory not found at the calculated path: {MODEL_PATH}\n"
106
+ f"Please make sure the 'roberta-base-finetuned-depression' folder, "
107
+ f"containing your trained model files, is in the same directory as this app.py script."
108
+ )
109
+
110
+ # --- Define Global Variables ---
111
+ CLASS_NAMES = ['no depression', 'moderate depression', 'severe depression', 'suicidal']
112
+ label2id = {label: i for i, label in enumerate(CLASS_NAMES)}
113
+ id2label = {i: label for i, label in enumerate(CLASS_NAMES)} #<-- FIX: Define id2label mapping
114
 
 
 
115
  tokenizer = RobertaTokenizer.from_pretrained(MODEL_PATH)
116
+
117
+ # --- FIX: Load the model WITH the correct label mappings ---
118
+ # This is the key change. By passing id2label and label2id, we ensure the
119
+ # model's config is correct, and the pipeline will output the proper string labels.
120
+ model = RobertaForSequenceClassification.from_pretrained(
121
+ MODEL_PATH,
122
+ id2label=id2label,
123
+ label2id=label2id
124
+ )
125
+
126
  model.eval() # Set model to evaluation mode
127
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
128
+ model.to(device)
129
+ print("Model loaded successfully.")
130
 
131
+ # ==============================================================================
132
+ # NEW: SETUP FOR SHAP EXPLAINABILITY
133
+ # ==============================================================================
134
+ # The SHAP library works best with the Hugging Face `pipeline` object.
135
+ # This pipeline handles tokenization, prediction, and moving data to the GPU for us.
136
+ print("Creating Hugging Face pipeline for SHAP...")
137
+ classifier_pipeline = pipeline(
138
+ "text-classification",
139
+ model=model,
140
+ tokenizer=tokenizer,
141
+ device=0 if torch.cuda.is_available() else -1, # Use GPU if available
142
+ return_all_scores=True
143
+ )
144
+
145
+ # Create the SHAP explainer using the pipeline.
146
+ # The 'text' explainer is optimized for NLP models.
147
+ print("Creating SHAP explainer...")
148
+ explainer_shap = shap.Explainer(classifier_pipeline)
149
+ print("SHAP is ready.")
150
+
151
+ # ==============================================================================
152
+ # SETUP FOR LIME EXPLAINABILITY (Your existing code)
153
+ # ==============================================================================
154
+ print("Creating LIME explainer...")
155
+ explainer_lime = LimeTextExplainer(class_names=CLASS_NAMES)
156
+
157
+ # Create a prediction function that LIME can use.
158
+ # It needs to return a numpy array of probabilities for each class.
159
+ def predictor_for_lime(texts):
160
+ # Use the pipeline we already created for SHAP for consistency.
161
+ predictions = classifier_pipeline(texts, padding=True, truncation=True, max_length=512)
162
+ probs = []
163
+ for prediction_set in predictions:
164
+ # Sort results to ensure the order of probabilities matches CLASS_NAMES
165
+ sorted_preds = sorted(prediction_set, key=lambda x: label2id[x['label']])
166
+ probs.append([p['score'] for p in sorted_preds])
167
+ return np.array(probs)
168
+ print("LIME is ready.")
169
+
170
+
171
+ # --- 3. Main Function for Gradio Interface (UPDATED) ---
172
  def classify_and_explain(text):
173
+ """
174
+ This function now performs classification and generates explanations
175
+ from BOTH LIME and SHAP, with added error handling for debugging.
176
+ """
177
+ if not text or not text.strip():
178
+ # Handle empty input gracefully
179
+ empty_probs = {label: 0.0 for label in CLASS_NAMES}
180
+ return empty_probs, [("Enter text to see explanation.", 0)], [("Enter text to see explanation.", 0)]
181
+
182
+ try:
183
+ # --- A. Get Prediction ---
184
+ prediction_results = classifier_pipeline(text)[0]
185
+ sorted_preds = sorted(prediction_results, key=lambda x: label2id[x['label']])
186
+ prediction_probs_dict = {p['label']: p['score'] for p in sorted_preds}
187
+ prediction_index = np.argmax([p['score'] for p in sorted_preds])
188
+ predicted_class_name = CLASS_NAMES[prediction_index]
189
+ except Exception as e:
190
+ print("--- ERROR DURING PREDICTION ---")
191
+ traceback.print_exc()
192
+ raise gr.Error(f"Failed during prediction: {e}")
193
+
194
+ # --- B. Generate LIME Explanation ---
195
+ try:
196
+ lime_exp = explainer_lime.explain_instance(
197
+ text,
198
+ predictor_for_lime,
199
+ num_features=10,
200
+ labels=(prediction_index,)
201
+ )
202
+ lime_highlighted = lime_exp.as_list(label=prediction_index)
203
+ except Exception as e:
204
+ print("--- ERROR DURING LIME EXPLANATION ---")
205
+ traceback.print_exc()
206
+ lime_highlighted = [("LIME failed to generate.", 0)]
207
 
208
+ # --- C. Generate SHAP Explanation ---
209
+ try:
210
+ shap_values = explainer_shap([text])
211
+
212
+ # --- FINAL FIX: Definitive logic for merging subword tokens ---
213
+ # This new approach directly uses SHAP's internal grouping to avoid manual errors.
214
+ shap_explanation_for_pred_class = None
215
+ for i, label in enumerate(CLASS_NAMES):
216
+ if label == predicted_class_name:
217
+ # We use the cohort's data and values which are already grouped correctly
218
+ tokens = shap_values.cohorts(1).data[0,:,i]
219
+ values = shap_values.cohorts(1).values[0,:,i]
220
+
221
+ # Combine tokens and values, then format for Gradio
222
+ word_attributions = []
223
+ for token, value in zip(tokens, values):
224
+ if token not in [tokenizer.bos_token, tokenizer.eos_token, tokenizer.sep_token, tokenizer.pad_token]:
225
+ word_attributions.append((token, value))
226
+
227
+ # Sort by absolute importance and take top 10 for display
228
+ word_attributions.sort(key=lambda x: abs(x[1]), reverse=True)
229
+ shap_highlighted = word_attributions[:10]
230
+ break
231
+
232
+ if shap_explanation_for_pred_class is None:
233
+ shap_highlighted = [("SHAP data not found for class.", 0)]
234
+
235
+ except Exception as e:
236
+ print("--- ERROR DURING SHAP EXPLANATION ---")
237
+ traceback.print_exc()
238
+ shap_highlighted = [("SHAP failed to generate.", 0)]
239
+
240
+ return prediction_probs_dict, lime_highlighted, shap_highlighted
241
+
242
+ # --- 4. Create and Launch the Gradio Interface (UPDATED) ---
243
  iface = gr.Interface(
244
  fn=classify_and_explain,
245
  inputs=gr.Textbox(lines=5, label="Enter Text for Analysis", placeholder="I've been feeling so alone and empty lately..."),
246
  outputs=[
247
  gr.Label(label="Prediction Probabilities"),
248
  gr.HighlightedText(
249
+ label="LIME Explanation (Local Surrogate)",
250
+ color_map={"POSITIVE": "green", "NEGATIVE": "red"}
251
+ ),
252
+ gr.HighlightedText(
253
+ label="SHAP Explanation (Game-Theoretic Attribution)",
254
+ color_map={"POSITIVE": "blue", "NEGATIVE": "orange"}
255
  )
256
  ],
257
+ title="🔬 RoBERTa Depression Classifier with LIME & SHAP",
258
+ description="This tool uses a fine-tuned RoBERTa model to classify text and provides two state-of-the-art explanations. LIME approximates the model locally, while SHAP provides theoretically grounded contribution scores for each word.",
259
+ examples=[
260
+ ["I have been feeling down and hopeless for weeks. Nothing brings me joy anymore."],
261
+ ["It all feels so pointless. I've been thinking about whether it's even worth being here anymore."]
262
+ ]
263
  )
264
 
265
  if __name__ == "__main__":
266
+ iface.launch()