import gradio as gr import shap import numpy as np import torch import transformers from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer, AutoModelForTokenClassification import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.patches as mpatches import io import base64 import string import re import sys import csv import os HF_TOKEN = os.getenv("hf_token") csv.field_size_limit(sys.maxsize) device = "cuda:0" if torch.cuda.is_available() else "cpu" # ── Load classification model ────────────────────────────────────────────────── tokenizer = AutoTokenizer.from_pretrained( "willwim/adr_SJM_Notebook-Copy_for_T3", token=HF_TOKEN ) model = AutoModelForSequenceClassification.from_pretrained( "willwim/adr_SJM_Notebook-Copy_for_T3", token=HF_TOKEN ).to(device) pred = transformers.pipeline( "text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device ) explainer = shap.Explainer(pred) # ── Load NER model ───────────────────────────────────────────────────────────── ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all") ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all") ner_pipe = pipeline( "ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple" ) # ── Severity rating renderer ────────────────────────────────────────────────── def build_severity_html(severe_prob: float) -> str: """ Converts the Severe Reaction probability into a Low / Medium / High badge with a colour-coded progress bar and plain-language description. Thresholds (tunable): < 0.35 → Low 0.35 – 0.65 → Medium >= 0.65 → High """ if severe_prob < 0.35: level = "Low" bar_color = "#2eaa5a" # green bg_color = "#eafaf1" border_col = "#2eaa5a" description = ( "The clinical text shows limited indicators of a serious adverse drug reaction. " "Routine monitoring is advised." ) elif severe_prob < 0.65: level = "Medium" bar_color = "#e8a020" # amber bg_color = "#fffbea" border_col = "#e8a020" description = ( "The clinical text contains some features associated with a significant adverse " "drug reaction. Further clinical assessment is recommended." ) else: level = "High" bar_color = "#cc1111" # red bg_color = "#fff0f0" border_col = "#cc1111" description = ( "The clinical text shows strong indicators of a severe adverse drug reaction. " "Prompt clinical review is strongly recommended." ) pct = round(severe_prob * 100, 1) bar_width = round(severe_prob * 100, 1) html = ( "
" # Rating badge + percentage on same row "
" "" + level + "" "" + str(pct) + "% severe probability" "" "
" # Progress bar track "
" "
" "
" # Description "

" + description + "

" "
" ) return html # ── Custom SHAP bar-chart renderer ───────────────────────────────────────────── def render_shap_bar_chart(shap_values, class_idx=1): values = shap_values.values tokens = shap_values.data if values.ndim == 2: sv = values[:, class_idx] else: sv = values STOPWORDS = { "a", "an", "the", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by", "from", "is", "was", "are", "were", "be", "been", "being", "have", "has", "had", "do", "does", "did", "will", "would", "could", "should", "may", "might", "that", "this", "it", "its", "he", "she", "they", "we", "i", "my", "his", "her", "their", "our", "not", "no", "so", "if", "as", "up", "out", "about", "after", "before", "then", "than", "also", "into", "over", "such", "further", "while", "which", "who", "whom", "what", "when", "where", "how", } def is_meaningful(tok): t = tok.strip().lower() if not t: return False if all(c in string.punctuation + " \t\n" for c in t): return False if len(t) <= 1: return False if t in STOPWORDS: return False if t.startswith("##"): return False if re.fullmatch(r"\d+", t): return False return True tok_arr = np.array(tokens) mask = np.array([is_meaningful(t) for t in tok_arr]) sv_f = sv[mask] tok_f = tok_arr[mask] # Deduplicate: keep highest |SHAP| per unique token seen = {} for i, tok in enumerate(tok_f): key = tok.strip().lower() if key not in seen or abs(sv_f[i]) > abs(sv_f[seen[key]]): seen[key] = i keep = sorted(seen.values()) sv_f = sv_f[keep] tok_f = tok_f[keep] TOP_N = 20 order = np.argsort(np.abs(sv_f))[::-1][:TOP_N] sv_top = sv_f[order] tok_top = tok_f[order] plot_order = np.argsort(sv_top) sv_plot = sv_top[plot_order] tok_plot = tok_top[plot_order] COLOR_POSITIVE = "#cc1111" COLOR_NEGATIVE = "#1a6fcc" colors = [COLOR_POSITIVE if v > 0 else COLOR_NEGATIVE for v in sv_plot] fig_height = max(4, len(sv_plot) * 0.38) fig, ax = plt.subplots(figsize=(8, fig_height), facecolor="white") ax.set_facecolor("white") y_pos = np.arange(len(sv_plot)) ax.barh(y_pos, sv_plot, color=colors, height=0.6, edgecolor="none") ax.axvline(0, color="#333333", linewidth=0.9, zorder=3) ax.set_yticks(y_pos) ax.set_yticklabels(tok_plot, fontsize=10, color="#222222") ax.set_xlabel("SHAP Value — impact on ADR prediction", fontsize=10, color="#444444") ax.set_title( "Token-Feature Importance: Words Driving Prediction", fontsize=12, fontweight="bold", color="#222222", pad=12 ) red_patch = mpatches.Patch(color=COLOR_POSITIVE, label="Increases severe ADR probability") blue_patch = mpatches.Patch(color=COLOR_NEGATIVE, label="Decreases severe ADR probability") ax.legend(handles=[red_patch, blue_patch], fontsize=9, loc="lower right", framealpha=0.7) ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.spines["left"].set_visible(False) ax.tick_params(axis="y", length=0) ax.tick_params(axis="x", colors="#555555") plt.tight_layout() buf = io.BytesIO() fig.savefig(buf, format="png", dpi=130, bbox_inches="tight", facecolor="white") plt.close(fig) buf.seek(0) b64 = base64.b64encode(buf.read()).decode("utf-8") return ( "
" "
" ) # ── Severity rating widget ───────────────────────────────────────────────────── def build_severity_html(severe_prob): if severe_prob >= 0.70: rating = "HIGH" rating_color = "#cc1111" rating_bg = "#fff0f0" rating_border = "#e07070" rating_desc = "Strong indicators of a serious adverse drug reaction are present." rating_icon = "\u26a0\ufe0f" elif severe_prob >= 0.40: rating = "MEDIUM" rating_color = "#b07000" rating_bg = "#fffbe6" rating_border = "#d4a800" rating_desc = "Some indicators of an adverse drug reaction detected. Clinical review recommended." rating_icon = "\U0001f536" else: rating = "LOW" rating_color = "#2a7a2a" rating_bg = "#f0fff0" rating_border = "#5aaa5a" rating_desc = "Few or no indicators of a severe adverse drug reaction detected." rating_icon = "\u2705" bar_pct = int(severe_prob * 100) return ( "
" "
" "" + rating_icon + "" "
" "
ADR Severity Rating
" "
" + rating + "
" "
" "
" "
Severe reaction probability
" "
" + str(bar_pct) + "%
" "
" "
" "
" "
" "
" "
" + rating_desc + "
" "
" ) # ── Main prediction function ─────────────────────────────────────────────────── def adr_predict(x): text_input = str(x).lower() encoded_input = tokenizer(text_input, return_tensors="pt").to(device) output = model(**encoded_input) scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy() # SHAP try: shap_values = explainer([text_input]) shap_html = render_shap_bar_chart(shap_values[0], class_idx=1) except Exception as e: shap_html = "

