Test / app.py
willwim's picture
Update app.py
34e41e3 verified
Raw
History Blame Contribute Delete
4.16 kB
import gradio as gr
import shap
import numpy as np
import scipy as sp
import torch
import transformers
from transformers import pipeline
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer, AutoModelForTokenClassification
import sys
import csv
import os
# Environment setup
HF_TOKEN = os.getenv("hf_token")
csv.field_size_limit(sys.maxsize)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Load models and tokenizer
tokenizer = AutoTokenizer.from_pretrained("willwim/WillTest", token=HF_TOKEN)
model = AutoModelForSequenceClassification.from_pretrained("willwim/WillTest", token=HF_TOKEN).to(device)
# Build a pipeline object for predictions
pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device)
# SHAP explainer
explainer = shap.Explainer(pred)
# NER pipeline
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
def adr_predict(x):
text_input = str(x).lower()
encoded_input = tokenizer(text_input, return_tensors='pt').to(device)
output = model(**encoded_input)
scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy()
# SHAP Explanation
try:
shap_values = explainer([text_input])
local_plot = shap.plots.text(shap_values[0], display=False)
except Exception as e:
print(f"SHAP explanation failed: {e}")
local_plot = "<p>SHAP explanation not available.</p>"
# NER processing
try:
res = ner_pipe(text_input)
entity_colors = {
'Severity': 'red',
'Sign_symptom': 'green',
'Medication': 'lightblue',
'Age': 'yellow',
'Sex':'yellow',
'Diagnostic_procedure':'gray',
'Biological_structure':'silver'
}
htext = ""
prev_end = 0
res = sorted(res, key=lambda x: x['start'])
for entity in res:
start = entity['start']
end = entity['end']
word = text_input[start:end]
entity_type = entity['entity_group']
color = entity_colors.get(entity_type, 'lightgray')
htext += f"{text_input[prev_end:start]}"
htext += f"<mark style='background-color:{color};'>{word}</mark>"
prev_end = end
htext += text_input[prev_end:]
except Exception as e:
print(f"NER processing failed: {e}")
htext = "<p>NER processing not available.</p>"
label_output = {"Severe Reaction": float(scores[1]), "Non-severe Reaction": float(scores[0])}
return label_output, local_plot, htext
def main(prob1):
return adr_predict(prob1)
# Gradio Interface
title = "Welcome to **ADR Detector** 🪐"
description1 = "This app predicts severe or non-severe adverse reactions to medications. Do NOT use for medical diagnosis."
with gr.Blocks(title=title) as demo:
gr.Markdown(f"## {title}")
gr.Markdown(description1)
gr.Markdown("---")
prob1 = gr.Textbox(label="Enter Your Text Here:", lines=2, placeholder="Type it here ...")
label = gr.Label(label="Predicted Label")
local_plot = gr.HTML(label='Shap Explanation')
htext = gr.HTML(label="Named Entity Recognition")
submit_btn = gr.Button("Analyze")
submit_btn.click(
fn=main,
inputs=[prob1],
outputs=[label, local_plot, htext],
api_name="adr"
)
gr.Markdown("### Click on any of the examples below to see how it works:")
gr.Examples(
examples=[
["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."],
["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."]
],
inputs=[prob1],
outputs=[label, local_plot, htext],
fn=main,
cache_examples=False,
run_on_click=True
)
demo.launch()