Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import shap | |
| import numpy as np | |
| import torch | |
| import matplotlib | |
| import os | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline | |
| from huggingface_hub import login | |
| # ── Model Setup ────────────────────────────────────────────────────────────── | |
| login(token=os.environ["uvamsba26ADR"]) | |
| MODEL_ID = "rayshunp/Paula_ADR2026Team5" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID) | |
| model.eval() | |
| # Hugging-Face pipeline used by SHAP | |
| clf_pipeline = pipeline( | |
| "text-classification", | |
| model=model, | |
| tokenizer=tokenizer, | |
| top_k=None, # needed so SHAP sees both class probabilities | |
| device=0 if torch.cuda.is_available() else -1, | |
| ) | |
| # Label map – adjust if your model uses different label names | |
| LABELS = {0: "Non-Severe", 1: "Severe"} | |
| # SHAP explainer (uses the HF pipeline directly) | |
| explainer = shap.Explainer(clf_pipeline, tokenizer) | |
| # --- Medical Information Extraction Model (NER) --- | |
| # Token classification model for extracting medical entities | |
| # (e.g. drugs, symptoms, body sites, dosages) from ADR text. | |
| # Swap NER_MODEL_ID for any token-classification model on the Hub, | |
| # for example: | |
| # "allenai/scibert_scivocab_cased" – general biomedical | |
| # "d4data/biomedical-ner-all" – multi-entity biomedical NER | |
| # "pruas/BENT-PubMedBERT-NER-Disease" – disease/symptom focused | |
| NER_MODEL_ID = "d4data/biomedical-ner-all" # ← replace with your chosen model | |
| ner_pipe = pipeline( | |
| "ner", | |
| model=NER_MODEL_ID, | |
| aggregation_strategy="simple", # merges subword tokens → whole words | |
| device=0 if torch.cuda.is_available() else -1, | |
| ) | |
| # ── Helper Functions ───────────────────────────────────────────────── | |
| # ── Use SHAP on each word ───────────────────────────────────────────────── | |
| def _merge_shap_to_words(tokens, values, original_text): | |
| """Map SHAP token values back to original whitespace-split words.""" | |
| words = original_text.split() | |
| word_values = [] | |
| token_idx = 0 | |
| for word in words: | |
| word_lower = word.lower().strip(".,!?") | |
| accumulated = "" | |
| word_vals = [] | |
| # Keep consuming tokens until we've rebuilt this word | |
| while token_idx < len(tokens) and len(accumulated) < len(word_lower): | |
| token = tokens[token_idx] | |
| # Skip special tokens | |
| if token in ("[CLS]", "[SEP]", "<s>", "</s>", "<pad>"): | |
| token_idx += 1 | |
| continue | |
| # Strip ## and punctuation for matching | |
| clean = token.lstrip("#").lower().strip(".,!?") | |
| accumulated += clean | |
| word_vals.append(values[token_idx]) | |
| token_idx += 1 | |
| avg_val = float(np.mean(word_vals)) if word_vals else 0.0 | |
| word_values.append((word, avg_val)) | |
| words_out, vals_out = zip(*word_values) if word_values else ([], []) | |
| return list(words_out), list(vals_out) | |
| # ── Merge words when using NER ───────────────────────────────────────────────── | |
| def _merge_ner_entities(entities, original_text): | |
| merged = [] | |
| for ent in entities: | |
| word = original_text[ent['start']:ent['end']] | |
| # Extend to full word boundary if token ends mid-word | |
| end = ent['end'] | |
| while end < len(original_text) and original_text[end].isalpha(): | |
| end += 1 | |
| word = original_text[ent['start']:end] | |
| if merged and ent['start'] <= merged[-1]['end'] + 2: | |
| merged[-1]['word'] = original_text[merged[-1]['start']:end] | |
| merged[-1]['end'] = end | |
| merged[-1]['score'] = (merged[-1]['score'] + float(ent['score'])) / 2 | |
| if float(ent['score']) > merged[-1]['score']: | |
| merged[-1]['entity_group'] = ent['entity_group'] | |
| else: | |
| merged.append({ | |
| 'word': word, | |
| 'entity_group': ent['entity_group'], | |
| 'score': float(ent['score']), | |
| 'start': ent['start'], | |
| 'end': end # use extended end | |
| }) | |
| return merged | |
| ENTITY_CATEGORY_MAP = { | |
| 'Sign_symptom': ('Symptom', '#e67e22'), | |
| 'Therapeutic_procedure': ('Medication/Treatment', '#3498db'), | |
| 'Biological_structure': ('Body Site', '#9b59b6'), | |
| 'Chemical': ('Medication Name', '#3498db'), | |
| 'Disease_disorder': ('Reaction', '#e74c3c'), | |
| 'Age': ('Demographics', '#1abc9c'), | |
| 'Sex': ('Demographics', '#1abc9c'), | |
| } | |
| def _build_ner_html(text, entities): | |
| merged = _merge_ner_entities(entities, text) | |
| if not merged: | |
| return f""" | |
| <div style="padding:16px; border-radius:10px; border:1px solid #888; | |
| font-family: Georgia, serif; font-size:1.05rem; line-height:2.2;"> | |
| <div style="font-size:0.78rem; opacity:0.6; margin-bottom:10px;"> | |
| 🔴 Sign Symptom | 🔵 Medication | | |
| 🟣 Body Site | 🟠 Reaction | 🟢 Demographics | |
| </div> | |
| {text} | |
| </div> | |
| """ | |
| # Build a map of character position → entity | |
| entity_map = {} | |
| for ent in merged: | |
| entity_map[(ent['start'], ent['end'])] = ent | |
| # Walk through text character by character, inserting highlighted spans | |
| parts = [] | |
| i = 0 | |
| while i < len(text): | |
| # Check if any entity starts here | |
| matched = None | |
| for (start, end), ent in entity_map.items(): | |
| if start == i: | |
| matched = (start, end, ent) | |
| break | |
| if matched: | |
| start, end, ent = matched | |
| word = text[start:end] | |
| label, color = ENTITY_CATEGORY_MAP.get( | |
| ent['entity_group'], | |
| (ent['entity_group'], '#888') | |
| ) | |
| parts.append( | |
| f'<span style="border:1px solid {color}; border-radius:6px; ' | |
| f'padding:2px 8px; margin:1px; display:inline-block; ' | |
| f'color:{color}; white-space:nowrap;">' | |
| f'<strong>{word}</strong>' | |
| f'<sup style="font-size:0.6rem; margin-left:3px; opacity:0.8;">{label}</sup>' | |
| f'</span>' | |
| ) | |
| i = end | |
| else: | |
| # Regular text — collect until next entity start or end | |
| next_entity_start = min( | |
| (s for (s, e) in entity_map.keys() if s > i), | |
| default=len(text) | |
| ) | |
| parts.append( | |
| f'<span style="display:inline;">{text[i:next_entity_start]}</span>' | |
| ) | |
| i = next_entity_start | |
| # Build legend | |
| seen_labels = set() | |
| legend_parts = [] | |
| for ent in merged: | |
| label, color = ENTITY_CATEGORY_MAP.get(ent['entity_group'], (ent['entity_group'], '#888')) | |
| if label not in seen_labels: | |
| seen_labels.add(label) | |
| legend_parts.append( | |
| f'<span style="border:1px solid {color}; color:{color}; ' | |
| f'border-radius:4px; padding:1px 8px; margin:2px; ' | |
| f'display:inline-block; font-size:0.78rem;">{label}</span>' | |
| ) | |
| return f""" | |
| <div style="padding:16px; border-radius:10px; border:1px solid #888; | |
| font-family: Georgia, serif; font-size:1.05rem; line-height:2.5;"> | |
| <div style="font-size:0.78rem; opacity:0.6; margin-bottom:10px;"> | |
| Detected medical entities (hover for confidence) | |
| </div> | |
| <div style="margin-bottom:12px;"> | |
| {''.join(parts)} | |
| </div> | |
| <div style="border-top:1px solid #888; padding-top:8px; margin-top:8px;"> | |
| {''.join(legend_parts)} | |
| </div> | |
| </div> | |
| """ | |
| # ── Core prediction function ───────────────────────────────────────────────── | |
| def predict(text: str): | |
| if not text.strip(): | |
| return "⚠️ Please enter a reaction description.", None, None | |
| # ── 1. Severity prediction ── | |
| raw = clf_pipeline(text) | |
| results = raw[0] if isinstance(raw[0], list) else raw | |
| scores = {r["label"]: r["score"] for r in results} | |
| def get_score(idx): | |
| label_key = f"LABEL_{idx}" | |
| for r in results: # loops through ALL results searching by label name | |
| if r["label"] == label_key or r["label"] == LABELS[idx]: | |
| return r["score"] | |
| return 0.5 | |
| score_non_severe = get_score(0) | |
| score_severe = get_score(1) | |
| predicted_idx = int(score_severe > score_non_severe) | |
| label = LABELS[predicted_idx] | |
| confidence = score_severe if predicted_idx == 1 else score_non_severe | |
| verdict_html = f""" | |
| <div style=" | |
| font-family: 'Courier New', monospace; | |
| border: 2px solid {'#e74c3c' if predicted_idx == 1 else '#2ecc71'}; | |
| border-radius: 10px; | |
| padding: 18px 24px; | |
| background: {'rgba(231,76,60,0.15)' if predicted_idx == 1 else 'rgba(46,204,113,0.15)'}; | |
| color: {'#e74c3c' if predicted_idx == 1 else '#2ecc71'}; | |
| text-align: center; | |
| "> | |
| <div style="font-size:2rem; font-weight:900; letter-spacing:2px;"> | |
| {'🚨 SEVERE' if predicted_idx == 1 else '✅ NON-SEVERE'} | |
| </div> | |
| <div style="font-size:1rem; margin-top:8px; opacity:0.85;"> | |
| Confidence: <strong>{confidence:.1%}</strong> | |
| </div> | |
| <div style="margin-top:12px; display:flex; gap:8px; justify-content:center; flex-wrap:wrap;"> | |
| <span style="border: 1px solid currentColor; border-radius:6px; padding:4px 12px; opacity:0.8;"> | |
| Non-Severe: {score_non_severe:.1%} | |
| </span> | |
| <span style="border: 1px solid currentColor; border-radius:6px; padding:4px 12px; opacity:0.8;"> | |
| Severe: {score_severe:.1%} | |
| </span> | |
| </div> | |
| </div> | |
| """ | |
| ner_results = ner_pipe(text) | |
| print("NER DEBUG:", ner_results) | |
| # ── 2. SHAP values ── | |
| shap_values = explainer([text]) | |
| # ── 3. Map SHAP back to original words ── | |
| tokens, vals_severe = _merge_shap_to_words( | |
| shap_values.data[0], | |
| shap_values.values[0, :, 1], | |
| text # pass original text | |
| ) | |
| vals_severe = np.array(vals_severe) | |
| # ── 4. Inline word-highlight HTML ── | |
| word_html = _build_highlight_html(tokens, vals_severe) | |
| # ── 5. NER Pipeline ── | |
| ner_results = ner_pipe(text) | |
| ner_html = _build_ner_html(text, ner_results) | |
| return verdict_html, word_html, ner_html | |
| # ── HTML word highlights ────────────────────────────────────────────────────── | |
| def _build_highlight_html(tokens, values): | |
| max_abs = max(abs(values).max(), 1e-8) | |
| parts = [] | |
| for token, val in zip(tokens, values): | |
| # Skip special tokens | |
| if token in ("[CLS]", "[SEP]", "<s>", "</s>", "<pad>"): | |
| continue | |
| intensity = abs(val) / max_abs # 0-1 | |
| alpha = 0.15 + 0.75 * intensity # 0.15-0.90 | |
| if val > 0: # pushes toward SEVERE → red | |
| bg = f"rgba(231, 76, 60, {alpha:.2f})" | |
| fg = "#fff" if intensity > 0.4 else "#111" | |
| else: # pushes toward NON-SEVERE → green | |
| bg = f"rgba(46, 204, 113, {alpha:.2f})" | |
| fg = "#fff" if intensity > 0.4 else "#111" | |
| # Strip leading ## from subword tokens for readability | |
| display = token.lstrip("#") | |
| parts.append( | |
| f'<span style="background:{bg}; color:{fg}; ' | |
| f'padding:3px 5px; border-radius:4px; margin:2px; ' | |
| f'font-weight:{"700" if intensity > 0.5 else "400"}; ' | |
| f'display:inline-block;" ' | |
| f'title="SHAP: {val:+.4f}">{display}</span>' | |
| ) | |
| html = f""" | |
| <div style=" | |
| font-family: Georgia, serif; | |
| font-size: 1.05rem; | |
| line-height: 2.2; | |
| padding: 16px; | |
| background: transparent; | |
| border-radius: 10px; | |
| border: 1px solid #888; | |
| "> | |
| <div style="font-size:0.78rem; color:#aaa; margin-bottom:10px;"> | |
| 🔴 Red = pushes toward <strong>Severe</strong> | | |
| 🟢 Green = pushes toward <strong>Non-Severe</strong> | | |
| Darker = higher importance | |
| </div> | |
| {''.join(parts)} | |
| </div> | |
| """ | |
| return html | |
| # ── SHAP bar plot ───────────────────────────────────────────────────────────── | |
| def _build_shap_plot(tokens, values, text): | |
| # Filter special tokens | |
| pairs = [ | |
| (t.lstrip("#"), v) | |
| for t, v in zip(tokens, values) | |
| if t not in ("[CLS]", "[SEP]", "<s>", "</s>", "<pad>") | |
| ] | |
| if not pairs: | |
| return None | |
| labels_plot, vals = zip(*pairs) | |
| # Sort by absolute value descending, take top 20 | |
| order = np.argsort(np.abs(vals))[::-1][:20] | |
| labels_plot = [labels_plot[i] for i in order] | |
| vals = [vals[i] for i in order] | |
| # Re-sort for display: positive first, then negative | |
| combined = sorted(zip(vals, labels_plot), key=lambda x: x[0]) | |
| vals, labels_plot = zip(*combined) | |
| colors = ["#e74c3c" if v > 0 else "#2ecc71" for v in vals] | |
| fig, ax = plt.subplots(figsize=(8, max(4, len(vals) * 0.38))) | |
| fig.patch.set_facecolor("none") # transparent | |
| ax.set_facecolor("none") # transparent | |
| ax.set_xlabel("SHAP value (impact on Severe class)", fontsize=10) | |
| ax.set_title(...) | |
| ax.tick_params(colors="gray") | |
| for spine in ax.spines.values(): | |
| spine.set_edgecolor("gray") | |
| bars = ax.barh(labels_plot, vals, color=colors, edgecolor="none", height=0.65) | |
| ax.axvline(0, color="#555", linewidth=1) | |
| ax.set_xlabel("SHAP value (impact on Severe class)", color="#ccc", fontsize=10) | |
| ax.set_title( | |
| f"Token SHAP Breakdown\n\"{text[:60]}{'…' if len(text)>60 else ''}\"", | |
| color="#444", fontsize=11, pad=12 | |
| ) | |
| ax.tick_params(colors="#444") | |
| ax.yaxis.label.set_color("#444") | |
| for spine in ax.spines.values(): | |
| spine.set_edgecolor("#bbb") | |
| # Value labels on bars | |
| for bar, val in zip(bars, vals): | |
| ax.text( | |
| val + (0.002 if val >= 0 else -0.002), | |
| bar.get_y() + bar.get_height() / 2, | |
| f"{val:+.3f}", | |
| va="center", | |
| ha="left" if val >= 0 else "right", | |
| color="#eee", | |
| fontsize=8, | |
| ) | |
| plt.tight_layout() | |
| return fig | |
| # ── Gradio UI ───────────────────────────────────────────────────────────────── | |
| EXAMPLES = [ | |
| ["I took ibuprofen for my back pain this morning. About an hour later, I noticed mild stomach discomfort and felt slightly nauseous. The symptoms went away after I ate something."], | |
| ["I was prescribed amoxicillin for a sinus infection last week. I developed a mild rash on my arms after the second dose. My doctor said it was a minor allergic reaction and switched me to a different antibiotic."], | |
| ["After receiving the contrast dye injection for my CT scan, I immediately felt my throat tightening and had difficulty breathing. I broke out in hives across my chest and face. The medical team administered epinephrine and I was kept under observation for several hours."], | |
| ["I am a 67 year old male who started taking metformin for type 2 diabetes last month. I have been experiencing persistent nausea and occasional vomiting since starting the medication. My blood sugar levels have improved but the gastrointestinal side effects are making it difficult to continue."], | |
| ["My 8 year old daughter was given penicillin for strep throat. Within 30 minutes she complained of dizziness and her lips began to swell. We rushed her to the emergency room where she was treated for anaphylaxis."], | |
| ] | |
| CSS = """ | |
| #title { | |
| text-align: center; | |
| font-family: 'Courier New', monospace; | |
| color: #7eb8f7; | |
| letter-spacing: 3px; | |
| text-transform: uppercase; | |
| margin-bottom: 4px; | |
| } | |
| #subtitle { | |
| text-align: center; | |
| color: #888; | |
| font-size: 0.9rem; | |
| margin-bottom: 20px; | |
| font-family: Georgia, serif; | |
| } | |
| .gr-button-primary { | |
| background: #3a5fc8 !important; | |
| border: none !important; | |
| font-weight: 700 !important; | |
| letter-spacing: 1px !important; | |
| } | |
| #input-card, #verdict-card, #shap-card, #ner-card { | |
| border: 1.5px solid #3a5fc8 !important; | |
| border-radius: 12px !important; | |
| box-shadow: 0 4px 12px rgba(58,95,200,0.15) !important; | |
| padding: 16px !important; | |
| } | |
| .example-btn { | |
| border: 1.5px solid #3a5fc8 !important; | |
| border-radius: 8px !important; | |
| background: white !important; | |
| font-size: 0.82rem !important; | |
| font-weight: 600 !important; | |
| white-space: normal !important; | |
| height: auto !important; | |
| text-align: center !important; | |
| line-height: 1.4 !important; | |
| padding: 12px !important; | |
| color: #222 !important; | |
| } | |
| .example-btn:hover { | |
| background: #eef2ff !important; | |
| cursor: pointer !important; | |
| } | |
| """ | |
| with gr.Blocks(css=CSS, theme=gr.themes.Base()) as demo: | |
| gr.HTML('<h1 id="title">⚕ ADR Severity Analyzer</h1>') | |
| gr.HTML('<p id="subtitle">This demo uses a DeBERTA transformer to classify severity reactions. Please note this demo is for academic purposes and should NOT be used for medical advice.</p>') | |
| # ── Input + Verdict ── | |
| with gr.Row(): | |
| with gr.Column(scale=1, elem_id="input-card"): | |
| gr.HTML("<h3 style='text-align:center; margin-top:0;'>📝 Please Enter Your Drug Reaction</h3>") | |
| text_input = gr.Textbox( | |
| label="", | |
| placeholder="e.g. I took ibuprofen for my back pain this morning.", | |
| lines=5, | |
| ) | |
| analyze_btn = gr.Button("🔍 Analyze Reaction", variant="primary") | |
| with gr.Column(scale=1, elem_id="verdict-card"): | |
| gr.HTML("<h3 style='text-align:center; margin-top:0;'>⚕ Severity Classification</h3>") | |
| verdict_out = gr.HTML() | |
| # ── Examples ── | |
| gr.HTML("<h3 style='text-align:center;'>💡 Try An Example Reaction</h3>") | |
| with gr.Row(): | |
| ex_btns = [] | |
| for example in EXAMPLES: | |
| btn = gr.Button( | |
| f'"{example[0][:80]}..."', | |
| size="sm", | |
| elem_classes="example-btn" | |
| ) | |
| ex_btns.append((btn, example[0])) | |
| for btn, text in ex_btns: | |
| btn.click(fn=lambda t=text: t, inputs=[], outputs=text_input) | |
| # ── SHAP + NER ── | |
| with gr.Row(): | |
| with gr.Column(scale=1, elem_id="shap-card"): | |
| gr.HTML("<h3 style='text-align:center; margin-top:0;'>🎨 Keywords Contributing To Severity Classification </h3>") | |
| highlight_out = gr.HTML() | |
| with gr.Column(scale=1, elem_id="ner-card"): | |
| gr.HTML("<h3 style='text-align:center; margin-top:0;'>🏥 Medical Word Recognition</h3>") | |
| ner_out = gr.HTML() | |
| analyze_btn.click( | |
| fn=predict, | |
| inputs=text_input, | |
| outputs=[verdict_out, highlight_out, ner_out], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |