import gradio as gr import shap import numpy as np import scipy as sp import torch import tensorflow as tf import transformers from transformers import pipeline from transformers import RobertaTokenizer, RobertaModel from transformers import AutoModelForSequenceClassification from transformers import AutoTokenizer, AutoModelForTokenClassification import matplotlib.pyplot as plt import sys import csv import os HF_TOKEN = os.getenv("hf_token") csv.field_size_limit(sys.maxsize) device = "cuda:0" if torch.cuda.is_available() else "cpu" # Load models and tokenizer tokenizer = AutoTokenizer.from_pretrained("paragon-analytics/ADRv1", token=HF_TOKEN) model = AutoModelForSequenceClassification.from_pretrained("paragon-analytics/ADRv1", token=HF_TOKEN).to(device) # Build a pipeline object for predictions pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device) # SHAP explainer explainer = shap.Explainer(pred) # NER pipeline ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all") ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all") ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu # def adr_predict(x): def adr_predict(x): # Ensure input is treated as a string text_input = str(x).lower() encoded_input = tokenizer(text_input, return_tensors='pt').to(device) # Move input to device output = model(**encoded_input) scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy() # Apply softmax on logits, move to cpu and convert to numpy try: shap_values = explainer([text_input]) local_plot = shap.plots.text(shap_values[0], display=False) except Exception as e: print(f"SHAP explanation failed: {e}") local_plot = "
SHAP explanation not available.
" # Provide a fallback # NER processing try: res = ner_pipe(text_input) entity_colors = { 'Severity': 'red', 'Sign_symptom': 'green', 'Medication': 'lightblue', 'Age': 'yellow', 'Sex':'yellow', 'Diagnostic_procedure':'gray', 'Biological_structure':'silver' } htext = "" prev_end = 0 res = sorted(res, key=lambda x: x['start']) for entity in res: start = entity['start'] end = entity['end'] word = text_input[start:end] # Extract original text segment entity_type = entity['entity_group'] color = entity_colors.get(entity_type, 'lightgray') # Use get with a default color # Append text before the entity htext += f"{text_input[prev_end:start]}" # Append the highlighted entity htext += f"{word}" prev_end = end # Append any remaining text after the last entity htext += text_input[prev_end:] except Exception as e: print(f"NER processing failed: {e}") htext = "NER processing not available.
" # Provide a fallback label_output = {"Severe Reaction": float(scores[1]), "Non-severe Reaction": float(scores[0])} return label_output, local_plot, htext def main(prob1): return adr_predict(prob1) title = "Welcome to **ADR Detector** 🪐" description1 = """This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) adverse reaction to medicaitons. Please do NOT use for medical diagnosis.""" # Use the 'with' syntax for Blocks with gr.Blocks(title=title) as demo: gr.Markdown(f"## {title}") gr.Markdown(description1) gr.Markdown("""---""") # Define input and output components prob1 = gr.Textbox(label="Enter Your Text Here:",lines=2, placeholder="Type it here ...") # Output components matching the return values of the main function label = gr.Label(label = "Predicted Label") local_plot = gr.HTML(label = 'Shap Explanation') htext = gr.HTML(label="Named Entity Recognition") submit_btn = gr.Button("Analyze") submit_btn.click( fn=main, # The function to call inputs=[prob1], # The input components outputs=[label, local_plot, htext], # The output components api_name="adr" # Keep the api_name if you intend to use the API ) # Examples section gr.Markdown("### Click on any of the examples below to see how it works:") # Gradio 4.0+ Examples usage. Pass inputs and outputs components directly. # cache_examples is recommended for faster loading of examples. gr.Examples( examples=[ ["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."], ["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."] ], inputs=[prob1], outputs=[label, local_plot, htext], fn=main, # Provide the function to run for caching examples cache_examples=False, run_on_click=True ) # Launch the demo demo.launch()