File size: 4,156 Bytes
6f713b7
 
 
 
 
 
 
 
 
 
 
 
 
34e41e3
6f713b7
 
 
 
 
fdff5a2
 
6f713b7
 
 
 
 
 
 
 
 
 
a9d0b27
6f713b7
 
 
a9d0b27
6f713b7
 
a9d0b27
 
34e41e3
6f713b7
 
a9d0b27
6f713b7
 
a9d0b27
6f713b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9d0b27
6f713b7
a9d0b27
 
6f713b7
 
 
a9d0b27
6f713b7
 
 
a9d0b27
6f713b7
 
 
 
 
 
 
34e41e3
6f713b7
34e41e3
6f713b7
 
 
 
34e41e3
a9d0b27
 
 
 
 
6f713b7
a9d0b27
6f713b7
 
a9d0b27
 
 
 
6f713b7
a9d0b27
6f713b7
 
 
 
 
 
 
 
a9d0b27
6f713b7
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
import shap
import numpy as np
import scipy as sp
import torch
import transformers
from transformers import pipeline
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer, AutoModelForTokenClassification
import sys
import csv
import os

# Environment setup
HF_TOKEN = os.getenv("hf_token")
csv.field_size_limit(sys.maxsize)
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Load models and tokenizer
tokenizer = AutoTokenizer.from_pretrained("willwim/WillTest", token=HF_TOKEN)
model = AutoModelForSequenceClassification.from_pretrained("willwim/WillTest", token=HF_TOKEN).to(device)

# Build a pipeline object for predictions
pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device) 

# SHAP explainer
explainer = shap.Explainer(pred)

# NER pipeline
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")

def adr_predict(x):
    text_input = str(x).lower()
    encoded_input = tokenizer(text_input, return_tensors='pt').to(device)
    output = model(**encoded_input)
    
    scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy()
    
    # SHAP Explanation
    try:
        shap_values = explainer([text_input])
        local_plot = shap.plots.text(shap_values[0], display=False)
    except Exception as e:
        print(f"SHAP explanation failed: {e}")
        local_plot = "<p>SHAP explanation not available.</p>" 

    # NER processing
    try:
        res = ner_pipe(text_input)
        entity_colors = {
            'Severity': 'red',
            'Sign_symptom': 'green',
            'Medication': 'lightblue',
            'Age': 'yellow',
            'Sex':'yellow',
            'Diagnostic_procedure':'gray',
            'Biological_structure':'silver'
        }
        htext = ""
        prev_end = 0
        res = sorted(res, key=lambda x: x['start'])
        for entity in res:
            start = entity['start']
            end = entity['end']
            word = text_input[start:end]
            entity_type = entity['entity_group']
            color = entity_colors.get(entity_type, 'lightgray')
            
            htext += f"{text_input[prev_end:start]}"
            htext += f"<mark style='background-color:{color};'>{word}</mark>"
            prev_end = end
            
        htext += text_input[prev_end:]
    except Exception as e:
        print(f"NER processing failed: {e}")
        htext = "<p>NER processing not available.</p>"

    label_output = {"Severe Reaction": float(scores[1]), "Non-severe Reaction": float(scores[0])}
    return label_output, local_plot, htext

def main(prob1):
    return adr_predict(prob1)

# Gradio Interface
title = "Welcome to **ADR Detector** 🪐"
description1 = "This app predicts severe or non-severe adverse reactions to medications. Do NOT use for medical diagnosis."

with gr.Blocks(title=title) as demo:
    gr.Markdown(f"## {title}")
    gr.Markdown(description1)
    gr.Markdown("---")
    
    prob1 = gr.Textbox(label="Enter Your Text Here:", lines=2, placeholder="Type it here ...")
    
    label = gr.Label(label="Predicted Label")
    local_plot = gr.HTML(label='Shap Explanation') 
    htext = gr.HTML(label="Named Entity Recognition") 
    
    submit_btn = gr.Button("Analyze")
    submit_btn.click(
        fn=main,
        inputs=[prob1],
        outputs=[label, local_plot, htext],
        api_name="adr"
    )
    
    gr.Markdown("### Click on any of the examples below to see how it works:")
    gr.Examples(
        examples=[
            ["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."],
            ["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."]
        ],
        inputs=[prob1],
        outputs=[label, local_plot, htext],
        fn=main,
        cache_examples=False,
        run_on_click=True
    )

demo.launch()