willwim commited on
Commit
842a7db
Β·
verified Β·
1 Parent(s): 22b54b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -73
app.py CHANGED
@@ -1,133 +1,247 @@
1
  import gradio as gr
2
  import shap
3
  import numpy as np
4
- import scipy as sp
5
  import torch
6
  import transformers
7
  from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer, AutoModelForTokenClassification
 
 
8
  import matplotlib.pyplot as plt
 
 
 
9
  import sys
10
  import csv
11
  import os
12
-
13
  HF_TOKEN = os.getenv("hf_token")
14
  csv.field_size_limit(sys.maxsize)
15
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
16
-
17
- # Load models and tokenizer
18
- tokenizer = AutoTokenizer.from_pretrained("willwim/adr_SJM_Notebook-Copy_for_T3", token=HF_TOKEN)
19
- model = AutoModelForSequenceClassification.from_pretrained("willwim/adr_SJM_Notebook-Copy_for_T3", token=HF_TOKEN).to(device)
20
-
21
- # Build a pipeline object for predictions
22
- pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device)
23
-
24
- # SHAP explainer
 
 
 
 
 
25
  explainer = shap.Explainer(pred)
26
-
27
- # NER pipeline
28
  ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
29
  ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
30
- ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
31
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def adr_predict(x):
33
  text_input = str(x).lower()
34
- encoded_input = tokenizer(text_input, return_tensors='pt').to(device)
35
  output = model(**encoded_input)
36
-
37
  scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy()
38
-
 
39
  try:
40
  shap_values = explainer([text_input])
41
- local_plot = shap.plots.text(shap_values[0], display=False)
42
  except Exception as e:
43
- local_plot = f"<p style='color:red;'>SHAP explanation error: {e}</p>"
44
-
 
45
  try:
46
  res = ner_pipe(text_input)
47
  entity_colors = {
48
- 'Severity': '#ffcccb',
49
- 'Sign_symptom': '#bcf5bc',
50
- 'Medication': '#cfe2f3',
51
- 'Age': '#fff2cc',
52
- 'Sex':'#fff2cc',
53
- 'Diagnostic_procedure':'#eeeeee',
54
- 'Biological_structure':'#d9d9d9'
55
  }
56
-
57
- # FIX: Added inline "color: black;" to force all un-highlighted text to be black
58
- htext = "<div style='line-height: 2.0; font-size: 1.1em; color: black;'>"
59
  prev_end = 0
60
- res = sorted(res, key=lambda x: x['start'])
61
  for entity in res:
62
- start, end = entity['start'], entity['end']
63
- word = text_input[start:end]
64
- color = entity_colors.get(entity['entity_group'], '#f3f3f3')
65
-
66
- htext += f"{text_input[prev_end:start]}"
67
- # Highlighted text is also explicitly set to black
68
- htext += f"<mark style='background-color:{color}; color: black; padding: 2px 4px; border-radius: 4px; font-weight: 500;'>{word} <small style='opacity: 0.7;'>[{entity['entity_group']}]</small></mark>"
 
 
 
 
69
  prev_end = end
70
  htext += text_input[prev_end:] + "</div>"
71
- except:
72
- htext = "<p style='color: black;'>NER processing error.</p>"
73
-
74
- label_output = {"Severe Reaction": float(scores[1]), "Non-severe Reaction": float(scores[0])}
75
-
76
- return label_output, local_plot, htext
77
-
78
- # FIX: Added !important tags to ensure Gradio's dark mode doesn't override the white background and black text
 
 
 
 
79
  custom_css = """
80
  .gradio-container { font-family: 'Inter', system-ui, sans-serif; }
81
  .main-header { text-align: center; margin-bottom: 2rem; }
82
- .output-box { border-radius: 8px; border: 1px solid #e0e0e0; padding: 15px; background: white !important; color: black !important; }
 
 
 
 
 
 
83
  footer { visibility: hidden; }
84
  """
85
-
86
- with gr.Blocks(title="ADR Detector") as demo:
 
87
  with gr.Column(elem_classes="main-header"):
88
  gr.Markdown("# Adverse Drug Reaction (ADR) Detector")
89
- gr.Markdown("Analyze clinical text for potential medication-related severity and key medical entities.")
90
-
 
 
 
91
  with gr.Row():
 
92
  with gr.Column(scale=1):
93
  gr.Markdown("### Input")
94
  prob1 = gr.Textbox(
95
- label="Clinical Observations",
96
- lines=4,
97
  placeholder="Example: Patient experienced acute kidney injury after taking Ibuprofen...",
98
- elem_id="input-text"
99
  )
100
  submit_btn = gr.Button("Run Analysis", variant="primary")
101
-
102
  gr.Markdown("### Examples")
103
  gr.Examples(
104
  examples=[
105
- ["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."],
106
- ["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."]
 
 
107
  ],
108
- inputs=[prob1]
109
  )
110
-
 
111
  with gr.Column(scale=1):
112
  gr.Markdown("### Classification")
113
  label = gr.Label(label="Severity Probability")
114
-
115
- # --- TABS REMOVED HERE ---
116
- # Both components are now stacked sequentially in the column
117
-
118
  gr.Markdown("### Medical Entities")
119
- htext = gr.HTML(label="NER Mapping", elem_classes="output-box")
120
-
121
  gr.Markdown("### Model Logic (SHAP)")
122
- local_plot = gr.HTML(label='Feature Importance', elem_classes="output-box")
123
-
124
  gr.Markdown("---")
125
- gr.Markdown("Disclaimer: This tool is for research purposes only and does not constitute medical advice.")
126
-
 
 
 
127
  submit_btn.click(
128
  fn=adr_predict,
129
  inputs=[prob1],
130
- outputs=[label, local_plot, htext]
131
  )
132
-
133
- demo.launch(css=custom_css, theme=gr.themes.Soft())
 
1
  import gradio as gr
2
  import shap
3
  import numpy as np
 
4
  import torch
5
  import transformers
6
  from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer, AutoModelForTokenClassification
7
+ import matplotlib
8
+ matplotlib.use("Agg")
9
  import matplotlib.pyplot as plt
10
+ import matplotlib.patches as mpatches
11
+ import io
12
+ import base64
13
  import sys
14
  import csv
15
  import os
16
+
17
  HF_TOKEN = os.getenv("hf_token")
18
  csv.field_size_limit(sys.maxsize)
19
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
20
+
21
+ # ── Load classification model ──────────────────────────────────────────────────
22
+ tokenizer = AutoTokenizer.from_pretrained(
23
+ "willwim/adr_SJM_Notebook-Copy_for_T3", token=HF_TOKEN
24
+ )
25
+ model = AutoModelForSequenceClassification.from_pretrained(
26
+ "willwim/adr_SJM_Notebook-Copy_for_T3", token=HF_TOKEN
27
+ ).to(device)
28
+
29
+ pred = transformers.pipeline(
30
+ "text-classification", model=model, tokenizer=tokenizer,
31
+ top_k=None, device=device
32
+ )
33
+
34
  explainer = shap.Explainer(pred)
35
+
36
+ # ── Load NER model ─────────────────────────────────────────────────────────────
37
  ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
38
  ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
