File size: 5,112 Bytes
819d638 829d721 377bf3c 829d721 819d638 eb56a28 af523a6 ebcc2bf 9fc789c 377bf3c 9fc789c 377bf3c 9fc789c 377bf3c 9fc789c 377bf3c 9fc789c 377bf3c 9fc789c 377bf3c 9fc789c 377bf3c 9fc789c 377bf3c 38dde64 9fc789c 377bf3c 9fc789c 377bf3c 38dde64 9fc789c 377bf3c 9fc789c 071eda8 8febc26 377bf3c 1f15c83 377bf3c 9fc789c 377bf3c 1f15c83 ec59fd3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | import os
import torch
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import shap
from shap.maskers import Text
from shap.explainers import Permutation
# Device configuration
device = torch.device("cpu")
print(f"✅ Running on device: {device}")
# Load model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained("calerio-uva/roberta-adr-model").to(device).eval()
tokenizer = AutoTokenizer.from_pretrained("calerio-uva/roberta-adr-model")
# NER pipeline
ner = pipeline(
"ner",
model="d4data/biomedical-ner-all",
tokenizer="d4data/biomedical-ner-all",
aggregation_strategy="simple",
device=-1
)
# SHAP setup
clf_pipeline = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
top_k=None,
device=-1
)
def shap_predict(texts):
texts = [str(t) for t in texts]
results = clf_pipeline(texts, truncation=True, padding=True, max_length=512)
scores = []
for i, text in enumerate(texts):
if isinstance(results[i], dict):
scores.append([1 - results[i]['score'], results[i]['score']])
else:
scores.append([entry['score'] for entry in results[i]])
return np.array(scores)
masker = Text(tokenizer)
explainer = Permutation(shap_predict, masker, output_names=["Not Severe", "Severe"])
SYMPTOM_TAGS = {"sign_symptom", "symptom"}
DISEASE_TAGS = {"disease_disorder"}
MED_TAGS = {"medication", "administration", "therapeutic_procedure"}
def dedupe_and_filter(tokens):
seen, out = set(), []
for tok in tokens:
w = tok.strip()
if len(w) < 3:
continue
lw = w.lower()
if lw not in seen:
seen.add(lw)
out.append(w)
return out
def classify_adr(text, explain=False):
clean = text.strip().replace("nan", "").replace(" ", " ")[:512]
# Predict
inputs = tokenizer(clean, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
# NER
ents = ner(clean)
spans = []
for ent in ents:
grp, start, end, score = ent["entity_group"].lower(), ent["start"], ent["end"], ent.get("score", 1.0)
if spans and spans[-1]["group"] == grp and start <= spans[-1]["end"] + 1:
spans[-1]["end"] = max(spans[-1]["end"], end)
spans[-1]["score"] = max(spans[-1]["score"], score)
else:
spans.append({"group": grp, "start": start, "end": end, "score": score})
for s in spans:
if s["group"] in MED_TAGS:
en = s["end"]
while en < len(clean) and clean[en].isalpha():
en += 1
s["end"] = en
spans = [s for s in spans if s["score"] >= 0.6]
tokens = []
for s in spans:
chunk = clean[s["start"]:s["end"]].strip()
if len(chunk) >= 3:
tokens.append((chunk, s["group"]))
symptoms = dedupe_and_filter([t for t, g in tokens if g in SYMPTOM_TAGS])
diseases = dedupe_and_filter([t for t, g in tokens if g in DISEASE_TAGS])
medications = dedupe_and_filter([t for t, g in tokens if g in MED_TAGS])
if probs[1] > 0.9:
comment = "❗ High confidence this is a severe ADR."
elif probs[1] > 0.5:
comment = "⚠️ Borderline case — may be severe."
else:
comment = "✅ Likely not severe."
# SHAP explanation as image
shap_path = None
if explain:
try:
shap_values = explainer([clean], max_evals=min(400, len(clean.split()) * 5))
plt.figure()
shap.plots.bar(shap_values[0], show=False)
shap_path = "/tmp/shap_expl.png"
plt.savefig(shap_path, bbox_inches="tight")
plt.close()
except Exception as e:
print(f"[SHAP Error] {e}")
shap_path = None
return (
f"Not Severe (0): {probs[0]:.3f}\nSevere (1): {probs[1]:.3f}",
"\n".join(symptoms) or "None detected",
"\n".join(diseases) or "None detected",
"\n".join(medications) or "None detected",
comment,
shap_path
)
demo = gr.Interface(
fn=classify_adr,
inputs=[
gr.Textbox(lines=5, label="ADR Description"),
gr.Checkbox(label="Generate SHAP Explanation (VERY slow)", value=False)
],
outputs=[
gr.Textbox(label="Predicted Probabilities"),
gr.Textbox(label="Symptoms"),
gr.Textbox(label="Diseases or Conditions"),
gr.Textbox(label="Medications"),
gr.Textbox(label="Interpretation"),
gr.Image(label="SHAP Explanation")
],
title="ADR Severity & NER Classifier 2",
description="Paste an ADR description to classify severity, extract symptoms, diseases, medications, and visualize SHAP explanations.",
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) |