rayshunp's picture
Update app.py
94591c2 verified
Raw
History Blame Contribute Delete
20.1 kB
import gradio as gr
import shap
import numpy as np
import torch
import matplotlib
import os
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
from huggingface_hub import login
# ── Model Setup ──────────────────────────────────────────────────────────────
login(token=os.environ["uvamsba26ADR"])
MODEL_ID = "rayshunp/Paula_ADR2026Team5"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
model.eval()
# Hugging-Face pipeline used by SHAP
clf_pipeline = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
top_k=None, # needed so SHAP sees both class probabilities
device=0 if torch.cuda.is_available() else -1,
)
# Label map – adjust if your model uses different label names
LABELS = {0: "Non-Severe", 1: "Severe"}
# SHAP explainer (uses the HF pipeline directly)
explainer = shap.Explainer(clf_pipeline, tokenizer)
# --- Medical Information Extraction Model (NER) ---
# Token classification model for extracting medical entities
# (e.g. drugs, symptoms, body sites, dosages) from ADR text.
# Swap NER_MODEL_ID for any token-classification model on the Hub,
# for example:
# "allenai/scibert_scivocab_cased" – general biomedical
# "d4data/biomedical-ner-all" – multi-entity biomedical NER
# "pruas/BENT-PubMedBERT-NER-Disease" – disease/symptom focused
NER_MODEL_ID = "d4data/biomedical-ner-all" # ← replace with your chosen model
ner_pipe = pipeline(
"ner",
model=NER_MODEL_ID,
aggregation_strategy="simple", # merges subword tokens → whole words
device=0 if torch.cuda.is_available() else -1,
)
# ── Helper Functions ─────────────────────────────────────────────────
# ── Use SHAP on each word ─────────────────────────────────────────────────
def _merge_shap_to_words(tokens, values, original_text):
"""Map SHAP token values back to original whitespace-split words."""
words = original_text.split()
word_values = []
token_idx = 0
for word in words:
word_lower = word.lower().strip(".,!?")
accumulated = ""
word_vals = []
# Keep consuming tokens until we've rebuilt this word
while token_idx < len(tokens) and len(accumulated) < len(word_lower):
token = tokens[token_idx]
# Skip special tokens
if token in ("[CLS]", "[SEP]", "<s>", "</s>", "<pad>"):
token_idx += 1
continue
# Strip ## and punctuation for matching
clean = token.lstrip("#").lower().strip(".,!?")
accumulated += clean
word_vals.append(values[token_idx])
token_idx += 1
avg_val = float(np.mean(word_vals)) if word_vals else 0.0
word_values.append((word, avg_val))
words_out, vals_out = zip(*word_values) if word_values else ([], [])
return list(words_out), list(vals_out)
# ── Merge words when using NER ─────────────────────────────────────────────────
def _merge_ner_entities(entities, original_text):
merged = []
for ent in entities:
word = original_text[ent['start']:ent['end']]
# Extend to full word boundary if token ends mid-word
end = ent['end']
while end < len(original_text) and original_text[end].isalpha():
end += 1
word = original_text[ent['start']:end]
if merged and ent['start'] <= merged[-1]['end'] + 2:
merged[-1]['word'] = original_text[merged[-1]['start']:end]
merged[-1]['end'] = end
merged[-1]['score'] = (merged[-1]['score'] + float(ent['score'])) / 2
if float(ent['score']) > merged[-1]['score']:
merged[-1]['entity_group'] = ent['entity_group']
else:
merged.append({
'word': word,
'entity_group': ent['entity_group'],
'score': float(ent['score']),
'start': ent['start'],
'end': end # use extended end
})
return merged
ENTITY_CATEGORY_MAP = {
'Sign_symptom': ('Symptom', '#e67e22'),
'Therapeutic_procedure': ('Medication/Treatment', '#3498db'),
'Biological_structure': ('Body Site', '#9b59b6'),
'Chemical': ('Medication Name', '#3498db'),
'Disease_disorder': ('Reaction', '#e74c3c'),
'Age': ('Demographics', '#1abc9c'),
'Sex': ('Demographics', '#1abc9c'),
}
def _build_ner_html(text, entities):
merged = _merge_ner_entities(entities, text)
if not merged:
return f"""
<div style="padding:16px; border-radius:10px; border:1px solid #888;
font-family: Georgia, serif; font-size:1.05rem; line-height:2.2;">
<div style="font-size:0.78rem; opacity:0.6; margin-bottom:10px;">
🔴 Sign Symptom &nbsp;|&nbsp; 🔵 Medication &nbsp;|&nbsp;
🟣 Body Site &nbsp;|&nbsp; 🟠 Reaction &nbsp;|&nbsp; 🟢 Demographics
</div>
{text}
</div>
"""
# Build a map of character position → entity
entity_map = {}
for ent in merged:
entity_map[(ent['start'], ent['end'])] = ent
# Walk through text character by character, inserting highlighted spans
parts = []
i = 0
while i < len(text):
# Check if any entity starts here
matched = None
for (start, end), ent in entity_map.items():
if start == i:
matched = (start, end, ent)
break
if matched:
start, end, ent = matched
word = text[start:end]
label, color = ENTITY_CATEGORY_MAP.get(
ent['entity_group'],
(ent['entity_group'], '#888')
)
parts.append(
f'<span style="border:1px solid {color}; border-radius:6px; '
f'padding:2px 8px; margin:1px; display:inline-block; '
f'color:{color}; white-space:nowrap;">'
f'<strong>{word}</strong>'
f'<sup style="font-size:0.6rem; margin-left:3px; opacity:0.8;">{label}</sup>'
f'</span>'
)
i = end
else:
# Regular text — collect until next entity start or end
next_entity_start = min(
(s for (s, e) in entity_map.keys() if s > i),
default=len(text)
)
parts.append(
f'<span style="display:inline;">{text[i:next_entity_start]}</span>'
)
i = next_entity_start
# Build legend
seen_labels = set()
legend_parts = []
for ent in merged:
label, color = ENTITY_CATEGORY_MAP.get(ent['entity_group'], (ent['entity_group'], '#888'))
if label not in seen_labels:
seen_labels.add(label)
legend_parts.append(
f'<span style="border:1px solid {color}; color:{color}; '
f'border-radius:4px; padding:1px 8px; margin:2px; '
f'display:inline-block; font-size:0.78rem;">{label}</span>'
)
return f"""
<div style="padding:16px; border-radius:10px; border:1px solid #888;
font-family: Georgia, serif; font-size:1.05rem; line-height:2.5;">
<div style="font-size:0.78rem; opacity:0.6; margin-bottom:10px;">
Detected medical entities (hover for confidence)
</div>
<div style="margin-bottom:12px;">
{''.join(parts)}
</div>
<div style="border-top:1px solid #888; padding-top:8px; margin-top:8px;">
{''.join(legend_parts)}
</div>
</div>
"""
# ── Core prediction function ─────────────────────────────────────────────────
def predict(text: str):
if not text.strip():
return "⚠️ Please enter a reaction description.", None, None
# ── 1. Severity prediction ──
raw = clf_pipeline(text)
results = raw[0] if isinstance(raw[0], list) else raw
scores = {r["label"]: r["score"] for r in results}
def get_score(idx):
label_key = f"LABEL_{idx}"
for r in results: # loops through ALL results searching by label name
if r["label"] == label_key or r["label"] == LABELS[idx]:
return r["score"]
return 0.5
score_non_severe = get_score(0)
score_severe = get_score(1)
predicted_idx = int(score_severe > score_non_severe)
label = LABELS[predicted_idx]
confidence = score_severe if predicted_idx == 1 else score_non_severe
verdict_html = f"""
<div style="
font-family: 'Courier New', monospace;
border: 2px solid {'#e74c3c' if predicted_idx == 1 else '#2ecc71'};
border-radius: 10px;
padding: 18px 24px;
background: {'rgba(231,76,60,0.15)' if predicted_idx == 1 else 'rgba(46,204,113,0.15)'};
color: {'#e74c3c' if predicted_idx == 1 else '#2ecc71'};
text-align: center;
">
<div style="font-size:2rem; font-weight:900; letter-spacing:2px;">
{'🚨 SEVERE' if predicted_idx == 1 else '✅ NON-SEVERE'}
</div>
<div style="font-size:1rem; margin-top:8px; opacity:0.85;">
Confidence: <strong>{confidence:.1%}</strong>
</div>
<div style="margin-top:12px; display:flex; gap:8px; justify-content:center; flex-wrap:wrap;">
<span style="border: 1px solid currentColor; border-radius:6px; padding:4px 12px; opacity:0.8;">
Non-Severe: {score_non_severe:.1%}
</span>
<span style="border: 1px solid currentColor; border-radius:6px; padding:4px 12px; opacity:0.8;">
Severe: {score_severe:.1%}
</span>
</div>
</div>
"""
ner_results = ner_pipe(text)
print("NER DEBUG:", ner_results)
# ── 2. SHAP values ──
shap_values = explainer([text])
# ── 3. Map SHAP back to original words ──
tokens, vals_severe = _merge_shap_to_words(
shap_values.data[0],
shap_values.values[0, :, 1],
text # pass original text
)
vals_severe = np.array(vals_severe)
# ── 4. Inline word-highlight HTML ──
word_html = _build_highlight_html(tokens, vals_severe)
# ── 5. NER Pipeline ──
ner_results = ner_pipe(text)
ner_html = _build_ner_html(text, ner_results)
return verdict_html, word_html, ner_html
# ── HTML word highlights ──────────────────────────────────────────────────────
def _build_highlight_html(tokens, values):
max_abs = max(abs(values).max(), 1e-8)
parts = []
for token, val in zip(tokens, values):
# Skip special tokens
if token in ("[CLS]", "[SEP]", "<s>", "</s>", "<pad>"):
continue
intensity = abs(val) / max_abs # 0-1
alpha = 0.15 + 0.75 * intensity # 0.15-0.90
if val > 0: # pushes toward SEVERE → red
bg = f"rgba(231, 76, 60, {alpha:.2f})"
fg = "#fff" if intensity > 0.4 else "#111"
else: # pushes toward NON-SEVERE → green
bg = f"rgba(46, 204, 113, {alpha:.2f})"
fg = "#fff" if intensity > 0.4 else "#111"
# Strip leading ## from subword tokens for readability
display = token.lstrip("#")
parts.append(
f'<span style="background:{bg}; color:{fg}; '
f'padding:3px 5px; border-radius:4px; margin:2px; '
f'font-weight:{"700" if intensity > 0.5 else "400"}; '
f'display:inline-block;" '
f'title="SHAP: {val:+.4f}">{display}</span>'
)
html = f"""
<div style="
font-family: Georgia, serif;
font-size: 1.05rem;
line-height: 2.2;
padding: 16px;
background: transparent;
border-radius: 10px;
border: 1px solid #888;
">
<div style="font-size:0.78rem; color:#aaa; margin-bottom:10px;">
🔴 Red = pushes toward <strong>Severe</strong> &nbsp;|&nbsp;
🟢 Green = pushes toward <strong>Non-Severe</strong> &nbsp;|&nbsp;
Darker = higher importance
</div>
{''.join(parts)}
</div>
"""
return html
# ── SHAP bar plot ─────────────────────────────────────────────────────────────
def _build_shap_plot(tokens, values, text):
# Filter special tokens
pairs = [
(t.lstrip("#"), v)
for t, v in zip(tokens, values)
if t not in ("[CLS]", "[SEP]", "<s>", "</s>", "<pad>")
]
if not pairs:
return None
labels_plot, vals = zip(*pairs)
# Sort by absolute value descending, take top 20
order = np.argsort(np.abs(vals))[::-1][:20]
labels_plot = [labels_plot[i] for i in order]
vals = [vals[i] for i in order]
# Re-sort for display: positive first, then negative
combined = sorted(zip(vals, labels_plot), key=lambda x: x[0])
vals, labels_plot = zip(*combined)
colors = ["#e74c3c" if v > 0 else "#2ecc71" for v in vals]
fig, ax = plt.subplots(figsize=(8, max(4, len(vals) * 0.38)))
fig.patch.set_facecolor("none") # transparent
ax.set_facecolor("none") # transparent
ax.set_xlabel("SHAP value (impact on Severe class)", fontsize=10)
ax.set_title(...)
ax.tick_params(colors="gray")
for spine in ax.spines.values():
spine.set_edgecolor("gray")
bars = ax.barh(labels_plot, vals, color=colors, edgecolor="none", height=0.65)
ax.axvline(0, color="#555", linewidth=1)
ax.set_xlabel("SHAP value (impact on Severe class)", color="#ccc", fontsize=10)
ax.set_title(
f"Token SHAP Breakdown\n\"{text[:60]}{'…' if len(text)>60 else ''}\"",
color="#444", fontsize=11, pad=12
)
ax.tick_params(colors="#444")
ax.yaxis.label.set_color("#444")
for spine in ax.spines.values():
spine.set_edgecolor("#bbb")
# Value labels on bars
for bar, val in zip(bars, vals):
ax.text(
val + (0.002 if val >= 0 else -0.002),
bar.get_y() + bar.get_height() / 2,
f"{val:+.3f}",
va="center",
ha="left" if val >= 0 else "right",
color="#eee",
fontsize=8,
)
plt.tight_layout()
return fig
# ── Gradio UI ─────────────────────────────────────────────────────────────────
EXAMPLES = [
["I took ibuprofen for my back pain this morning. About an hour later, I noticed mild stomach discomfort and felt slightly nauseous. The symptoms went away after I ate something."],
["I was prescribed amoxicillin for a sinus infection last week. I developed a mild rash on my arms after the second dose. My doctor said it was a minor allergic reaction and switched me to a different antibiotic."],
["After receiving the contrast dye injection for my CT scan, I immediately felt my throat tightening and had difficulty breathing. I broke out in hives across my chest and face. The medical team administered epinephrine and I was kept under observation for several hours."],
["I am a 67 year old male who started taking metformin for type 2 diabetes last month. I have been experiencing persistent nausea and occasional vomiting since starting the medication. My blood sugar levels have improved but the gastrointestinal side effects are making it difficult to continue."],
["My 8 year old daughter was given penicillin for strep throat. Within 30 minutes she complained of dizziness and her lips began to swell. We rushed her to the emergency room where she was treated for anaphylaxis."],
]
CSS = """
#title {
text-align: center;
font-family: 'Courier New', monospace;
color: #7eb8f7;
letter-spacing: 3px;
text-transform: uppercase;
margin-bottom: 4px;
}
#subtitle {
text-align: center;
color: #888;
font-size: 0.9rem;
margin-bottom: 20px;
font-family: Georgia, serif;
}
.gr-button-primary {
background: #3a5fc8 !important;
border: none !important;
font-weight: 700 !important;
letter-spacing: 1px !important;
}
#input-card, #verdict-card, #shap-card, #ner-card {
border: 1.5px solid #3a5fc8 !important;
border-radius: 12px !important;
box-shadow: 0 4px 12px rgba(58,95,200,0.15) !important;
padding: 16px !important;
}
.example-btn {
border: 1.5px solid #3a5fc8 !important;
border-radius: 8px !important;
background: white !important;
font-size: 0.82rem !important;
font-weight: 600 !important;
white-space: normal !important;
height: auto !important;
text-align: center !important;
line-height: 1.4 !important;
padding: 12px !important;
color: #222 !important;
}
.example-btn:hover {
background: #eef2ff !important;
cursor: pointer !important;
}
"""
with gr.Blocks(css=CSS, theme=gr.themes.Base()) as demo:
gr.HTML('<h1 id="title">⚕ ADR Severity Analyzer</h1>')
gr.HTML('<p id="subtitle">This demo uses a DeBERTA transformer to classify severity reactions. Please note this demo is for academic purposes and should NOT be used for medical advice.</p>')
# ── Input + Verdict ──
with gr.Row():
with gr.Column(scale=1, elem_id="input-card"):
gr.HTML("<h3 style='text-align:center; margin-top:0;'>📝 Please Enter Your Drug Reaction</h3>")
text_input = gr.Textbox(
label="",
placeholder="e.g. I took ibuprofen for my back pain this morning.",
lines=5,
)
analyze_btn = gr.Button("🔍 Analyze Reaction", variant="primary")
with gr.Column(scale=1, elem_id="verdict-card"):
gr.HTML("<h3 style='text-align:center; margin-top:0;'>⚕ Severity Classification</h3>")
verdict_out = gr.HTML()
# ── Examples ──
gr.HTML("<h3 style='text-align:center;'>💡 Try An Example Reaction</h3>")
with gr.Row():
ex_btns = []
for example in EXAMPLES:
btn = gr.Button(
f'"{example[0][:80]}..."',
size="sm",
elem_classes="example-btn"
)
ex_btns.append((btn, example[0]))
for btn, text in ex_btns:
btn.click(fn=lambda t=text: t, inputs=[], outputs=text_input)
# ── SHAP + NER ──
with gr.Row():
with gr.Column(scale=1, elem_id="shap-card"):
gr.HTML("<h3 style='text-align:center; margin-top:0;'>🎨 Keywords Contributing To Severity Classification </h3>")
highlight_out = gr.HTML()
with gr.Column(scale=1, elem_id="ner-card"):
gr.HTML("<h3 style='text-align:center; margin-top:0;'>🏥 Medical Word Recognition</h3>")
ner_out = gr.HTML()
analyze_btn.click(
fn=predict,
inputs=text_input,
outputs=[verdict_out, highlight_out, ner_out],
)
if __name__ == "__main__":
demo.launch()