willwim commited on
Commit
fd765d2
Β·
verified Β·
1 Parent(s): 842a7db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -46
app.py CHANGED
@@ -13,11 +13,11 @@ import base64
13
  import sys
14
  import csv
15
  import os
16
-
17
  HF_TOKEN = os.getenv("hf_token")
18
  csv.field_size_limit(sys.maxsize)
19
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
20
-
21
  # ── Load classification model ──────────────────────────────────────────────────
22
  tokenizer = AutoTokenizer.from_pretrained(
23
  "willwim/adr_SJM_Notebook-Copy_for_T3", token=HF_TOKEN
@@ -25,14 +25,14 @@ tokenizer = AutoTokenizer.from_pretrained(
25
  model = AutoModelForSequenceClassification.from_pretrained(
26
  "willwim/adr_SJM_Notebook-Copy_for_T3", token=HF_TOKEN
27
  ).to(device)
28
-
29
  pred = transformers.pipeline(
30
  "text-classification", model=model, tokenizer=tokenizer,
31
  top_k=None, device=device
32
  )
33
-
34
  explainer = shap.Explainer(pred)
35
-
36
  # ── Load NER model ─────────────────────────────────────────────────────────────
37
  ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
38
  ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
@@ -40,8 +40,8 @@ ner_pipe = pipeline(
40
  "ner", model=ner_model, tokenizer=ner_tokenizer,
41
  aggregation_strategy="simple"
42
  )
43
-
44
-
45
  # ── Custom SHAP bar-chart renderer ─────────────────────────────────────────────
46
  def render_shap_bar_chart(shap_values, class_idx: int = 1) -> str:
47
  """
@@ -53,59 +53,63 @@ def render_shap_bar_chart(shap_values, class_idx: int = 1) -> str:
53
  # .values shape: (n_tokens, n_classes) or (n_tokens,) when binary
54
  values = shap_values.values # (n_tokens, n_classes)
55
  tokens = shap_values.data # list/array of token strings
56
-
57
  if values.ndim == 2:
58
  sv = values[:, class_idx] # SHAP values for "Severe Reaction"
59
  else:
60
  sv = values
61
-
62
  # Sort by absolute magnitude and keep top-N for readability
63
  TOP_N = 20
64
  order = np.argsort(np.abs(sv))[::-1][:TOP_N]
65
  sv_top = sv[order]
66
  tok_top = np.array(tokens)[order]
67
-
68
  # Re-sort so the chart reads top-to-bottom by value (positive on top)
69
  plot_order = np.argsort(sv_top)
70
  sv_plot = sv_top[plot_order]
71
  tok_plot = tok_top[plot_order]
72
-
73
- colors = ["#e05c5c" if v > 0 else "#3dbdb0" for v in sv_plot]
74
-
 
 
 
75
  fig_height = max(4, len(sv_plot) * 0.38)
76
  fig, ax = plt.subplots(figsize=(8, fig_height), facecolor="white")
77
  ax.set_facecolor("white")
78
-
79
  y_pos = np.arange(len(sv_plot))
80
  bars = ax.barh(y_pos, sv_plot, color=colors, height=0.6, edgecolor="none")
81
-
82
  # Zero line
83
  ax.axvline(0, color="#333333", linewidth=0.9, zorder=3)
84
-
85
  ax.set_yticks(y_pos)
86
  ax.set_yticklabels(tok_plot, fontsize=10, color="#222222")
87
  ax.set_xlabel("SHAP Value β€” impact on ADR prediction", fontsize=10, color="#444444")
88
  ax.set_title(
89
- "Token-Level Feature Importance\n"
90
- "β–  Red = pushes toward ADR β–  Teal = pushes away",
91
- fontsize=11, color="#222222", pad=10
92
  )
93
-
94
- # Legend patches
95
- red_patch = mpatches.Patch(color="#e05c5c", label="Pushes toward ADR")
96
- teal_patch = mpatches.Patch(color="#3dbdb0", label="Pushes away from ADR")
97
- ax.legend(handles=[red_patch, teal_patch], fontsize=9,
 
 
98
  loc="lower right", framealpha=0.7)
99
-
100
  ax.spines["top"].set_visible(False)
101
  ax.spines["right"].set_visible(False)
102
  ax.spines["left"].set_visible(False)
103
  ax.tick_params(axis="y", length=0)
104
  ax.tick_params(axis="x", colors="#555555")
105
  ax.xaxis.label.set_color("#555555")
106
-
107
  plt.tight_layout()
108
-
109
  buf = io.BytesIO()
110
  fig.savefig(buf, format="png", dpi=130, bbox_inches="tight",
111
  facecolor="white")
@@ -118,22 +122,22 @@ def render_shap_bar_chart(shap_values, class_idx: int = 1) -> str:
118
  f"style='width:100%; max-width:760px; display:block; margin:auto;' />"
119
  f"</div>"
120
  )
121
-
122
-
123
  # ── Main prediction function ─────────────────────────────────────────────────���─
124
  def adr_predict(x):
125
  text_input = str(x).lower()
126
  encoded_input = tokenizer(text_input, return_tensors="pt").to(device)
127
  output = model(**encoded_input)
128
  scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy()
129
-
130
  # ── SHAP (bar chart) ──────────────────────────────────────────────────────
131
  try:
132
  shap_values = explainer([text_input])
133
  shap_html = render_shap_bar_chart(shap_values[0], class_idx=1)
134
  except Exception as e:
135
  shap_html = f"<p style='color:red;'>SHAP explanation error: {e}</p>"
136
-
137
  # ── NER ───────────────────────────────────────────────────────────────────
138
  try:
139
  res = ner_pipe(text_input)
@@ -146,7 +150,7 @@ def adr_predict(x):
146
  "Diagnostic_procedure": "#eeeeee",
147
  "Biological_structure": "#d9d9d9",
148
  }
149
-
150
  htext = "<div style='line-height:2.0; font-size:1.1em; color:black;'>"
151
  prev_end = 0
152
  res = sorted(res, key=lambda e: e["start"])
@@ -166,15 +170,15 @@ def adr_predict(x):
166
  htext += text_input[prev_end:] + "</div>"
167
  except Exception:
168
  htext = "<p style='color:black;'>NER processing error.</p>"
169
-
170
  label_output = {
171
  "Severe Reaction": float(scores[1]),
172
  "Non-severe Reaction": float(scores[0]),
173
  }
174
-
175
  return label_output, shap_html, htext
176
-
177
-
178
  # ── UI ─────────────────────────────────────────────────────────────────────────
179
  custom_css = """
180
  .gradio-container { font-family: 'Inter', system-ui, sans-serif; }
@@ -188,16 +192,16 @@ custom_css = """
188
  }
189
  footer { visibility: hidden; }
190
  """
191
-
192
  with gr.Blocks(title="ADR Detector", css=custom_css, theme=gr.themes.Soft()) as demo:
193
-
194
  with gr.Column(elem_classes="main-header"):
195
  gr.Markdown("# Adverse Drug Reaction (ADR) Detector")
196
  gr.Markdown(
197
  "Analyze clinical text for potential medication-related severity "
198
  "and key medical entities."
199
  )
200
-
201
  with gr.Row():
202
  # ── Left column: input ────────────────────────────────────────────────
203
  with gr.Column(scale=1):
@@ -209,7 +213,7 @@ with gr.Blocks(title="ADR Detector", css=custom_css, theme=gr.themes.Soft()) as
209
  elem_id="input-text",
210
  )
211
  submit_btn = gr.Button("Run Analysis", variant="primary")
212
-
213
  gr.Markdown("### Examples")
214
  gr.Examples(
215
  examples=[
@@ -220,28 +224,28 @@ with gr.Blocks(title="ADR Detector", css=custom_css, theme=gr.themes.Soft()) as
220
  ],
221
  inputs=[prob1],
222
  )
223
-
224
  # ── Right column: outputs ─────────────────────────────────────────────
225
  with gr.Column(scale=1):
226
  gr.Markdown("### Classification")
227
  label = gr.Label(label="Severity Probability")
228
-
229
  gr.Markdown("### Medical Entities")
230
  htext_out = gr.HTML(label="NER Mapping", elem_classes="output-box")
231
-
232
  gr.Markdown("### Model Logic (SHAP)")
233
  shap_out = gr.HTML(label="Feature Importance", elem_classes="output-box")
234
-
235
  gr.Markdown("---")
236
  gr.Markdown(
237
  "Disclaimer: This tool is for research purposes only and does not "
238
  "constitute medical advice."
239
  )
240
-
241
  submit_btn.click(
242
  fn=adr_predict,
243
  inputs=[prob1],
244
  outputs=[label, shap_out, htext_out],
245
  )
246
-
247
  demo.launch()
 
13
  import sys
14
  import csv
15
  import os
16
+
17
  HF_TOKEN = os.getenv("hf_token")
18
  csv.field_size_limit(sys.maxsize)
19
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
20
+
21
  # ── Load classification model ──────────────────────────────────────────────────
22
  tokenizer = AutoTokenizer.from_pretrained(
23
  "willwim/adr_SJM_Notebook-Copy_for_T3", token=HF_TOKEN
 
25
  model = AutoModelForSequenceClassification.from_pretrained(
26
  "willwim/adr_SJM_Notebook-Copy_for_T3", token=HF_TOKEN
27
  ).to(device)
28
+
29
  pred = transformers.pipeline(
30
  "text-classification", model=model, tokenizer=tokenizer,
31
  top_k=None, device=device
32
  )
33
+
34
  explainer = shap.Explainer(pred)
35
+
36
  # ── Load NER model ─────────────────────────────────────────────────────────────
37
  ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
38
  ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
 
40
  "ner", model=ner_model, tokenizer=ner_tokenizer,
41
  aggregation_strategy="simple"
42
  )
43
+
44
+
45
  # ── Custom SHAP bar-chart renderer ─────────────────────────────────────────────
46
  def render_shap_bar_chart(shap_values, class_idx: int = 1) -> str:
47
  """
 
53
  # .values shape: (n_tokens, n_classes) or (n_tokens,) when binary
54
  values = shap_values.values # (n_tokens, n_classes)
55
  tokens = shap_values.data # list/array of token strings
56
+
57
  if values.ndim == 2:
58
  sv = values[:, class_idx] # SHAP values for "Severe Reaction"
59
  else:
60
  sv = values
61
+
62
  # Sort by absolute magnitude and keep top-N for readability
63
  TOP_N = 20
64
  order = np.argsort(np.abs(sv))[::-1][:TOP_N]
65
  sv_top = sv[order]
66
  tok_top = np.array(tokens)[order]
67
+
68
  # Re-sort so the chart reads top-to-bottom by value (positive on top)
69
  plot_order = np.argsort(sv_top)
70
  sv_plot = sv_top[plot_order]
71
  tok_plot = tok_top[plot_order]
72
+
73
+ COLOR_POSITIVE = "#cc1111" # bold red β€” increases severe ADR probability
74
+ COLOR_NEGATIVE = "#1a6fcc" # strong blue β€” decreases severe ADR probability
75
+
76
+ colors = [COLOR_POSITIVE if v > 0 else COLOR_NEGATIVE for v in sv_plot]
77
+
78
  fig_height = max(4, len(sv_plot) * 0.38)
79
  fig, ax = plt.subplots(figsize=(8, fig_height), facecolor="white")
80
  ax.set_facecolor("white")
81
+
82
  y_pos = np.arange(len(sv_plot))
83
  bars = ax.barh(y_pos, sv_plot, color=colors, height=0.6, edgecolor="none")
84
+
85
  # Zero line
86
  ax.axvline(0, color="#333333", linewidth=0.9, zorder=3)
87
+
88
  ax.set_yticks(y_pos)
89
  ax.set_yticklabels(tok_plot, fontsize=10, color="#222222")
90
  ax.set_xlabel("SHAP Value β€” impact on ADR prediction", fontsize=10, color="#444444")
91
  ax.set_title(
92
+ "Token-Feature Importance: Words Driving Prediction",
93
+ fontsize=12, fontweight="bold", color="#222222", pad=12
 
94
  )
95
+
96
+ # Legend patches β€” colors match the bars exactly
97
+ red_patch = mpatches.Patch(color=COLOR_POSITIVE,
98
+ label="Increases severe ADR probability")
99
+ blue_patch = mpatches.Patch(color=COLOR_NEGATIVE,
100
+ label="Decreases severe ADR probability")
101
+ ax.legend(handles=[red_patch, blue_patch], fontsize=9,
102
  loc="lower right", framealpha=0.7)
103
+
104
  ax.spines["top"].set_visible(False)
105
  ax.spines["right"].set_visible(False)
106
  ax.spines["left"].set_visible(False)
107
  ax.tick_params(axis="y", length=0)
108
  ax.tick_params(axis="x", colors="#555555")
109
  ax.xaxis.label.set_color("#555555")
110
+
111
  plt.tight_layout()
112
+
113
  buf = io.BytesIO()
114
  fig.savefig(buf, format="png", dpi=130, bbox_inches="tight",
115
  facecolor="white")
 
122
  f"style='width:100%; max-width:760px; display:block; margin:auto;' />"
123
  f"</div>"
124
  )
125
+
126
+
127
  # ── Main prediction function ─────────────────────────────────────────────────���─
128
  def adr_predict(x):
129
  text_input = str(x).lower()
130
  encoded_input = tokenizer(text_input, return_tensors="pt").to(device)
131
  output = model(**encoded_input)
132
  scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy()
133
+
134
  # ── SHAP (bar chart) ──────────────────────────────────────────────────────
135
  try:
136
  shap_values = explainer([text_input])
137
  shap_html = render_shap_bar_chart(shap_values[0], class_idx=1)
138
  except Exception as e:
139
  shap_html = f"<p style='color:red;'>SHAP explanation error: {e}</p>"
140
+
141
  # ── NER ───────────────────────────────────────────────────────────────────
142
  try:
143
  res = ner_pipe(text_input)
 
150
  "Diagnostic_procedure": "#eeeeee",
151
  "Biological_structure": "#d9d9d9",
152
  }
153
+
154
  htext = "<div style='line-height:2.0; font-size:1.1em; color:black;'>"
155
  prev_end = 0
156
  res = sorted(res, key=lambda e: e["start"])
 
170
  htext += text_input[prev_end:] + "</div>"
171
  except Exception:
172
  htext = "<p style='color:black;'>NER processing error.</p>"
173
+
174
  label_output = {
175
  "Severe Reaction": float(scores[1]),
176
  "Non-severe Reaction": float(scores[0]),
177
  }
178
+
179
  return label_output, shap_html, htext
180
+
181
+
182
  # ── UI ─────────────────────────────────────────────────────────────────────────
183
  custom_css = """
184
  .gradio-container { font-family: 'Inter', system-ui, sans-serif; }
 
192
  }
