Team3Mod4 / app.py
willwim's picture
Upload 2 files
8af3270 verified
Raw
History Blame Contribute Delete
5.21 kB
import gradio as gr
import shap
import numpy as np
import scipy as sp
import torch
import tensorflow as tf
import transformers
from transformers import pipeline
from transformers import RobertaTokenizer, RobertaModel
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer, AutoModelForTokenClassification
import matplotlib.pyplot as plt
import sys
import csv
import os
HF_TOKEN = os.getenv("hf_token")
csv.field_size_limit(sys.maxsize)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Load models and tokenizer
tokenizer = AutoTokenizer.from_pretrained("paragon-analytics/ADRv1", token=HF_TOKEN)
model = AutoModelForSequenceClassification.from_pretrained("paragon-analytics/ADRv1", token=HF_TOKEN).to(device)
# Build a pipeline object for predictions
pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device)
# SHAP explainer
explainer = shap.Explainer(pred)
# NER pipeline
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu
# def adr_predict(x):
def adr_predict(x):
# Ensure input is treated as a string
text_input = str(x).lower()
encoded_input = tokenizer(text_input, return_tensors='pt').to(device) # Move input to device
output = model(**encoded_input)
scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy() # Apply softmax on logits, move to cpu and convert to numpy
try:
shap_values = explainer([text_input])
local_plot = shap.plots.text(shap_values[0], display=False)
except Exception as e:
print(f"SHAP explanation failed: {e}")
local_plot = "<p>SHAP explanation not available.</p>" # Provide a fallback
# NER processing
try:
res = ner_pipe(text_input)
entity_colors = {
'Severity': 'red',
'Sign_symptom': 'green',
'Medication': 'lightblue',
'Age': 'yellow',
'Sex':'yellow',
'Diagnostic_procedure':'gray',
'Biological_structure':'silver'
}
htext = ""
prev_end = 0
res = sorted(res, key=lambda x: x['start'])
for entity in res:
start = entity['start']
end = entity['end']
word = text_input[start:end] # Extract original text segment
entity_type = entity['entity_group']
color = entity_colors.get(entity_type, 'lightgray') # Use get with a default color
# Append text before the entity
htext += f"{text_input[prev_end:start]}"
# Append the highlighted entity
htext += f"<mark style='background-color:{color};'>{word}</mark>"
prev_end = end
# Append any remaining text after the last entity
htext += text_input[prev_end:]
except Exception as e:
print(f"NER processing failed: {e}")
htext = "<p>NER processing not available.</p>" # Provide a fallback
label_output = {"Severe Reaction": float(scores[1]), "Non-severe Reaction": float(scores[0])}
return label_output, local_plot, htext
def main(prob1):
return adr_predict(prob1)
title = "Welcome to **ADR Detector** 🪐"
description1 = """This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) adverse reaction to medicaitons. Please do NOT use for medical diagnosis."""
# Use the 'with' syntax for Blocks
with gr.Blocks(title=title) as demo:
gr.Markdown(f"## {title}")
gr.Markdown(description1)
gr.Markdown("""---""")
# Define input and output components
prob1 = gr.Textbox(label="Enter Your Text Here:",lines=2, placeholder="Type it here ...")
# Output components matching the return values of the main function
label = gr.Label(label = "Predicted Label")
local_plot = gr.HTML(label = 'Shap Explanation')
htext = gr.HTML(label="Named Entity Recognition")
submit_btn = gr.Button("Analyze")
submit_btn.click(
fn=main, # The function to call
inputs=[prob1], # The input components
outputs=[label, local_plot, htext], # The output components
api_name="adr" # Keep the api_name if you intend to use the API
)
# Examples section
gr.Markdown("### Click on any of the examples below to see how it works:")
# Gradio 4.0+ Examples usage. Pass inputs and outputs components directly.
# cache_examples is recommended for faster loading of examples.
gr.Examples(
examples=[
["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."],
["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."]
],
inputs=[prob1],
outputs=[label, local_plot, htext],
fn=main, # Provide the function to run for caching examples
cache_examples=False,
run_on_click=True
)
# Launch the demo
demo.launch()