| import gradio as gr |
| import shap |
| import numpy as np |
| import scipy as sp |
| import torch |
| import transformers |
| from transformers import pipeline |
| from transformers import AutoModelForSequenceClassification |
| from transformers import AutoTokenizer, AutoModelForTokenClassification |
| import sys |
| import csv |
| import os |
|
|
| |
| HF_TOKEN = os.getenv("hf_token") |
| csv.field_size_limit(sys.maxsize) |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained("willwim/WillTest", token=HF_TOKEN) |
| model = AutoModelForSequenceClassification.from_pretrained("willwim/WillTest", token=HF_TOKEN).to(device) |
|
|
| |
| pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device) |
|
|
| |
| explainer = shap.Explainer(pred) |
|
|
| |
| ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all") |
| ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all") |
| ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") |
|
|
| def adr_predict(x): |
| text_input = str(x).lower() |
| encoded_input = tokenizer(text_input, return_tensors='pt').to(device) |
| output = model(**encoded_input) |
| |
| scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy() |
| |
| |
| try: |
| shap_values = explainer([text_input]) |
| local_plot = shap.plots.text(shap_values[0], display=False) |
| except Exception as e: |
| print(f"SHAP explanation failed: {e}") |
| local_plot = "<p>SHAP explanation not available.</p>" |
|
|
| |
| try: |
| res = ner_pipe(text_input) |
| entity_colors = { |
| 'Severity': 'red', |
| 'Sign_symptom': 'green', |
| 'Medication': 'lightblue', |
| 'Age': 'yellow', |
| 'Sex':'yellow', |
| 'Diagnostic_procedure':'gray', |
| 'Biological_structure':'silver' |
| } |
| htext = "" |
| prev_end = 0 |
| res = sorted(res, key=lambda x: x['start']) |
| for entity in res: |
| start = entity['start'] |
| end = entity['end'] |
| word = text_input[start:end] |
| entity_type = entity['entity_group'] |
| color = entity_colors.get(entity_type, 'lightgray') |
| |
| htext += f"{text_input[prev_end:start]}" |
| htext += f"<mark style='background-color:{color};'>{word}</mark>" |
| prev_end = end |
| |
| htext += text_input[prev_end:] |
| except Exception as e: |
| print(f"NER processing failed: {e}") |
| htext = "<p>NER processing not available.</p>" |
|
|
| label_output = {"Severe Reaction": float(scores[1]), "Non-severe Reaction": float(scores[0])} |
| return label_output, local_plot, htext |
|
|
| def main(prob1): |
| return adr_predict(prob1) |
|
|
| |
| title = "Welcome to **ADR Detector** 🪐" |
| description1 = "This app predicts severe or non-severe adverse reactions to medications. Do NOT use for medical diagnosis." |
|
|
| with gr.Blocks(title=title) as demo: |
| gr.Markdown(f"## {title}") |
| gr.Markdown(description1) |
| gr.Markdown("---") |
| |
| prob1 = gr.Textbox(label="Enter Your Text Here:", lines=2, placeholder="Type it here ...") |
| |
| label = gr.Label(label="Predicted Label") |
| local_plot = gr.HTML(label='Shap Explanation') |
| htext = gr.HTML(label="Named Entity Recognition") |
| |
| submit_btn = gr.Button("Analyze") |
| submit_btn.click( |
| fn=main, |
| inputs=[prob1], |
| outputs=[label, local_plot, htext], |
| api_name="adr" |
| ) |
| |
| gr.Markdown("### Click on any of the examples below to see how it works:") |
| gr.Examples( |
| examples=[ |
| ["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."], |
| ["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."] |
| ], |
| inputs=[prob1], |
| outputs=[label, local_plot, htext], |
| fn=main, |
| cache_examples=False, |
| run_on_click=True |
| ) |
|
|
| demo.launch() |