39
+ ner_pipe = pipeline(
40
+ "ner", model=ner_model, tokenizer=ner_tokenizer,
41
+ aggregation_strategy="simple"
42
+ )
43
+
44
+
45
+ # ── Custom SHAP bar-chart renderer ─────────────────────────────────────────────
46
+ def render_shap_bar_chart(shap_values, class_idx: int = 1) -> str:
47
+ """
48
+ Builds a horizontal bar chart (red = pushes toward ADR, teal = pushes away)
49
+ that mirrors the style shown in the reference screenshot.
50
+ Returns an <img> HTML tag with an embedded base64 PNG.
51
+ """
52
+ # shap_values is a shap.Explanation object for a single sample
53
+ # .values shape: (n_tokens, n_classes) or (n_tokens,) when binary
54
+ values = shap_values.values # (n_tokens, n_classes)
55
+ tokens = shap_values.data # list/array of token strings
56
+
57
+ if values.ndim == 2:
58
+ sv = values[:, class_idx] # SHAP values for "Severe Reaction"
59
+ else:
60
+ sv = values
61
+
62
+ # Sort by absolute magnitude and keep top-N for readability
63
+ TOP_N = 20
64
+ order = np.argsort(np.abs(sv))[::-1][:TOP_N]
65
+ sv_top = sv[order]
66
+ tok_top = np.array(tokens)[order]
67
+
68
+ # Re-sort so the chart reads top-to-bottom by value (positive on top)
69
+ plot_order = np.argsort(sv_top)
70
+ sv_plot = sv_top[plot_order]
71
+ tok_plot = tok_top[plot_order]
72
+
73
+ colors = ["#e05c5c" if v > 0 else "#3dbdb0" for v in sv_plot]
74
+
75
+ fig_height = max(4, len(sv_plot) * 0.38)
76
+ fig, ax = plt.subplots(figsize=(8, fig_height), facecolor="white")
77
+ ax.set_facecolor("white")
78
+
79
+ y_pos = np.arange(len(sv_plot))
80
+ bars = ax.barh(y_pos, sv_plot, color=colors, height=0.6, edgecolor="none")
81
+
82
+ # Zero line
83
+ ax.axvline(0, color="#333333", linewidth=0.9, zorder=3)
84
+
85
+ ax.set_yticks(y_pos)
86
+ ax.set_yticklabels(tok_plot, fontsize=10, color="#222222")
87
+ ax.set_xlabel("SHAP Value β€” impact on ADR prediction", fontsize=10, color="#444444")
88
+ ax.set_title(
89
+ "Token-Level Feature Importance\n"
90
+ "β–  Red = pushes toward ADR β–  Teal = pushes away",
91
+ fontsize=11, color="#222222", pad=10
92
+ )
93
+
94
+ # Legend patches
95
+ red_patch = mpatches.Patch(color="#e05c5c", label="Pushes toward ADR")
96
+ teal_patch = mpatches.Patch(color="#3dbdb0", label="Pushes away from ADR")
97
+ ax.legend(handles=[red_patch, teal_patch], fontsize=9,
98
+ loc="lower right", framealpha=0.7)
99
+
100
+ ax.spines["top"].set_visible(False)
101
+ ax.spines["right"].set_visible(False)
102
+ ax.spines["left"].set_visible(False)
103
+ ax.tick_params(axis="y", length=0)
104
+ ax.tick_params(axis="x", colors="#555555")
105
+ ax.xaxis.label.set_color("#555555")
106
+
107
+ plt.tight_layout()
108
+
109
+ buf = io.BytesIO()
110
+ fig.savefig(buf, format="png", dpi=130, bbox_inches="tight",
111
+ facecolor="white")
112
+ plt.close(fig)
113
+ buf.seek(0)
114
+ b64 = base64.b64encode(buf.read()).decode("utf-8")
115
+ return (
116
+ f"<div style='background:white; padding:12px; border-radius:8px;'>"
117
+ f"<img src='data:image/png;base64,{b64}' "
118
+ f"style='width:100%; max-width:760px; display:block; margin:auto;' />"
119
+ f"</div>"
120
+ )
121
+
122
+
123
+ # ── Main prediction function ───────────────────────────────────────────────────
124
  def adr_predict(x):
125
  text_input = str(x).lower()
126
+ encoded_input = tokenizer(text_input, return_tensors="pt").to(device)
127
  output = model(**encoded_input)
 
128
  scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy()
129
+
130
+ # ── SHAP (bar chart) ──────────────────────────────────────────────────────
131
  try:
132
  shap_values = explainer([text_input])
133
+ shap_html = render_shap_bar_chart(shap_values[0], class_idx=1)
134
  except Exception as e:
135
+ shap_html = f"<p style='color:red;'>SHAP explanation error: {e}</p>"
136
+
137
+ # ── NER ───────────────────────────────────────────────────────────────────
138
  try:
139
  res = ner_pipe(text_input)
140
  entity_colors = {
141
+ "Severity": "#ffcccb",
142
+ "Sign_symptom": "#bcf5bc",
143
+ "Medication": "#cfe2f3",
144
+ "Age": "#fff2cc",
145
+ "Sex": "#fff2cc",
146
+ "Diagnostic_procedure": "#eeeeee",
147
+ "Biological_structure": "#d9d9d9",
148
  }
