File size: 4,156 Bytes
6f713b7 34e41e3 6f713b7 fdff5a2 6f713b7 a9d0b27 6f713b7 a9d0b27 6f713b7 a9d0b27 34e41e3 6f713b7 a9d0b27 6f713b7 a9d0b27 6f713b7 a9d0b27 6f713b7 a9d0b27 6f713b7 a9d0b27 6f713b7 a9d0b27 6f713b7 34e41e3 6f713b7 34e41e3 6f713b7 34e41e3 a9d0b27 6f713b7 a9d0b27 6f713b7 a9d0b27 6f713b7 a9d0b27 6f713b7 a9d0b27 6f713b7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | import gradio as gr
import shap
import numpy as np
import scipy as sp
import torch
import transformers
from transformers import pipeline
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer, AutoModelForTokenClassification
import sys
import csv
import os
# Environment setup
HF_TOKEN = os.getenv("hf_token")
csv.field_size_limit(sys.maxsize)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Load models and tokenizer
tokenizer = AutoTokenizer.from_pretrained("willwim/WillTest", token=HF_TOKEN)
model = AutoModelForSequenceClassification.from_pretrained("willwim/WillTest", token=HF_TOKEN).to(device)
# Build a pipeline object for predictions
pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device)
# SHAP explainer
explainer = shap.Explainer(pred)
# NER pipeline
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
def adr_predict(x):
text_input = str(x).lower()
encoded_input = tokenizer(text_input, return_tensors='pt').to(device)
output = model(**encoded_input)
scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy()
# SHAP Explanation
try:
shap_values = explainer([text_input])
local_plot = shap.plots.text(shap_values[0], display=False)
except Exception as e:
print(f"SHAP explanation failed: {e}")
local_plot = "<p>SHAP explanation not available.</p>"
# NER processing
try:
res = ner_pipe(text_input)
entity_colors = {
'Severity': 'red',
'Sign_symptom': 'green',
'Medication': 'lightblue',
'Age': 'yellow',
'Sex':'yellow',
'Diagnostic_procedure':'gray',
'Biological_structure':'silver'
}
htext = ""
prev_end = 0
res = sorted(res, key=lambda x: x['start'])
for entity in res:
start = entity['start']
end = entity['end']
word = text_input[start:end]
entity_type = entity['entity_group']
color = entity_colors.get(entity_type, 'lightgray')
htext += f"{text_input[prev_end:start]}"
htext += f"<mark style='background-color:{color};'>{word}</mark>"
prev_end = end
htext += text_input[prev_end:]
except Exception as e:
print(f"NER processing failed: {e}")
htext = "<p>NER processing not available.</p>"
label_output = {"Severe Reaction": float(scores[1]), "Non-severe Reaction": float(scores[0])}
return label_output, local_plot, htext
def main(prob1):
return adr_predict(prob1)
# Gradio Interface
title = "Welcome to **ADR Detector** 🪐"
description1 = "This app predicts severe or non-severe adverse reactions to medications. Do NOT use for medical diagnosis."
with gr.Blocks(title=title) as demo:
gr.Markdown(f"## {title}")
gr.Markdown(description1)
gr.Markdown("---")
prob1 = gr.Textbox(label="Enter Your Text Here:", lines=2, placeholder="Type it here ...")
label = gr.Label(label="Predicted Label")
local_plot = gr.HTML(label='Shap Explanation')
htext = gr.HTML(label="Named Entity Recognition")
submit_btn = gr.Button("Analyze")
submit_btn.click(
fn=main,
inputs=[prob1],
outputs=[label, local_plot, htext],
api_name="adr"
)
gr.Markdown("### Click on any of the examples below to see how it works:")
gr.Examples(
examples=[
["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."],
["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."]
],
inputs=[prob1],
outputs=[label, local_plot, htext],
fn=main,
cache_examples=False,
run_on_click=True
)
demo.launch() |