Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,133 +1,247 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import shap
|
| 3 |
import numpy as np
|
| 4 |
-
import scipy as sp
|
| 5 |
import torch
|
| 6 |
import transformers
|
| 7 |
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer, AutoModelForTokenClassification
|
|
|
|
|
|
|
| 8 |
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
|
|
| 9 |
import sys
|
| 10 |
import csv
|
| 11 |
import os
|
| 12 |
-
|
| 13 |
HF_TOKEN = os.getenv("hf_token")
|
| 14 |
csv.field_size_limit(sys.maxsize)
|
| 15 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 16 |
-
|
| 17 |
-
# Load
|
| 18 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
explainer = shap.Explainer(pred)
|
| 26 |
-
|
| 27 |
-
# NER
|
| 28 |
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
|
| 29 |
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
|
| 30 |
-
ner_pipe = pipeline(
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
def adr_predict(x):
|
| 33 |
text_input = str(x).lower()
|
| 34 |
-
encoded_input = tokenizer(text_input, return_tensors=
|
| 35 |
output = model(**encoded_input)
|
| 36 |
-
|
| 37 |
scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy()
|
| 38 |
-
|
|
|
|
| 39 |
try:
|
| 40 |
shap_values = explainer([text_input])
|
| 41 |
-
|
| 42 |
except Exception as e:
|
| 43 |
-
|
| 44 |
-
|
|
|
|
| 45 |
try:
|
| 46 |
res = ner_pipe(text_input)
|
| 47 |
entity_colors = {
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
}
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
htext = "<div style='line-height: 2.0; font-size: 1.1em; color: black;'>"
|
| 59 |
prev_end = 0
|
| 60 |
-
res = sorted(res, key=lambda
|
| 61 |
for entity in res:
|
| 62 |
-
start, end = entity[
|
| 63 |
-
word
|
| 64 |
-
color = entity_colors.get(entity[
|
| 65 |
-
|
| 66 |
-
htext +=
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
prev_end = end
|
| 70 |
htext += text_input[prev_end:] + "</div>"
|
| 71 |
-
except:
|
| 72 |
-
htext = "<p style='color:
|
| 73 |
-
|
| 74 |
-
label_output = {
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
custom_css = """
|
| 80 |
.gradio-container { font-family: 'Inter', system-ui, sans-serif; }
|
| 81 |
.main-header { text-align: center; margin-bottom: 2rem; }
|
| 82 |
-
.output-box {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
footer { visibility: hidden; }
|
| 84 |
"""
|
| 85 |
-
|
| 86 |
-
with gr.Blocks(title="ADR Detector") as demo:
|
|
|
|
| 87 |
with gr.Column(elem_classes="main-header"):
|
| 88 |
gr.Markdown("# Adverse Drug Reaction (ADR) Detector")
|
| 89 |
-
gr.Markdown(
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
| 91 |
with gr.Row():
|
|
|
|
| 92 |
with gr.Column(scale=1):
|
| 93 |
gr.Markdown("### Input")
|
| 94 |
prob1 = gr.Textbox(
|
| 95 |
-
label="Clinical Observations",
|
| 96 |
-
lines=4,
|
| 97 |
placeholder="Example: Patient experienced acute kidney injury after taking Ibuprofen...",
|
| 98 |
-
elem_id="input-text"
|
| 99 |
)
|
| 100 |
submit_btn = gr.Button("Run Analysis", variant="primary")
|
| 101 |
-
|
| 102 |
gr.Markdown("### Examples")
|
| 103 |
gr.Examples(
|
| 104 |
examples=[
|
| 105 |
-
["A 35 year-old male had severe headache after taking Aspirin.
|
| 106 |
-
|
|
|
|
|
|
|
| 107 |
],
|
| 108 |
-
inputs=[prob1]
|
| 109 |
)
|
| 110 |
-
|
|
|
|
| 111 |
with gr.Column(scale=1):
|
| 112 |
gr.Markdown("### Classification")
|
| 113 |
label = gr.Label(label="Severity Probability")
|
| 114 |
-
|
| 115 |
-
# --- TABS REMOVED HERE ---
|
| 116 |
-
# Both components are now stacked sequentially in the column
|
| 117 |
-
|
| 118 |
gr.Markdown("### Medical Entities")
|
| 119 |
-
|
| 120 |
-
|
| 121 |
gr.Markdown("### Model Logic (SHAP)")
|
| 122 |
-
|
| 123 |
-
|
| 124 |
gr.Markdown("---")
|
| 125 |
-
gr.Markdown(
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
| 127 |
submit_btn.click(
|
| 128 |
fn=adr_predict,
|
| 129 |
inputs=[prob1],
|
| 130 |
-
outputs=[label,
|
| 131 |
)
|
| 132 |
-
|
| 133 |
-
demo.launch(
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import shap
|
| 3 |
import numpy as np
|
|
|
|
| 4 |
import torch
|
| 5 |
import transformers
|
| 6 |
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer, AutoModelForTokenClassification
|
| 7 |
+
import matplotlib
|
| 8 |
+
matplotlib.use("Agg")
|
| 9 |
import matplotlib.pyplot as plt
|
| 10 |
+
import matplotlib.patches as mpatches
|
| 11 |
+
import io
|
| 12 |
+
import base64
|
| 13 |
import sys
|
| 14 |
import csv
|
| 15 |
import os
|
| 16 |
+
|
| 17 |
HF_TOKEN = os.getenv("hf_token")
|
| 18 |
csv.field_size_limit(sys.maxsize)
|
| 19 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 20 |
+
|
| 21 |
+
# ββ Load classification model ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 22 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 23 |
+
"willwim/adr_SJM_Notebook-Copy_for_T3", token=HF_TOKEN
|
| 24 |
+
)
|
| 25 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 26 |
+
"willwim/adr_SJM_Notebook-Copy_for_T3", token=HF_TOKEN
|
| 27 |
+
).to(device)
|
| 28 |
+
|
| 29 |
+
pred = transformers.pipeline(
|
| 30 |
+
"text-classification", model=model, tokenizer=tokenizer,
|
| 31 |
+
top_k=None, device=device
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
explainer = shap.Explainer(pred)
|
| 35 |
+
|
| 36 |
+
# ββ Load NER model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 37 |
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
|
| 38 |
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
|
| 39 |
+
ner_pipe = pipeline(
|
| 40 |
+
"ner", model=ner_model, tokenizer=ner_tokenizer,
|
| 41 |
+
aggregation_strategy="simple"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ββ Custom SHAP bar-chart renderer βββββββββββββββββββββββββββββββββββββββββββββ
|
| 46 |
+
def render_shap_bar_chart(shap_values, class_idx: int = 1) -> str:
|
| 47 |
+
"""
|
| 48 |
+
Builds a horizontal bar chart (red = pushes toward ADR, teal = pushes away)
|
| 49 |
+
that mirrors the style shown in the reference screenshot.
|
| 50 |
+
Returns an <img> HTML tag with an embedded base64 PNG.
|
| 51 |
+
"""
|
| 52 |
+
# shap_values is a shap.Explanation object for a single sample
|
| 53 |
+
# .values shape: (n_tokens, n_classes) or (n_tokens,) when binary
|
| 54 |
+
values = shap_values.values # (n_tokens, n_classes)
|
| 55 |
+
tokens = shap_values.data # list/array of token strings
|
| 56 |
+
|
| 57 |
+
if values.ndim == 2:
|
| 58 |
+
sv = values[:, class_idx] # SHAP values for "Severe Reaction"
|
| 59 |
+
else:
|
| 60 |
+
sv = values
|
| 61 |
+
|
| 62 |
+
# Sort by absolute magnitude and keep top-N for readability
|
| 63 |
+
TOP_N = 20
|
| 64 |
+
order = np.argsort(np.abs(sv))[::-1][:TOP_N]
|
| 65 |
+
sv_top = sv[order]
|
| 66 |
+
tok_top = np.array(tokens)[order]
|
| 67 |
+
|
| 68 |
+
# Re-sort so the chart reads top-to-bottom by value (positive on top)
|
| 69 |
+
plot_order = np.argsort(sv_top)
|
| 70 |
+
sv_plot = sv_top[plot_order]
|
| 71 |
+
tok_plot = tok_top[plot_order]
|
| 72 |
+
|
| 73 |
+
colors = ["#e05c5c" if v > 0 else "#3dbdb0" for v in sv_plot]
|
| 74 |
+
|
| 75 |
+
fig_height = max(4, len(sv_plot) * 0.38)
|
| 76 |
+
fig, ax = plt.subplots(figsize=(8, fig_height), facecolor="white")
|
| 77 |
+
ax.set_facecolor("white")
|
| 78 |
+
|
| 79 |
+
y_pos = np.arange(len(sv_plot))
|
| 80 |
+
bars = ax.barh(y_pos, sv_plot, color=colors, height=0.6, edgecolor="none")
|
| 81 |
+
|
| 82 |
+
# Zero line
|
| 83 |
+
ax.axvline(0, color="#333333", linewidth=0.9, zorder=3)
|
| 84 |
+
|
| 85 |
+
ax.set_yticks(y_pos)
|
| 86 |
+
ax.set_yticklabels(tok_plot, fontsize=10, color="#222222")
|
| 87 |
+
ax.set_xlabel("SHAP Value β impact on ADR prediction", fontsize=10, color="#444444")
|
| 88 |
+
ax.set_title(
|
| 89 |
+
"Token-Level Feature Importance\n"
|
| 90 |
+
"β Red = pushes toward ADR β Teal = pushes away",
|
| 91 |
+
fontsize=11, color="#222222", pad=10
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Legend patches
|
| 95 |
+
red_patch = mpatches.Patch(color="#e05c5c", label="Pushes toward ADR")
|
| 96 |
+
teal_patch = mpatches.Patch(color="#3dbdb0", label="Pushes away from ADR")
|
| 97 |
+
ax.legend(handles=[red_patch, teal_patch], fontsize=9,
|
| 98 |
+
loc="lower right", framealpha=0.7)
|
| 99 |
+
|
| 100 |
+
ax.spines["top"].set_visible(False)
|
| 101 |
+
ax.spines["right"].set_visible(False)
|
| 102 |
+
ax.spines["left"].set_visible(False)
|
| 103 |
+
ax.tick_params(axis="y", length=0)
|
| 104 |
+
ax.tick_params(axis="x", colors="#555555")
|
| 105 |
+
ax.xaxis.label.set_color("#555555")
|
| 106 |
+
|
| 107 |
+
plt.tight_layout()
|
| 108 |
+
|
| 109 |
+
buf = io.BytesIO()
|
| 110 |
+
fig.savefig(buf, format="png", dpi=130, bbox_inches="tight",
|
| 111 |
+
facecolor="white")
|
| 112 |
+
plt.close(fig)
|
| 113 |
+
buf.seek(0)
|
| 114 |
+
b64 = base64.b64encode(buf.read()).decode("utf-8")
|
| 115 |
+
return (
|
| 116 |
+
f"<div style='background:white; padding:12px; border-radius:8px;'>"
|
| 117 |
+
f"<img src='data:image/png;base64,{b64}' "
|
| 118 |
+
f"style='width:100%; max-width:760px; display:block; margin:auto;' />"
|
| 119 |
+
f"</div>"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# ββ Main prediction function βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 124 |
def adr_predict(x):
|
| 125 |
text_input = str(x).lower()
|
| 126 |
+
encoded_input = tokenizer(text_input, return_tensors="pt").to(device)
|
| 127 |
output = model(**encoded_input)
|
|
|
|
| 128 |
scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy()
|
| 129 |
+
|
| 130 |
+
# ββ SHAP (bar chart) ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 131 |
try:
|
| 132 |
shap_values = explainer([text_input])
|
| 133 |
+
shap_html = render_shap_bar_chart(shap_values[0], class_idx=1)
|
| 134 |
except Exception as e:
|
| 135 |
+
shap_html = f"<p style='color:red;'>SHAP explanation error: {e}</p>"
|
| 136 |
+
|
| 137 |
+
# ββ NER βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 138 |
try:
|
| 139 |
res = ner_pipe(text_input)
|
| 140 |
entity_colors = {
|
| 141 |
+
"Severity": "#ffcccb",
|
| 142 |
+
"Sign_symptom": "#bcf5bc",
|
| 143 |
+
"Medication": "#cfe2f3",
|
| 144 |
+
"Age": "#fff2cc",
|
| 145 |
+
"Sex": "#fff2cc",
|
| 146 |
+
"Diagnostic_procedure": "#eeeeee",
|
| 147 |
+
"Biological_structure": "#d9d9d9",
|
| 148 |
}
|
| 149 |
+
|
| 150 |
+
htext = "<div style='line-height:2.0; font-size:1.1em; color:black;'>"
|
|
|
|
| 151 |
prev_end = 0
|
| 152 |
+
res = sorted(res, key=lambda e: e["start"])
|
| 153 |
for entity in res:
|
| 154 |
+
start, end = entity["start"], entity["end"]
|
| 155 |
+
word = text_input[start:end]
|
| 156 |
+
color = entity_colors.get(entity["entity_group"], "#f3f3f3")
|
| 157 |
+
htext += text_input[prev_end:start]
|
| 158 |
+
htext += (
|
| 159 |
+
f"<mark style='background-color:{color}; color:black; "
|
| 160 |
+
f"padding:2px 4px; border-radius:4px; font-weight:500;'>"
|
| 161 |
+
f"{word} "
|
| 162 |
+
f"<small style='opacity:0.7;'>[{entity['entity_group']}]</small>"
|
| 163 |
+
f"</mark>"
|
| 164 |
+
)
|
| 165 |
prev_end = end
|
| 166 |
htext += text_input[prev_end:] + "</div>"
|
| 167 |
+
except Exception:
|
| 168 |
+
htext = "<p style='color:black;'>NER processing error.</p>"
|
| 169 |
+
|
| 170 |
+
label_output = {
|
| 171 |
+
"Severe Reaction": float(scores[1]),
|
| 172 |
+
"Non-severe Reaction": float(scores[0]),
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
return label_output, shap_html, htext
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# ββ UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 179 |
custom_css = """
|
| 180 |
.gradio-container { font-family: 'Inter', system-ui, sans-serif; }
|
| 181 |
.main-header { text-align: center; margin-bottom: 2rem; }
|
| 182 |
+
.output-box {
|
| 183 |
+
border-radius: 8px;
|
| 184 |
+
border: 1px solid #e0e0e0;
|
| 185 |
+
padding: 15px;
|
| 186 |
+
background: white !important;
|
| 187 |
+
color: black !important;
|
| 188 |
+
}
|
| 189 |
footer { visibility: hidden; }
|
| 190 |
"""
|
| 191 |
+
|
| 192 |
+
with gr.Blocks(title="ADR Detector", css=custom_css, theme=gr.themes.Soft()) as demo:
|
| 193 |
+
|
| 194 |
with gr.Column(elem_classes="main-header"):
|
| 195 |
gr.Markdown("# Adverse Drug Reaction (ADR) Detector")
|
| 196 |
+
gr.Markdown(
|
| 197 |
+
"Analyze clinical text for potential medication-related severity "
|
| 198 |
+
"and key medical entities."
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
with gr.Row():
|
| 202 |
+
# ββ Left column: input ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 203 |
with gr.Column(scale=1):
|
| 204 |
gr.Markdown("### Input")
|
| 205 |
prob1 = gr.Textbox(
|
| 206 |
+
label="Clinical Observations",
|
| 207 |
+
lines=4,
|
| 208 |
placeholder="Example: Patient experienced acute kidney injury after taking Ibuprofen...",
|
| 209 |
+
elem_id="input-text",
|
| 210 |
)
|
| 211 |
submit_btn = gr.Button("Run Analysis", variant="primary")
|
| 212 |
+
|
| 213 |
gr.Markdown("### Examples")
|
| 214 |
gr.Examples(
|
| 215 |
examples=[
|
| 216 |
+
["A 35 year-old male had severe headache after taking Aspirin. "
|
| 217 |
+
"The lab results were normal."],
|
| 218 |
+
["A 35 year-old female had minor pain in upper abdomen after "
|
| 219 |
+
"taking Acetaminophen."],
|
| 220 |
],
|
| 221 |
+
inputs=[prob1],
|
| 222 |
)
|
| 223 |
+
|
| 224 |
+
# ββ Right column: outputs βββββββββββββββββββββββββββββββββββββββββββββ
|
| 225 |
with gr.Column(scale=1):
|
| 226 |
gr.Markdown("### Classification")
|
| 227 |
label = gr.Label(label="Severity Probability")
|
| 228 |
+
|
|
|
|
|
|
|
|
|
|
| 229 |
gr.Markdown("### Medical Entities")
|
| 230 |
+
htext_out = gr.HTML(label="NER Mapping", elem_classes="output-box")
|
| 231 |
+
|
| 232 |
gr.Markdown("### Model Logic (SHAP)")
|
| 233 |
+
shap_out = gr.HTML(label="Feature Importance", elem_classes="output-box")
|
| 234 |
+
|
| 235 |
gr.Markdown("---")
|
| 236 |
+
gr.Markdown(
|
| 237 |
+
"Disclaimer: This tool is for research purposes only and does not "
|
| 238 |
+
"constitute medical advice."
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
submit_btn.click(
|
| 242 |
fn=adr_predict,
|
| 243 |
inputs=[prob1],
|
| 244 |
+
outputs=[label, shap_out, htext_out],
|
| 245 |
)
|
| 246 |
+
|
| 247 |
+
demo.launch()
|