SHAP explanation error: " + str(e) + "

" # NER try: res = ner_pipe(text_input) entity_config = { "Severity": {"bg": "#ffe0de", "border": "#e07070", "label": "Severity"}, "Sign_symptom": {"bg": "#d4f5d4", "border": "#5aaa5a", "label": "Symptom"}, "Medication": {"bg": "#d0e8ff", "border": "#4a90d9", "label": "Medication"}, "Age": {"bg": "#fff3cc", "border": "#d4a800", "label": "Age"}, "Sex": {"bg": "#ffe8fb", "border": "#c070b0", "label": "Sex"}, "Diagnostic_procedure": {"bg": "#e8e8e8", "border": "#888888", "label": "Diagnostic"}, "Biological_structure": {"bg": "#ddeeff", "border": "#6699cc", "label": "Body Part"}, } default_cfg = {"bg": "#f0f0f0", "border": "#aaaaaa", "label": "Other"} seen_groups = list(dict.fromkeys( e["entity_group"] for e in sorted(res, key=lambda e: e["start"]) )) legend_html = ( "
" ) for grp in seen_groups: cfg = entity_config.get(grp, default_cfg) bg = cfg["bg"] border = cfg["border"] lbl = cfg["label"] legend_html += ( "" "" + lbl + "" ) legend_html += "
" text_html = ( "
" ) prev_end = 0 res_sorted = sorted(res, key=lambda e: e["start"]) for entity in res_sorted: start, end = entity["start"], entity["end"] word = text_input[start:end] cfg = entity_config.get(entity["entity_group"], default_cfg) bg = cfg["bg"] border = cfg["border"] lbl = cfg["label"] text_html += "" + text_input[prev_end:start] + "" text_html += ( "" "" + lbl + "" "" + word + "" "" ) prev_end = end text_html += "" + text_input[prev_end:] + "
" htext = legend_html + text_html except Exception as ex: htext = "

NER processing error: " + str(ex) + "

