File size: 5,209 Bytes
8af3270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import gradio as gr
import shap
import numpy as np
import scipy as sp
import torch
import tensorflow as tf
import transformers
from transformers import pipeline
from transformers import RobertaTokenizer, RobertaModel
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer, AutoModelForTokenClassification
import matplotlib.pyplot as plt
import sys
import csv
import os

HF_TOKEN = os.getenv("hf_token")

csv.field_size_limit(sys.maxsize)
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Load models and tokenizer
tokenizer = AutoTokenizer.from_pretrained("paragon-analytics/ADRv1", token=HF_TOKEN)
model = AutoModelForSequenceClassification.from_pretrained("paragon-analytics/ADRv1", token=HF_TOKEN).to(device)

# Build a pipeline object for predictions
pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device) 

# SHAP explainer
explainer = shap.Explainer(pred)

# NER pipeline
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu

# def adr_predict(x):
def adr_predict(x):
    # Ensure input is treated as a string
    text_input = str(x).lower()

    encoded_input = tokenizer(text_input, return_tensors='pt').to(device) # Move input to device
    output = model(**encoded_input)
    
    scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy() # Apply softmax on logits, move to cpu and convert to numpy

    try:
        shap_values = explainer([text_input])

        local_plot = shap.plots.text(shap_values[0], display=False) 

    except Exception as e:
        print(f"SHAP explanation failed: {e}")
        local_plot = "<p>SHAP explanation not available.</p>" # Provide a fallback

    # NER processing
    try:
        res = ner_pipe(text_input)

        entity_colors = {
            'Severity': 'red',
            'Sign_symptom': 'green',
            'Medication': 'lightblue',
            'Age': 'yellow',
            'Sex':'yellow',
            'Diagnostic_procedure':'gray',
            'Biological_structure':'silver'
        }
        htext = ""
        prev_end = 0

        res = sorted(res, key=lambda x: x['start'])

        for entity in res:
            start = entity['start']
            end = entity['end']
            word = text_input[start:end] # Extract original text segment
            entity_type = entity['entity_group']
            color = entity_colors.get(entity_type, 'lightgray') # Use get with a default color

            # Append text before the entity
            htext += f"{text_input[prev_end:start]}"
            # Append the highlighted entity
            htext += f"<mark style='background-color:{color};'>{word}</mark>"
            prev_end = end
        # Append any remaining text after the last entity
        htext += text_input[prev_end:]
    except Exception as e:
        print(f"NER processing failed: {e}")
        htext = "<p>NER processing not available.</p>" # Provide a fallback

    label_output = {"Severe Reaction": float(scores[1]), "Non-severe Reaction": float(scores[0])}

    return label_output, local_plot, htext

def main(prob1):
    return adr_predict(prob1)

title = "Welcome to **ADR Detector** 🪐"
description1 = """This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) adverse reaction to medicaitons. Please do NOT use for medical diagnosis."""

# Use the 'with' syntax for Blocks
with gr.Blocks(title=title) as demo:
    gr.Markdown(f"## {title}")
    gr.Markdown(description1)
    gr.Markdown("""---""")

    # Define input and output components
    prob1 = gr.Textbox(label="Enter Your Text Here:",lines=2, placeholder="Type it here ...")

    # Output components matching the return values of the main function
    label = gr.Label(label = "Predicted Label")
    local_plot = gr.HTML(label = 'Shap Explanation') 
    htext = gr.HTML(label="Named Entity Recognition") 

    submit_btn = gr.Button("Analyze")

    submit_btn.click(
        fn=main, # The function to call
        inputs=[prob1], # The input components
        outputs=[label, local_plot, htext], # The output components
        api_name="adr" # Keep the api_name if you intend to use the API
    )

    # Examples section
    gr.Markdown("### Click on any of the examples below to see how it works:")
    # Gradio 4.0+ Examples usage. Pass inputs and outputs components directly.
    # cache_examples is recommended for faster loading of examples.
    gr.Examples(
        examples=[
            ["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."],
            ["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."]
        ],
        inputs=[prob1],
        outputs=[label, local_plot, htext],
        fn=main, # Provide the function to run for caching examples
        cache_examples=False,
        run_on_click=True
    )

# Launch the demo
demo.launch()