193
  footer { visibility: hidden; }
194
  """
195
+
196
  with gr.Blocks(title="ADR Detector", css=custom_css, theme=gr.themes.Soft()) as demo:
197
+
198
  with gr.Column(elem_classes="main-header"):
199
  gr.Markdown("# Adverse Drug Reaction (ADR) Detector")
200
  gr.Markdown(
201
  "Analyze clinical text for potential medication-related severity "
202
  "and key medical entities."
203
  )
204
+
205
  with gr.Row():
206
  # ── Left column: input ────────────────────────────────────────────────
207
  with gr.Column(scale=1):
 
213
  elem_id="input-text",
214
  )
215
  submit_btn = gr.Button("Run Analysis", variant="primary")
216
+
217
  gr.Markdown("### Examples")
218
  gr.Examples(
219
  examples=[
 
224
  ],
225
  inputs=[prob1],
226
  )
227
+
228
  # ── Right column: outputs ─────────────────────────────────────────────
229
  with gr.Column(scale=1):
230
  gr.Markdown("### Classification")
231
  label = gr.Label(label="Severity Probability")
232
+
233
  gr.Markdown("### Medical Entities")
234
  htext_out = gr.HTML(label="NER Mapping", elem_classes="output-box")
235
+
236
  gr.Markdown("### Model Logic (SHAP)")
237
  shap_out = gr.HTML(label="Feature Importance", elem_classes="output-box")
238
+
239
  gr.Markdown("---")
240
  gr.Markdown(
241
  "Disclaimer: This tool is for research purposes only and does not "
242
  "constitute medical advice."
243
  )
244
+
245
  submit_btn.click(
246
  fn=adr_predict,
247
  inputs=[prob1],
248
  outputs=[label, shap_out, htext_out],
249
  )
250
+
251
  demo.launch()