Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -13,11 +13,11 @@ import base64
|
|
| 13 |
import sys
|
| 14 |
import csv
|
| 15 |
import os
|
| 16 |
-
|
| 17 |
HF_TOKEN = os.getenv("hf_token")
|
| 18 |
csv.field_size_limit(sys.maxsize)
|
| 19 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 20 |
-
|
| 21 |
# ββ Load classification model ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 22 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 23 |
"willwim/adr_SJM_Notebook-Copy_for_T3", token=HF_TOKEN
|
|
@@ -25,14 +25,14 @@ tokenizer = AutoTokenizer.from_pretrained(
|
|
| 25 |
model = AutoModelForSequenceClassification.from_pretrained(
|
| 26 |
"willwim/adr_SJM_Notebook-Copy_for_T3", token=HF_TOKEN
|
| 27 |
).to(device)
|
| 28 |
-
|
| 29 |
pred = transformers.pipeline(
|
| 30 |
"text-classification", model=model, tokenizer=tokenizer,
|
| 31 |
top_k=None, device=device
|
| 32 |
)
|
| 33 |
-
|
| 34 |
explainer = shap.Explainer(pred)
|
| 35 |
-
|
| 36 |
# ββ Load NER model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 37 |
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
|
| 38 |
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
|
|
@@ -40,8 +40,8 @@ ner_pipe = pipeline(
|
|
| 40 |
"ner", model=ner_model, tokenizer=ner_tokenizer,
|
| 41 |
aggregation_strategy="simple"
|
| 42 |
)
|
| 43 |
-
|
| 44 |
-
|
| 45 |
# ββ Custom SHAP bar-chart renderer βββββββββββββββββββββββββββββββββββββββββββββ
|
| 46 |
def render_shap_bar_chart(shap_values, class_idx: int = 1) -> str:
|
| 47 |
"""
|
|
@@ -53,59 +53,63 @@ def render_shap_bar_chart(shap_values, class_idx: int = 1) -> str:
|
|
| 53 |
# .values shape: (n_tokens, n_classes) or (n_tokens,) when binary
|
| 54 |
values = shap_values.values # (n_tokens, n_classes)
|
| 55 |
tokens = shap_values.data # list/array of token strings
|
| 56 |
-
|
| 57 |
if values.ndim == 2:
|
| 58 |
sv = values[:, class_idx] # SHAP values for "Severe Reaction"
|
| 59 |
else:
|
| 60 |
sv = values
|
| 61 |
-
|
| 62 |
# Sort by absolute magnitude and keep top-N for readability
|
| 63 |
TOP_N = 20
|
| 64 |
order = np.argsort(np.abs(sv))[::-1][:TOP_N]
|
| 65 |
sv_top = sv[order]
|
| 66 |
tok_top = np.array(tokens)[order]
|
| 67 |
-
|
| 68 |
# Re-sort so the chart reads top-to-bottom by value (positive on top)
|
| 69 |
plot_order = np.argsort(sv_top)
|
| 70 |
sv_plot = sv_top[plot_order]
|
| 71 |
tok_plot = tok_top[plot_order]
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
| 75 |
fig_height = max(4, len(sv_plot) * 0.38)
|
| 76 |
fig, ax = plt.subplots(figsize=(8, fig_height), facecolor="white")
|
| 77 |
ax.set_facecolor("white")
|
| 78 |
-
|
| 79 |
y_pos = np.arange(len(sv_plot))
|
| 80 |
bars = ax.barh(y_pos, sv_plot, color=colors, height=0.6, edgecolor="none")
|
| 81 |
-
|
| 82 |
# Zero line
|
| 83 |
ax.axvline(0, color="#333333", linewidth=0.9, zorder=3)
|
| 84 |
-
|
| 85 |
ax.set_yticks(y_pos)
|
| 86 |
ax.set_yticklabels(tok_plot, fontsize=10, color="#222222")
|
| 87 |
ax.set_xlabel("SHAP Value β impact on ADR prediction", fontsize=10, color="#444444")
|
| 88 |
ax.set_title(
|
| 89 |
-
"Token-
|
| 90 |
-
|
| 91 |
-
fontsize=11, color="#222222", pad=10
|
| 92 |
)
|
| 93 |
-
|
| 94 |
-
# Legend patches
|
| 95 |
-
red_patch = mpatches.Patch(color=
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
| 98 |
loc="lower right", framealpha=0.7)
|
| 99 |
-
|
| 100 |
ax.spines["top"].set_visible(False)
|
| 101 |
ax.spines["right"].set_visible(False)
|
| 102 |
ax.spines["left"].set_visible(False)
|
| 103 |
ax.tick_params(axis="y", length=0)
|
| 104 |
ax.tick_params(axis="x", colors="#555555")
|
| 105 |
ax.xaxis.label.set_color("#555555")
|
| 106 |
-
|
| 107 |
plt.tight_layout()
|
| 108 |
-
|
| 109 |
buf = io.BytesIO()
|
| 110 |
fig.savefig(buf, format="png", dpi=130, bbox_inches="tight",
|
| 111 |
facecolor="white")
|
|
@@ -118,22 +122,22 @@ def render_shap_bar_chart(shap_values, class_idx: int = 1) -> str:
|
|
| 118 |
f"style='width:100%; max-width:760px; display:block; margin:auto;' />"
|
| 119 |
f"</div>"
|
| 120 |
)
|
| 121 |
-
|
| 122 |
-
|
| 123 |
# ββ Main prediction function βββββββββββββββββββββββββββββββββββββββββββββββββοΏ½οΏ½οΏ½β
|
| 124 |
def adr_predict(x):
|
| 125 |
text_input = str(x).lower()
|
| 126 |
encoded_input = tokenizer(text_input, return_tensors="pt").to(device)
|
| 127 |
output = model(**encoded_input)
|
| 128 |
scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy()
|
| 129 |
-
|
| 130 |
# ββ SHAP (bar chart) ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 131 |
try:
|
| 132 |
shap_values = explainer([text_input])
|
| 133 |
shap_html = render_shap_bar_chart(shap_values[0], class_idx=1)
|
| 134 |
except Exception as e:
|
| 135 |
shap_html = f"<p style='color:red;'>SHAP explanation error: {e}</p>"
|
| 136 |
-
|
| 137 |
# ββ NER βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 138 |
try:
|
| 139 |
res = ner_pipe(text_input)
|
|
@@ -146,7 +150,7 @@ def adr_predict(x):
|
|
| 146 |
"Diagnostic_procedure": "#eeeeee",
|
| 147 |
"Biological_structure": "#d9d9d9",
|
| 148 |
}
|
| 149 |
-
|
| 150 |
htext = "<div style='line-height:2.0; font-size:1.1em; color:black;'>"
|
| 151 |
prev_end = 0
|
| 152 |
res = sorted(res, key=lambda e: e["start"])
|
|
@@ -166,15 +170,15 @@ def adr_predict(x):
|
|
| 166 |
htext += text_input[prev_end:] + "</div>"
|
| 167 |
except Exception:
|
| 168 |
htext = "<p style='color:black;'>NER processing error.</p>"
|
| 169 |
-
|
| 170 |
label_output = {
|
| 171 |
"Severe Reaction": float(scores[1]),
|
| 172 |
"Non-severe Reaction": float(scores[0]),
|
| 173 |
}
|
| 174 |
-
|
| 175 |
return label_output, shap_html, htext
|
| 176 |
-
|
| 177 |
-
|
| 178 |
# ββ UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 179 |
custom_css = """
|
| 180 |
.gradio-container { font-family: 'Inter', system-ui, sans-serif; }
|
|
@@ -188,16 +192,16 @@ custom_css = """
|
|
| 188 |
}
|
| 189 |
footer { visibility: hidden; }
|
| 190 |
"""
|
| 191 |
-
|
| 192 |
with gr.Blocks(title="ADR Detector", css=custom_css, theme=gr.themes.Soft()) as demo:
|
| 193 |
-
|
| 194 |
with gr.Column(elem_classes="main-header"):
|
| 195 |
gr.Markdown("# Adverse Drug Reaction (ADR) Detector")
|
| 196 |
gr.Markdown(
|
| 197 |
"Analyze clinical text for potential medication-related severity "
|
| 198 |
"and key medical entities."
|
| 199 |
)
|
| 200 |
-
|
| 201 |
with gr.Row():
|
| 202 |
# ββ Left column: input ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 203 |
with gr.Column(scale=1):
|
|
@@ -209,7 +213,7 @@ with gr.Blocks(title="ADR Detector", css=custom_css, theme=gr.themes.Soft()) as
|
|
| 209 |
elem_id="input-text",
|
| 210 |
)
|
| 211 |
submit_btn = gr.Button("Run Analysis", variant="primary")
|
| 212 |
-
|
| 213 |
gr.Markdown("### Examples")
|
| 214 |
gr.Examples(
|
| 215 |
examples=[
|
|
@@ -220,28 +224,28 @@ with gr.Blocks(title="ADR Detector", css=custom_css, theme=gr.themes.Soft()) as
|
|
| 220 |
],
|
| 221 |
inputs=[prob1],
|
| 222 |
)
|
| 223 |
-
|
| 224 |
# ββ Right column: outputs βββββββββββββββββββββββββββββββββββββββββββββ
|
| 225 |
with gr.Column(scale=1):
|
| 226 |
gr.Markdown("### Classification")
|
| 227 |
label = gr.Label(label="Severity Probability")
|
| 228 |
-
|
| 229 |
gr.Markdown("### Medical Entities")
|
| 230 |
htext_out = gr.HTML(label="NER Mapping", elem_classes="output-box")
|
| 231 |
-
|
| 232 |
gr.Markdown("### Model Logic (SHAP)")
|
| 233 |
shap_out = gr.HTML(label="Feature Importance", elem_classes="output-box")
|
| 234 |
-
|
| 235 |
gr.Markdown("---")
|
| 236 |
gr.Markdown(
|
| 237 |
"Disclaimer: This tool is for research purposes only and does not "
|
| 238 |
"constitute medical advice."
|
| 239 |
)
|
| 240 |
-
|
| 241 |
submit_btn.click(
|
| 242 |
fn=adr_predict,
|
| 243 |
inputs=[prob1],
|
| 244 |
outputs=[label, shap_out, htext_out],
|
| 245 |
)
|
| 246 |
-
|
| 247 |
demo.launch()
|
|
|
|
| 13 |
import sys
|
| 14 |
import csv
|
| 15 |
import os
|
| 16 |
+
|
| 17 |
HF_TOKEN = os.getenv("hf_token")
|
| 18 |
csv.field_size_limit(sys.maxsize)
|
| 19 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 20 |
+
|
| 21 |
# ββ Load classification model ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 22 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 23 |
"willwim/adr_SJM_Notebook-Copy_for_T3", token=HF_TOKEN
|
|
|
|
| 25 |
model = AutoModelForSequenceClassification.from_pretrained(
|
| 26 |
"willwim/adr_SJM_Notebook-Copy_for_T3", token=HF_TOKEN
|
| 27 |
).to(device)
|
| 28 |
+
|
| 29 |
pred = transformers.pipeline(
|
| 30 |
"text-classification", model=model, tokenizer=tokenizer,
|
| 31 |
top_k=None, device=device
|
| 32 |
)
|
| 33 |
+
|
| 34 |
explainer = shap.Explainer(pred)
|
| 35 |
+
|
| 36 |
# ββ Load NER model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 37 |
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
|
| 38 |
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
|
|
|
|
| 40 |
"ner", model=ner_model, tokenizer=ner_tokenizer,
|
| 41 |
aggregation_strategy="simple"
|
| 42 |
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
# ββ Custom SHAP bar-chart renderer βββββββββββββββββββββββββββββββββββββββββββββ
|
| 46 |
def render_shap_bar_chart(shap_values, class_idx: int = 1) -> str:
|
| 47 |
"""
|
|
|
|
| 53 |
# .values shape: (n_tokens, n_classes) or (n_tokens,) when binary
|
| 54 |
values = shap_values.values # (n_tokens, n_classes)
|
| 55 |
tokens = shap_values.data # list/array of token strings
|
| 56 |
+
|
| 57 |
if values.ndim == 2:
|
| 58 |
sv = values[:, class_idx] # SHAP values for "Severe Reaction"
|
| 59 |
else:
|
| 60 |
sv = values
|
| 61 |
+
|
| 62 |
# Sort by absolute magnitude and keep top-N for readability
|
| 63 |
TOP_N = 20
|
| 64 |
order = np.argsort(np.abs(sv))[::-1][:TOP_N]
|
| 65 |
sv_top = sv[order]
|
| 66 |
tok_top = np.array(tokens)[order]
|
| 67 |
+
|
| 68 |
# Re-sort so the chart reads top-to-bottom by value (positive on top)
|
| 69 |
plot_order = np.argsort(sv_top)
|
| 70 |
sv_plot = sv_top[plot_order]
|
| 71 |
tok_plot = tok_top[plot_order]
|
| 72 |
+
|
| 73 |
+
COLOR_POSITIVE = "#cc1111" # bold red β increases severe ADR probability
|
| 74 |
+
COLOR_NEGATIVE = "#1a6fcc" # strong blue β decreases severe ADR probability
|
| 75 |
+
|
| 76 |
+
colors = [COLOR_POSITIVE if v > 0 else COLOR_NEGATIVE for v in sv_plot]
|
| 77 |
+
|
| 78 |
fig_height = max(4, len(sv_plot) * 0.38)
|
| 79 |
fig, ax = plt.subplots(figsize=(8, fig_height), facecolor="white")
|
| 80 |
ax.set_facecolor("white")
|
| 81 |
+
|
| 82 |
y_pos = np.arange(len(sv_plot))
|
| 83 |
bars = ax.barh(y_pos, sv_plot, color=colors, height=0.6, edgecolor="none")
|
| 84 |
+
|
| 85 |
# Zero line
|
| 86 |
ax.axvline(0, color="#333333", linewidth=0.9, zorder=3)
|
| 87 |
+
|
| 88 |
ax.set_yticks(y_pos)
|
| 89 |
ax.set_yticklabels(tok_plot, fontsize=10, color="#222222")
|
| 90 |
ax.set_xlabel("SHAP Value β impact on ADR prediction", fontsize=10, color="#444444")
|
| 91 |
ax.set_title(
|
| 92 |
+
"Token-Feature Importance: Words Driving Prediction",
|
| 93 |
+
fontsize=12, fontweight="bold", color="#222222", pad=12
|
|
|
|
| 94 |
)
|
| 95 |
+
|
| 96 |
+
# Legend patches β colors match the bars exactly
|
| 97 |
+
red_patch = mpatches.Patch(color=COLOR_POSITIVE,
|
| 98 |
+
label="Increases severe ADR probability")
|
| 99 |
+
blue_patch = mpatches.Patch(color=COLOR_NEGATIVE,
|
| 100 |
+
label="Decreases severe ADR probability")
|
| 101 |
+
ax.legend(handles=[red_patch, blue_patch], fontsize=9,
|
| 102 |
loc="lower right", framealpha=0.7)
|
| 103 |
+
|
| 104 |
ax.spines["top"].set_visible(False)
|
| 105 |
ax.spines["right"].set_visible(False)
|
| 106 |
ax.spines["left"].set_visible(False)
|
| 107 |
ax.tick_params(axis="y", length=0)
|
| 108 |
ax.tick_params(axis="x", colors="#555555")
|
| 109 |
ax.xaxis.label.set_color("#555555")
|
| 110 |
+
|
| 111 |
plt.tight_layout()
|
| 112 |
+
|
| 113 |
buf = io.BytesIO()
|
| 114 |
fig.savefig(buf, format="png", dpi=130, bbox_inches="tight",
|
| 115 |
facecolor="white")
|
|
|
|
| 122 |
f"style='width:100%; max-width:760px; display:block; margin:auto;' />"
|
| 123 |
f"</div>"
|
| 124 |
)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
# ββ Main prediction function βββββββββββββββββββββββββββββββββββββββββββββββββοΏ½οΏ½οΏ½β
|
| 128 |
def adr_predict(x):
|
| 129 |
text_input = str(x).lower()
|
| 130 |
encoded_input = tokenizer(text_input, return_tensors="pt").to(device)
|
| 131 |
output = model(**encoded_input)
|
| 132 |
scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy()
|
| 133 |
+
|
| 134 |
# ββ SHAP (bar chart) ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 135 |
try:
|
| 136 |
shap_values = explainer([text_input])
|
| 137 |
shap_html = render_shap_bar_chart(shap_values[0], class_idx=1)
|
| 138 |
except Exception as e:
|
| 139 |
shap_html = f"<p style='color:red;'>SHAP explanation error: {e}</p>"
|
| 140 |
+
|
| 141 |
# ββ NER βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 142 |
try:
|
| 143 |
res = ner_pipe(text_input)
|
|
|
|
| 150 |
"Diagnostic_procedure": "#eeeeee",
|
| 151 |
"Biological_structure": "#d9d9d9",
|
| 152 |
}
|
| 153 |
+
|
| 154 |
htext = "<div style='line-height:2.0; font-size:1.1em; color:black;'>"
|
| 155 |
prev_end = 0
|
| 156 |
res = sorted(res, key=lambda e: e["start"])
|
|
|
|
| 170 |
htext += text_input[prev_end:] + "</div>"
|
| 171 |
except Exception:
|
| 172 |
htext = "<p style='color:black;'>NER processing error.</p>"
|
| 173 |
+
|
| 174 |
label_output = {
|
| 175 |
"Severe Reaction": float(scores[1]),
|
| 176 |
"Non-severe Reaction": float(scores[0]),
|
| 177 |
}
|
| 178 |
+
|
| 179 |
return label_output, shap_html, htext
|
| 180 |
+
|
| 181 |
+
|
| 182 |
# ββ UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 183 |
custom_css = """
|
| 184 |
.gradio-container { font-family: 'Inter', system-ui, sans-serif; }
|
|
|
|
| 192 |
}
|
| 193 |
footer { visibility: hidden; }
|
| 194 |
"""
|
| 195 |
+
|
| 196 |
with gr.Blocks(title="ADR Detector", css=custom_css, theme=gr.themes.Soft()) as demo:
|
| 197 |
+
|
| 198 |
with gr.Column(elem_classes="main-header"):
|
| 199 |
gr.Markdown("# Adverse Drug Reaction (ADR) Detector")
|
| 200 |
gr.Markdown(
|
| 201 |
"Analyze clinical text for potential medication-related severity "
|
| 202 |
"and key medical entities."
|
| 203 |
)
|
| 204 |
+
|
| 205 |
with gr.Row():
|
| 206 |
# ββ Left column: input ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 207 |
with gr.Column(scale=1):
|
|
|
|
| 213 |
elem_id="input-text",
|
| 214 |
)
|
| 215 |
submit_btn = gr.Button("Run Analysis", variant="primary")
|
| 216 |
+
|
| 217 |
gr.Markdown("### Examples")
|
| 218 |
gr.Examples(
|
| 219 |
examples=[
|
|
|
|
| 224 |
],
|
| 225 |
inputs=[prob1],
|
| 226 |
)
|
| 227 |
+
|
| 228 |
# ββ Right column: outputs βββββββββββββββββββββββββββββββββββββββββββββ
|
| 229 |
with gr.Column(scale=1):
|
| 230 |
gr.Markdown("### Classification")
|
| 231 |
label = gr.Label(label="Severity Probability")
|
| 232 |
+
|
| 233 |
gr.Markdown("### Medical Entities")
|
| 234 |
htext_out = gr.HTML(label="NER Mapping", elem_classes="output-box")
|
| 235 |
+
|
| 236 |
gr.Markdown("### Model Logic (SHAP)")
|
| 237 |
shap_out = gr.HTML(label="Feature Importance", elem_classes="output-box")
|
| 238 |
+
|
| 239 |
gr.Markdown("---")
|
| 240 |
gr.Markdown(
|
| 241 |
"Disclaimer: This tool is for research purposes only and does not "
|
| 242 |
"constitute medical advice."
|
| 243 |
)
|
| 244 |
+
|
| 245 |
submit_btn.click(
|
| 246 |
fn=adr_predict,
|
| 247 |
inputs=[prob1],
|
| 248 |
outputs=[label, shap_out, htext_out],
|
| 249 |
)
|
| 250 |
+
|
| 251 |
demo.launch()
|