149
+
150
+ htext = "<div style='line-height:2.0; font-size:1.1em; color:black;'>"
 
151
  prev_end = 0
152
+ res = sorted(res, key=lambda e: e["start"])
153
  for entity in res:
154
+ start, end = entity["start"], entity["end"]
155
+ word = text_input[start:end]
156
+ color = entity_colors.get(entity["entity_group"], "#f3f3f3")
157
+ htext += text_input[prev_end:start]
158
+ htext += (
159
+ f"<mark style='background-color:{color}; color:black; "
160
+ f"padding:2px 4px; border-radius:4px; font-weight:500;'>"
161
+ f"{word} "
162
+ f"<small style='opacity:0.7;'>[{entity['entity_group']}]</small>"
163
+ f"</mark>"
164
+ )
165
  prev_end = end
166
  htext += text_input[prev_end:] + "</div>"
167
+ except Exception:
168
+ htext = "<p style='color:black;'>NER processing error.</p>"
169
+
170
+ label_output = {
171
+ "Severe Reaction": float(scores[1]),
172
+ "Non-severe Reaction": float(scores[0]),
173
+ }
174
+
175
+ return label_output, shap_html, htext
176
+
177
+
178
+ # ── UI ─────────────────────────────────────────────────────────────────────────
179
  custom_css = """
180
  .gradio-container { font-family: 'Inter', system-ui, sans-serif; }
181
  .main-header { text-align: center; margin-bottom: 2rem; }
182
+ .output-box {
183
+ border-radius: 8px;
184
+ border: 1px solid #e0e0e0;
185
+ padding: 15px;
186
+ background: white !important;
187
+ color: black !important;
188
+ }
189
  footer { visibility: hidden; }
190
  """
191
+
192
+ with gr.Blocks(title="ADR Detector", css=custom_css, theme=gr.themes.Soft()) as demo:
193
+
194
  with gr.Column(elem_classes="main-header"):
195
  gr.Markdown("# Adverse Drug Reaction (ADR) Detector")
196
+ gr.Markdown(
197
+ "Analyze clinical text for potential medication-related severity "
198
+ "and key medical entities."
199
+ )
200
+
201
  with gr.Row():
202
+ # ── Left column: input ────────────────────────────────────────────────
203
  with gr.Column(scale=1):
204
  gr.Markdown("### Input")
205
  prob1 = gr.Textbox(
206
+ label="Clinical Observations",
207
+ lines=4,
208
  placeholder="Example: Patient experienced acute kidney injury after taking Ibuprofen...",
209
+ elem_id="input-text",
210
  )
211
  submit_btn = gr.Button("Run Analysis", variant="primary")
212
+
213
  gr.Markdown("### Examples")
214
  gr.Examples(
215
  examples=[
216
+ ["A 35 year-old male had severe headache after taking Aspirin. "
217
+ "The lab results were normal."],
218
+ ["A 35 year-old female had minor pain in upper abdomen after "
219
+ "taking Acetaminophen."],
220
  ],
221
+ inputs=[prob1],
222
  )
223
+
224
+ # ── Right column: outputs ─────────────────────────────────────────────
225
  with gr.Column(scale=1):
226
  gr.Markdown("### Classification")
227
  label = gr.Label(label="Severity Probability")
228
+
 
 
 
229
  gr.Markdown("### Medical Entities")
230
+ htext_out = gr.HTML(label="NER Mapping", elem_classes="output-box")
231
+
232
  gr.Markdown("### Model Logic (SHAP)")
233
+ shap_out = gr.HTML(label="Feature Importance", elem_classes="output-box")
234
+
235
  gr.Markdown("---")
236
+ gr.Markdown(
237
+ "Disclaimer: This tool is for research purposes only and does not "
238
+ "constitute medical advice."
239
+ )
240
+
241
  submit_btn.click(
242
  fn=adr_predict,
243
  inputs=[prob1],
244
+ outputs=[label, shap_out, htext_out],
245
  )
246
+
247
+ demo.launch()