import gradio as gr import shap import numpy as np import scipy as sp import torch import transformers from transformers import pipeline from transformers import AutoModelForSequenceClassification from transformers import AutoTokenizer, AutoModelForTokenClassification import sys import csv import os # Environment setup HF_TOKEN = os.getenv("hf_token") csv.field_size_limit(sys.maxsize) device = "cuda:0" if torch.cuda.is_available() else "cpu" # Load models and tokenizer tokenizer = AutoTokenizer.from_pretrained("willwim/WillTest", token=HF_TOKEN) model = AutoModelForSequenceClassification.from_pretrained("willwim/WillTest", token=HF_TOKEN).to(device) # Build a pipeline object for predictions pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device) # SHAP explainer explainer = shap.Explainer(pred) # NER pipeline ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all") ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all") ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") def adr_predict(x): text_input = str(x).lower() encoded_input = tokenizer(text_input, return_tensors='pt').to(device) output = model(**encoded_input) scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy() # SHAP Explanation try: shap_values = explainer([text_input]) local_plot = shap.plots.text(shap_values[0], display=False) except Exception as e: print(f"SHAP explanation failed: {e}") local_plot = "

SHAP explanation not available.

" # NER processing try: res = ner_pipe(text_input) entity_colors = { 'Severity': 'red', 'Sign_symptom': 'green', 'Medication': 'lightblue', 'Age': 'yellow', 'Sex':'yellow', 'Diagnostic_procedure':'gray', 'Biological_structure':'silver' } htext = "" prev_end = 0 res = sorted(res, key=lambda x: x['start']) for entity in res: start = entity['start'] end = entity['end'] word = text_input[start:end] entity_type = entity['entity_group'] color = entity_colors.get(entity_type, 'lightgray') htext += f"{text_input[prev_end:start]}" htext += f"{word}" prev_end = end htext += text_input[prev_end:] except Exception as e: print(f"NER processing failed: {e}") htext = "

NER processing not available.

" label_output = {"Severe Reaction": float(scores[1]), "Non-severe Reaction": float(scores[0])} return label_output, local_plot, htext def main(prob1): return adr_predict(prob1) # Gradio Interface title = "Welcome to **ADR Detector** 🪐" description1 = "This app predicts severe or non-severe adverse reactions to medications. Do NOT use for medical diagnosis." with gr.Blocks(title=title) as demo: gr.Markdown(f"## {title}") gr.Markdown(description1) gr.Markdown("---") prob1 = gr.Textbox(label="Enter Your Text Here:", lines=2, placeholder="Type it here ...") label = gr.Label(label="Predicted Label") local_plot = gr.HTML(label='Shap Explanation') htext = gr.HTML(label="Named Entity Recognition") submit_btn = gr.Button("Analyze") submit_btn.click( fn=main, inputs=[prob1], outputs=[label, local_plot, htext], api_name="adr" ) gr.Markdown("### Click on any of the examples below to see how it works:") gr.Examples( examples=[ ["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."], ["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."] ], inputs=[prob1], outputs=[label, local_plot, htext], fn=main, cache_examples=False, run_on_click=True ) demo.launch()