Team3_Mod4 / app.py
willwim's picture
Update app.py
ba72a5a verified
Raw
History Blame Contribute Delete
23.7 kB
import gradio as gr
import shap
import numpy as np
import torch
import transformers
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer, AutoModelForTokenClassification
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import io
import base64
import string
import re
import sys
import csv
import os
HF_TOKEN = os.getenv("hf_token")
csv.field_size_limit(sys.maxsize)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# ── Load classification model ──────────────────────────────────────────────────
tokenizer = AutoTokenizer.from_pretrained(
"willwim/adr_SJM_Notebook-Copy_for_T3", token=HF_TOKEN
)
model = AutoModelForSequenceClassification.from_pretrained(
"willwim/adr_SJM_Notebook-Copy_for_T3", token=HF_TOKEN
).to(device)
pred = transformers.pipeline(
"text-classification", model=model, tokenizer=tokenizer,
top_k=None, device=device
)
explainer = shap.Explainer(pred)
# ── Load NER model ─────────────────────────────────────────────────────────────
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
ner_pipe = pipeline(
"ner", model=ner_model, tokenizer=ner_tokenizer,
aggregation_strategy="simple"
)
# ── Severity rating renderer ──────────────────────────────────────────────────
def build_severity_html(severe_prob: float) -> str:
"""
Converts the Severe Reaction probability into a Low / Medium / High badge
with a colour-coded progress bar and plain-language description.
Thresholds (tunable):
< 0.35 β†’ Low
0.35 – 0.65 β†’ Medium
>= 0.65 β†’ High
"""
if severe_prob < 0.35:
level = "Low"
bar_color = "#2eaa5a" # green
bg_color = "#eafaf1"
border_col = "#2eaa5a"
description = (
"The clinical text shows limited indicators of a serious adverse drug reaction. "
"Routine monitoring is advised."
)
elif severe_prob < 0.65:
level = "Medium"
bar_color = "#e8a020" # amber
bg_color = "#fffbea"
border_col = "#e8a020"
description = (
"The clinical text contains some features associated with a significant adverse "
"drug reaction. Further clinical assessment is recommended."
)
else:
level = "High"
bar_color = "#cc1111" # red
bg_color = "#fff0f0"
border_col = "#cc1111"
description = (
"The clinical text shows strong indicators of a severe adverse drug reaction. "
"Prompt clinical review is strongly recommended."
)
pct = round(severe_prob * 100, 1)
bar_width = round(severe_prob * 100, 1)
html = (
"<div style='"
"background:" + bg_color + "; "
"border:1.5px solid " + border_col + "; "
"border-radius:10px; padding:18px 20px; font-family:system-ui, sans-serif;'>"
# Rating badge + percentage on same row
"<div style='display:flex; align-items:center; justify-content:space-between; "
"margin-bottom:12px;'>"
"<span style='font-size:1.5em; font-weight:800; color:" + bar_color + ";'>"
+ level +
"</span>"
"<span style='font-size:1.0em; font-weight:600; color:#555;'>"
+ str(pct) + "% severe probability"
"</span>"
"</div>"
# Progress bar track
"<div style='background:#e0e0e0; border-radius:6px; height:14px; "
"overflow:hidden; margin-bottom:14px;'>"
"<div style='background:" + bar_color + "; width:" + str(bar_width) + "%; "
"height:100%; border-radius:6px; "
"transition:width 0.4s ease;'></div>"
"</div>"
# Description
"<p style='margin:0; font-size:0.9em; color:#444; line-height:1.6;'>"
+ description +
"</p>"
"</div>"
)
return html
# ── Custom SHAP bar-chart renderer ─────────────────────────────────────────────
def render_shap_bar_chart(shap_values, class_idx=1):
values = shap_values.values
tokens = shap_values.data
if values.ndim == 2:
sv = values[:, class_idx]
else:
sv = values
STOPWORDS = {
"a", "an", "the", "and", "or", "but", "in", "on", "at", "to", "for",
"of", "with", "by", "from", "is", "was", "are", "were", "be", "been",
"being", "have", "has", "had", "do", "does", "did", "will", "would",
"could", "should", "may", "might", "that", "this", "it", "its",
"he", "she", "they", "we", "i", "my", "his", "her", "their", "our",
"not", "no", "so", "if", "as", "up", "out", "about", "after",
"before", "then", "than", "also", "into", "over", "such", "further",
"while", "which", "who", "whom", "what", "when", "where", "how",
}
def is_meaningful(tok):
t = tok.strip().lower()
if not t:
return False
if all(c in string.punctuation + " \t\n" for c in t):
return False
if len(t) <= 1:
return False
if t in STOPWORDS:
return False
if t.startswith("##"):
return False
if re.fullmatch(r"\d+", t):
return False
return True
tok_arr = np.array(tokens)
mask = np.array([is_meaningful(t) for t in tok_arr])
sv_f = sv[mask]
tok_f = tok_arr[mask]
# Deduplicate: keep highest |SHAP| per unique token
seen = {}
for i, tok in enumerate(tok_f):
key = tok.strip().lower()
if key not in seen or abs(sv_f[i]) > abs(sv_f[seen[key]]):
seen[key] = i
keep = sorted(seen.values())
sv_f = sv_f[keep]
tok_f = tok_f[keep]
TOP_N = 20
order = np.argsort(np.abs(sv_f))[::-1][:TOP_N]
sv_top = sv_f[order]
tok_top = tok_f[order]
plot_order = np.argsort(sv_top)
sv_plot = sv_top[plot_order]
tok_plot = tok_top[plot_order]
COLOR_POSITIVE = "#cc1111"
COLOR_NEGATIVE = "#1a6fcc"
colors = [COLOR_POSITIVE if v > 0 else COLOR_NEGATIVE for v in sv_plot]
fig_height = max(4, len(sv_plot) * 0.38)
fig, ax = plt.subplots(figsize=(8, fig_height), facecolor="white")
ax.set_facecolor("white")
y_pos = np.arange(len(sv_plot))
ax.barh(y_pos, sv_plot, color=colors, height=0.6, edgecolor="none")
ax.axvline(0, color="#333333", linewidth=0.9, zorder=3)
ax.set_yticks(y_pos)
ax.set_yticklabels(tok_plot, fontsize=10, color="#222222")
ax.set_xlabel("SHAP Value β€” impact on ADR prediction", fontsize=10, color="#444444")
ax.set_title(
"Token-Feature Importance: Words Driving Prediction",
fontsize=12, fontweight="bold", color="#222222", pad=12
)
red_patch = mpatches.Patch(color=COLOR_POSITIVE, label="Increases severe ADR probability")
blue_patch = mpatches.Patch(color=COLOR_NEGATIVE, label="Decreases severe ADR probability")
ax.legend(handles=[red_patch, blue_patch], fontsize=9, loc="lower right", framealpha=0.7)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_visible(False)
ax.tick_params(axis="y", length=0)
ax.tick_params(axis="x", colors="#555555")
plt.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=130, bbox_inches="tight", facecolor="white")
plt.close(fig)
buf.seek(0)
b64 = base64.b64encode(buf.read()).decode("utf-8")
return (
"<div style='background:white; padding:12px; border-radius:8px;'>"
"<img src='data:image/png;base64," + b64 + "' "
"style='width:100%; max-width:760px; display:block; margin:auto;' /></div>"
)
# ── Severity rating widget ─────────────────────────────────────────────────────
def build_severity_html(severe_prob):
if severe_prob >= 0.70:
rating = "HIGH"
rating_color = "#cc1111"
rating_bg = "#fff0f0"
rating_border = "#e07070"
rating_desc = "Strong indicators of a serious adverse drug reaction are present."
rating_icon = "\u26a0\ufe0f"
elif severe_prob >= 0.40:
rating = "MEDIUM"
rating_color = "#b07000"
rating_bg = "#fffbe6"
rating_border = "#d4a800"
rating_desc = "Some indicators of an adverse drug reaction detected. Clinical review recommended."
rating_icon = "\U0001f536"
else:
rating = "LOW"
rating_color = "#2a7a2a"
rating_bg = "#f0fff0"
rating_border = "#5aaa5a"
rating_desc = "Few or no indicators of a severe adverse drug reaction detected."
rating_icon = "\u2705"
bar_pct = int(severe_prob * 100)
return (
"<div style='background:" + rating_bg + "; border:2px solid " + rating_border + "; "
"border-radius:10px; padding:16px 20px; font-family:system-ui, sans-serif;'>"
"<div style='display:flex; align-items:center; gap:12px; margin-bottom:10px;'>"
"<span style='font-size:1.8em;'>" + rating_icon + "</span>"
"<div>"
"<div style='font-size:0.75em; font-weight:700; letter-spacing:0.1em; "
"text-transform:uppercase; color:#666;'>ADR Severity Rating</div>"
"<div style='font-size:1.6em; font-weight:900; color:" + rating_color + "; "
"letter-spacing:0.04em;'>" + rating + "</div>"
"</div>"
"<div style='margin-left:auto; text-align:right;'>"
"<div style='font-size:0.72em; color:#888; font-weight:600;'>Severe reaction probability</div>"
"<div style='font-size:1.4em; font-weight:800; color:" + rating_color + ";'>"
+ str(bar_pct) + "%</div>"
"</div>"
"</div>"
"<div style='background:#e0e0e0; border-radius:999px; height:10px; margin-bottom:10px;'>"
"<div style='background:" + rating_color + "; width:" + str(bar_pct) + "%; "
"height:10px; border-radius:999px;'></div>"
"</div>"
"<div style='font-size:0.88em; color:#555; margin-top:4px;'>" + rating_desc + "</div>"
"</div>"
)
# ── Main prediction function ───────────────────────────────────────────────────
def adr_predict(x):
text_input = str(x).lower()
encoded_input = tokenizer(text_input, return_tensors="pt").to(device)
output = model(**encoded_input)
scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy()
# SHAP
try:
shap_values = explainer([text_input])
shap_html = render_shap_bar_chart(shap_values[0], class_idx=1)
except Exception as e:
shap_html = "<p style='color:red;'>SHAP explanation error: " + str(e) + "</p>"
# NER
try:
res = ner_pipe(text_input)
entity_config = {
"Severity": {"bg": "#ffe0de", "border": "#e07070", "label": "Severity"},
"Sign_symptom": {"bg": "#d4f5d4", "border": "#5aaa5a", "label": "Symptom"},
"Medication": {"bg": "#d0e8ff", "border": "#4a90d9", "label": "Medication"},
"Age": {"bg": "#fff3cc", "border": "#d4a800", "label": "Age"},
"Sex": {"bg": "#ffe8fb", "border": "#c070b0", "label": "Sex"},
"Diagnostic_procedure": {"bg": "#e8e8e8", "border": "#888888", "label": "Diagnostic"},
"Biological_structure": {"bg": "#ddeeff", "border": "#6699cc", "label": "Body Part"},
}
default_cfg = {"bg": "#f0f0f0", "border": "#aaaaaa", "label": "Other"}
seen_groups = list(dict.fromkeys(
e["entity_group"] for e in sorted(res, key=lambda e: e["start"])
))
legend_html = (
"<div style='display:flex; flex-wrap:wrap; gap:10px; "
"margin-bottom:16px; padding-bottom:12px; border-bottom:1px solid #e0e0e0;'>"
)
for grp in seen_groups:
cfg = entity_config.get(grp, default_cfg)
bg = cfg["bg"]
border = cfg["border"]
lbl = cfg["label"]
legend_html += (
"<span style='display:inline-flex; align-items:center; gap:6px; "
"font-size:0.8em; font-weight:600; color:#444; font-family:system-ui, sans-serif;'>"
"<span style='display:inline-block; width:13px; height:13px; border-radius:3px; "
"background:" + bg + "; border:2px solid " + border + ";'></span>"
+ lbl + "</span>"
)
legend_html += "</div>"
text_html = (
"<div style='line-height:3.4; font-size:1.05em; color:#111; "
"font-family:Georgia, serif; letter-spacing:0.01em;'>"
)
prev_end = 0
res_sorted = sorted(res, key=lambda e: e["start"])
for entity in res_sorted:
start, end = entity["start"], entity["end"]
word = text_input[start:end]
cfg = entity_config.get(entity["entity_group"], default_cfg)
bg = cfg["bg"]
border = cfg["border"]
lbl = cfg["label"]
text_html += "<span style='color:#111;'>" + text_input[prev_end:start] + "</span>"
text_html += (
"<span style='display:inline-block; position:relative; "
"vertical-align:middle; margin:0 2px; text-align:center;'>"
"<span style='display:block; font-size:0.6em; font-weight:800; "
"letter-spacing:0.08em; text-transform:uppercase; color:" + border + "; "
"font-family:system-ui, sans-serif; line-height:1.1; margin-bottom:2px;'>"
+ lbl + "</span>"
"<span style='background:" + bg + "; border:1.5px solid " + border + "; "
"color:#111; padding:3px 8px; border-radius:6px; "
"font-weight:600; white-space:nowrap;'>" + word + "</span>"
"</span>"
)
prev_end = end
text_html += "<span style='color:#111;'>" + text_input[prev_end:] + "</span></div>"
htext = legend_html + text_html
except Exception as ex:
htext = "<p style='color:#c00;'>NER processing error: " + str(ex) + "</p>"
label_output = {
"Severe Reaction": float(scores[1]),
"Non-severe Reaction": float(scores[0]),
}
severity_html = build_severity_html(float(scores[1]))
return label_output, severity_html, shap_html, htext
# ── UI ─────────────────────────────────────────────────────────────────────────
custom_css = """
/* ── Global: light grey page background, dark text everywhere ── */
body,
.gradio-container,
.main,
.contain,
.gap {
background-color: #f4f6f9 !important;
font-family: 'Inter', system-ui, sans-serif !important;
color: #111111 !important;
}
/* Force ALL text dark by default */
*, *::before, *::after {
color: inherit;
}
/* ── Uniform white card ── */
.card {
background: #ffffff !important;
border: 1px solid #e2e6ea !important;
border-radius: 12px !important;
padding: 20px !important;
box-shadow: 0 1px 4px rgba(0,0,0,0.08) !important;
color: #111111 !important;
margin-bottom: 12px !important;
}
/* Every text node inside a card */
.card *,
.card p,
.card span,
.card div,
.card label,
.card h1, .card h2, .card h3,
.card td, .card th,
.card button:not(.primary) {
color: #111111 !important;
}
/* Header card */
.header-card {
text-align: center !important;
margin-bottom: 18px !important;
}
/* ── Card section label (uppercase caption) ── */
.card-label {
font-size: 0.72em !important;
font-weight: 700 !important;
letter-spacing: 0.1em !important;
text-transform: uppercase !important;
color: #777777 !important;
margin: 0 0 10px 0 !important;
padding: 0 !important;
display: block !important;
}
/* Strip double borders / backgrounds from inner Gradio wrappers */
.card > .form,
.card > .block,
.card .wrap,
.card [data-testid="label"],
.card .label-container,
.card .label-wrap,
.card .prose,
.card .md {
background: transparent !important;
border: none !important;
box-shadow: none !important;
padding: 0 !important;
}
/* ── Gradio Label (probability bar) ── */
.card .label-wrap,
.card .label-wrap *,
.card [data-testid="label"] * {
background: transparent !important;
color: #111111 !important;
}
/* ── Textbox ── */
.card textarea,
.card input[type="text"],
.card input {
background: #fafafa !important;
color: #111111 !important;
border: 1px solid #dde1e7 !important;
border-radius: 8px !important;
}
.card textarea::placeholder,
.card input::placeholder {
color: #999999 !important;
}
/* ── Textbox label text ── */
.card .svelte-1f354aw,
.card label > span {
color: #333333 !important;
font-weight: 600 !important;
}
/* ── Examples table ── */
.card .examples,
.card .examples table,
.card .examples td,
.card .examples th,
.card .examples button {
background: #ffffff !important;
color: #111111 !important;
border-color: #e2e6ea !important;
}
/* ── Run button ── */
.run-btn {
width: 100% !important;
margin-top: 10px !important;
}
/* ── Processing spinner area ── */
.card .generating,
.card .eta-bar,
.card .progress-bar {
background: #f0f0f0 !important;
color: #555555 !important;
}
footer { visibility: hidden; }
"""
with gr.Blocks(title="ADR Detector") as demo:
# ── Header ────────────────────────────────────────────────────────────────
with gr.Column(elem_classes="card header-card"):
gr.Markdown(
"<h1 style='margin:0 0 6px 0; color:#1a1a2e; font-size:1.8em; text-align:center;'>"
"Adverse Drug Reaction (ADR) Detector</h1>"
"<p style='margin:0; color:#555; font-size:0.97em; text-align:center;'>"
"Analyze clinical text for potential medication-related severity "
"and key medical entities.</p>"
)
# ── Row 1: Input (left) | Classification + Severity (right) ──────────────
with gr.Row():
with gr.Column(scale=1):
with gr.Column(elem_classes="card"):
gr.Markdown(
"<p class='card-label'>Clinical Input</p>"
)
prob1 = gr.Textbox(
label="Clinical Observations",
lines=4,
placeholder="Example: Patient experienced acute kidney injury after taking Ibuprofen...",
elem_id="input-text",
)
submit_btn = gr.Button("Run Analysis", variant="primary", elem_classes="run-btn")
with gr.Column(elem_classes="card"):
gr.Markdown("<p class='card-label'>Examples</p>")
gr.Examples(
examples=[
["A 42 year-old male developed a severe migraine and elevated blood pressure "
"shortly after taking Aspirin. He was admitted for observation."],
["A 28 year-old female reported mild nausea and minor discomfort in the upper "
"abdomen after taking Acetaminophen 500mg. Symptoms resolved within two hours."],
["A 67 year-old male with a history of renal impairment experienced acute kidney "
"injury and oliguria following treatment with Ibuprofen for three days."],
["A 54 year-old female noted slight dizziness and dry mouth after her first dose "
"of Metformin. No further symptoms were reported at the follow-up visit."],
],
inputs=[prob1],
)
with gr.Column(scale=1):
with gr.Column(elem_classes="card"):
gr.Markdown("<p class='card-label'>Classification</p>")
label = gr.Label(label="Severity Probability")
with gr.Column(elem_classes="card"):
gr.Markdown("<p class='card-label'>Severity Rating</p>")
severity_out = gr.HTML(label="Severity Rating")
# ── Row 2: Medical Entities (left) | SHAP (right) ────────────────────────
with gr.Row():
with gr.Column(scale=1):
with gr.Column(elem_classes="card"):
gr.Markdown("<p class='card-label'>Medical Entities</p>")
htext_out = gr.HTML(label="NER Mapping")
with gr.Column(scale=1):
with gr.Column(elem_classes="card"):
gr.Markdown("<p class='card-label'>Model Logic (SHAP)</p>")
shap_out = gr.HTML(label="Feature Importance")
# ── Disclaimer footer ─────────────────────────────────────────────────────
with gr.Column(elem_classes="card"):
gr.Markdown(
"<div style='"
"border-left: 4px solid #e8a020; "
"padding: 14px 18px; "
"background: #fffbea; "
"border-radius: 6px; "
"font-family: system-ui, sans-serif;'>"
"<p style='margin:0 0 6px 0; font-size:0.85em; font-weight:800; "
"letter-spacing:0.08em; text-transform:uppercase; color:#b07800;'>"
"⚠ Academic Project β€” Not for Medical Use</p>"
"<p style='margin:0 0 8px 0; font-size:0.92em; color:#444; line-height:1.65;'>"
"This tool was developed strictly as an <strong>academic research project</strong> "
"and is intended to demonstrate the application of natural language processing "
"and explainable AI techniques to clinical text. "
"It <strong>must not</strong> be used to make real medical decisions, diagnose conditions, "
"guide treatment, or replace the advice of a qualified healthcare professional.</p>"
"<p style='margin:0; font-size:0.85em; color:#777; line-height:1.5;'>"
"Predictions are generated by a machine learning model and may be inaccurate, incomplete, "
"or misleading. Always consult a licensed medical professional for any health-related concerns. "
"The authors accept no liability for any use of this tool beyond its intended academic purpose.</p>"
"</div>"
)
submit_btn.click(
fn=adr_predict,
inputs=[prob1],
outputs=[label, severity_out, shap_out, htext_out],
)
demo.launch(theme=gr.themes.Default(), css=custom_css)