" label_output = { "Severe Reaction": float(scores[1]), "Non-severe Reaction": float(scores[0]), } severity_html = build_severity_html(float(scores[1])) return label_output, severity_html, shap_html, htext # ── UI ───────────────────────────────────────────────────────────────────────── custom_css = """ /* ── Global: light grey page background, dark text everywhere ── */ body, .gradio-container, .main, .contain, .gap { background-color: #f4f6f9 !important; font-family: 'Inter', system-ui, sans-serif !important; color: #111111 !important; } /* Force ALL text dark by default */ *, *::before, *::after { color: inherit; } /* ── Uniform white card ── */ .card { background: #ffffff !important; border: 1px solid #e2e6ea !important; border-radius: 12px !important; padding: 20px !important; box-shadow: 0 1px 4px rgba(0,0,0,0.08) !important; color: #111111 !important; margin-bottom: 12px !important; } /* Every text node inside a card */ .card *, .card p, .card span, .card div, .card label, .card h1, .card h2, .card h3, .card td, .card th, .card button:not(.primary) { color: #111111 !important; } /* Header card */ .header-card { text-align: center !important; margin-bottom: 18px !important; } /* ── Card section label (uppercase caption) ── */ .card-label { font-size: 0.72em !important; font-weight: 700 !important; letter-spacing: 0.1em !important; text-transform: uppercase !important; color: #777777 !important; margin: 0 0 10px 0 !important; padding: 0 !important; display: block !important; } /* Strip double borders / backgrounds from inner Gradio wrappers */ .card > .form, .card > .block, .card .wrap, .card [data-testid="label"], .card .label-container, .card .label-wrap, .card .prose, .card .md { background: transparent !important; border: none !important; box-shadow: none !important; padding: 0 !important; } /* ── Gradio Label (probability bar) ── */ .card .label-wrap, .card .label-wrap *, .card [data-testid="label"] * { background: transparent !important; color: #111111 !important; } /* ── Textbox ── */ .card textarea, .card input[type="text"], .card input { background: #fafafa !important; color: #111111 !important; border: 1px solid #dde1e7 !important; border-radius: 8px !important; } .card textarea::placeholder, .card input::placeholder { color: #999999 !important; } /* ── Textbox label text ── */ .card .svelte-1f354aw, .card label > span { color: #333333 !important; font-weight: 600 !important; } /* ── Examples table ── */ .card .examples, .card .examples table, .card .examples td, .card .examples th, .card .examples button { background: #ffffff !important; color: #111111 !important; border-color: #e2e6ea !important; } /* ── Run button ── */ .run-btn { width: 100% !important; margin-top: 10px !important; } /* ── Processing spinner area ── */ .card .generating, .card .eta-bar, .card .progress-bar { background: #f0f0f0 !important; color: #555555 !important; } footer { visibility: hidden; } """ with gr.Blocks(title="ADR Detector") as demo: # ── Header ──────────────────────────────────────────────────────────────── with gr.Column(elem_classes="card header-card"): gr.Markdown( "

" "Adverse Drug Reaction (ADR) Detector

" "

" "Analyze clinical text for potential medication-related severity " "and key medical entities.

" ) # ── Row 1: Input (left) | Classification + Severity (right) ────────────── with gr.Row(): with gr.Column(scale=1): with gr.Column(elem_classes="card"): gr.Markdown( "

Clinical Input

" ) prob1 = gr.Textbox( label="Clinical Observations", lines=4, placeholder="Example: Patient experienced acute kidney injury after taking Ibuprofen...", elem_id="input-text", ) submit_btn = gr.Button("Run Analysis", variant="primary", elem_classes="run-btn") with gr.Column(elem_classes="card"): gr.Markdown("

Examples

") gr.Examples( examples=[ ["A 42 year-old male developed a severe migraine and elevated blood pressure " "shortly after taking Aspirin. He was admitted for observation."], ["A 28 year-old female reported mild nausea and minor discomfort in the upper " "abdomen after taking Acetaminophen 500mg. Symptoms resolved within two hours."], ["A 67 year-old male with a history of renal impairment experienced acute kidney " "injury and oliguria following treatment with Ibuprofen for three days."], ["A 54 year-old female noted slight dizziness and dry mouth after her first dose " "of Metformin. No further symptoms were reported at the follow-up visit."], ], inputs=[prob1], ) with gr.Column(scale=1): with gr.Column(elem_classes="card"): gr.Markdown("

Classification

") label = gr.Label(label="Severity Probability") with gr.Column(elem_classes="card"): gr.Markdown("

Severity Rating

") severity_out = gr.HTML(label="Severity Rating") # ── Row 2: Medical Entities (left) | SHAP (right) ──────────────────────── with gr.Row(): with gr.Column(scale=1): with gr.Column(elem_classes="card"): gr.Markdown("

Medical Entities

") htext_out = gr.HTML(label="NER Mapping") with gr.Column(scale=1): with gr.Column(elem_classes="card"): gr.Markdown("

Model Logic (SHAP)

") shap_out = gr.HTML(label="Feature Importance") # ── Disclaimer footer ───────────────────────────────────────────────────── with gr.Column(elem_classes="card"): gr.Markdown( "
" "

" "⚠ Academic Project — Not for Medical Use

" "

" "This tool was developed strictly as an academic research project " "and is intended to demonstrate the application of natural language processing " "and explainable AI techniques to clinical text. " "It must not be used to make real medical decisions, diagnose conditions, " "guide treatment, or replace the advice of a qualified healthcare professional.

" "

" "Predictions are generated by a machine learning model and may be inaccurate, incomplete, " "or misleading. Always consult a licensed medical professional for any health-related concerns. " "The authors accept no liability for any use of this tool beyond its intended academic purpose.

" "
" ) submit_btn.click( fn=adr_predict, inputs=[prob1], outputs=[label, severity_out, shap_out, htext_out], ) demo.launch(theme=gr.themes.Default(), css=custom_css)