File size: 5,112 Bytes
819d638
829d721
377bf3c
 
 
829d721
819d638
eb56a28
af523a6
ebcc2bf
9fc789c
377bf3c
 
 
9fc789c
377bf3c
 
 
9fc789c
377bf3c
 
 
 
 
 
 
 
9fc789c
377bf3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fc789c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377bf3c
 
9fc789c
377bf3c
 
 
 
 
9fc789c
377bf3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fc789c
 
377bf3c
38dde64
9fc789c
377bf3c
9fc789c
 
 
377bf3c
38dde64
9fc789c
 
377bf3c
 
 
 
 
 
 
9fc789c
071eda8
8febc26
377bf3c
 
 
 
1f15c83
377bf3c
 
 
 
 
 
 
9fc789c
377bf3c
1f15c83
ec59fd3
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import torch
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import shap
from shap.maskers import Text
from shap.explainers import Permutation

# Device configuration
device = torch.device("cpu")
print(f"✅ Running on device: {device}")

# Load model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained("calerio-uva/roberta-adr-model").to(device).eval()
tokenizer = AutoTokenizer.from_pretrained("calerio-uva/roberta-adr-model")

# NER pipeline
ner = pipeline(
    "ner",
    model="d4data/biomedical-ner-all",
    tokenizer="d4data/biomedical-ner-all",
    aggregation_strategy="simple",
    device=-1
)

# SHAP setup
clf_pipeline = pipeline(
    "text-classification",
    model=model,
    tokenizer=tokenizer,
    top_k=None,
    device=-1
)

def shap_predict(texts):
    texts = [str(t) for t in texts]
    results = clf_pipeline(texts, truncation=True, padding=True, max_length=512)
    scores = []
    for i, text in enumerate(texts):
        if isinstance(results[i], dict):
            scores.append([1 - results[i]['score'], results[i]['score']])
        else:
            scores.append([entry['score'] for entry in results[i]])
    return np.array(scores)

masker = Text(tokenizer)
explainer = Permutation(shap_predict, masker, output_names=["Not Severe", "Severe"])

SYMPTOM_TAGS = {"sign_symptom", "symptom"}
DISEASE_TAGS = {"disease_disorder"}
MED_TAGS     = {"medication", "administration", "therapeutic_procedure"}

def dedupe_and_filter(tokens):
    seen, out = set(), []
    for tok in tokens:
        w = tok.strip()
        if len(w) < 3:
            continue
        lw = w.lower()
        if lw not in seen:
            seen.add(lw)
            out.append(w)
    return out

def classify_adr(text, explain=False):
    clean = text.strip().replace("nan", "").replace("  ", " ")[:512]

    # Predict
    inputs = tokenizer(clean, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
    with torch.no_grad():
        logits = model(**inputs).logits
    probs = torch.softmax(logits, dim=1)[0].cpu().numpy()

    # NER
    ents = ner(clean)
    spans = []
    for ent in ents:
        grp, start, end, score = ent["entity_group"].lower(), ent["start"], ent["end"], ent.get("score", 1.0)
        if spans and spans[-1]["group"] == grp and start <= spans[-1]["end"] + 1:
            spans[-1]["end"] = max(spans[-1]["end"], end)
            spans[-1]["score"] = max(spans[-1]["score"], score)
        else:
            spans.append({"group": grp, "start": start, "end": end, "score": score})

    for s in spans:
        if s["group"] in MED_TAGS:
            en = s["end"]
            while en < len(clean) and clean[en].isalpha():
                en += 1
            s["end"] = en

    spans = [s for s in spans if s["score"] >= 0.6]

    tokens = []
    for s in spans:
        chunk = clean[s["start"]:s["end"]].strip()
        if len(chunk) >= 3:
            tokens.append((chunk, s["group"]))

    symptoms    = dedupe_and_filter([t for t, g in tokens if g in SYMPTOM_TAGS])
    diseases    = dedupe_and_filter([t for t, g in tokens if g in DISEASE_TAGS])
    medications = dedupe_and_filter([t for t, g in tokens if g in MED_TAGS])

    if probs[1] > 0.9:
        comment = "❗ High confidence this is a severe ADR."
    elif probs[1] > 0.5:
        comment = "⚠️ Borderline case — may be severe."
    else:
        comment = "✅ Likely not severe."

    # SHAP explanation as image
    shap_path = None
    if explain:
        try:
            shap_values = explainer([clean], max_evals=min(400, len(clean.split()) * 5))
            plt.figure()
            shap.plots.bar(shap_values[0], show=False)
            shap_path = "/tmp/shap_expl.png"
            plt.savefig(shap_path, bbox_inches="tight")
            plt.close()
        except Exception as e:
            print(f"[SHAP Error] {e}")
            shap_path = None

    return (
        f"Not Severe (0): {probs[0]:.3f}\nSevere (1): {probs[1]:.3f}",
        "\n".join(symptoms) or "None detected",
        "\n".join(diseases) or "None detected",
        "\n".join(medications) or "None detected",
        comment,
        shap_path
    )

demo = gr.Interface(
    fn=classify_adr,
    inputs=[
        gr.Textbox(lines=5, label="ADR Description"),
        gr.Checkbox(label="Generate SHAP Explanation (VERY slow)", value=False)
    ],
    outputs=[
        gr.Textbox(label="Predicted Probabilities"),
        gr.Textbox(label="Symptoms"),
        gr.Textbox(label="Diseases or Conditions"),
        gr.Textbox(label="Medications"),
        gr.Textbox(label="Interpretation"),
        gr.Image(label="SHAP Explanation")
    ],
    title="ADR Severity & NER Classifier 2",
    description="Paste an ADR description to classify severity, extract symptoms, diseases, medications, and visualize SHAP explanations.",
    allow_flagging="never"
)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))