| import torch |
| import numpy as np |
| import json |
| import gradio as gr |
| from transformers import AutoTokenizer |
| from captum.attr import IntegratedGradients |
| from torch_geometric.data import Data |
| from empath import Empath |
| import spacy |
|
|
| |
| |
| |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| |
| |
| try: |
| nlp = spacy.load("en_core_web_sm") |
| except: |
| import os |
| os.system("python -m spacy download en_core_web_sm") |
| nlp = spacy.load("en_core_web_sm") |
|
|
| empath = Empath() |
|
|
| |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained("UFNLP/gatortron-base-2k") |
|
|
| with open("artifacts/union_trigrams.json") as f: |
| TRIGRAM_LIST = json.load(f) |
|
|
| with open("artifacts/empath_cats.json") as f: |
| EMPATH_CATS = json.load(f) |
|
|
| with open("artifacts/ip_op_trigram_sets.json") as f: |
| sets = json.load(f) |
| IP_SET = set(sets["ip"]) |
| OP_SET = set(sets["op"]) |
|
|
| |
| |
| |
| from gatortron_gnn_captum import GatorTronEncoder, MetaGNN, GNNWrapper |
|
|
| ckpt = torch.load("artifacts/best_model.pt", map_location=DEVICE) |
|
|
| gatortron = GatorTronEncoder("UFNLP/gatortron-base-2k").to(DEVICE) |
| gatortron.load_state_dict(ckpt["gatortron"]) |
| gatortron.eval() |
|
|
| gnn = MetaGNN( |
| in_dim=ckpt["params"]["in_dim"], |
| hidden_dim=ckpt["params"]["hidden_dim"], |
| out_dim=2 |
| ).to(DEVICE) |
| gnn.load_state_dict(ckpt["gnn"]) |
| gnn.eval() |
|
|
| |
| |
| |
| def extract_trigrams(text): |
| doc = nlp(text.lower()) |
| toks = [t.lemma_ for t in doc if t.is_alpha and not t.is_stop] |
| return [" ".join(toks[i:i+3]) for i in range(len(toks)-2)] |
|
|
| def build_feature_vector(text): |
| inp = tokenizer( |
| text, |
| truncation=True, |
| padding="max_length", |
| max_length=2000, |
| return_tensors="pt" |
| ).to(DEVICE) |
|
|
| with torch.no_grad(): |
| gt = gatortron(inp["input_ids"], inp["attention_mask"]).cpu().numpy()[0] |
|
|
| emp = empath.analyze(text, normalize=True) |
| emp_vec = np.array([emp.get(c, 0.0) for c in EMPATH_CATS]) |
|
|
| trigs = extract_trigrams(text) |
| tri_vec = np.array([trigs.count(t) for t in TRIGRAM_LIST]) |
|
|
| rsn = np.zeros(384) |
|
|
| return np.concatenate([gt, emp_vec, tri_vec, rsn]) |
|
|
| def explain(x_tensor): |
| dummy_edge = torch.tensor([[0], [0]]).to(DEVICE) |
| wrapper = GNNWrapper(gnn, dummy_edge) |
| ig = IntegratedGradients(wrapper) |
|
|
| attr = ig.attribute( |
| x_tensor, |
| baselines=torch.zeros_like(x_tensor), |
| target=0, |
| internal_batch_size=16 |
| ) |
|
|
| return attr.abs().cpu().numpy()[0] |
|
|
| |
| |
| |
| def predict(note): |
| x = build_feature_vector(note) |
| x_tensor = torch.tensor(x, dtype=torch.float32).unsqueeze(0).to(DEVICE) |
|
|
| dummy_edge = torch.tensor([[0], [0]]).to(DEVICE) |
| data = Data(x=x_tensor, edge_index=dummy_edge) |
|
|
| with torch.no_grad(): |
| out = gnn(data) |
| probs = torch.exp(out)[0].cpu().numpy() |
|
|
| pred = "IP" if probs[0] > probs[1] else "OP" |
|
|
| attr = explain(x_tensor) |
|
|
| |
| emp_start = len(x) - (len(EMPATH_CATS) + len(TRIGRAM_LIST) + 384) |
| emp_attr = attr[emp_start:emp_start+len(EMPATH_CATS)] |
|
|
| top_empath = sorted( |
| zip(EMPATH_CATS, emp_attr), |
| key=lambda x: x[1], |
| reverse=True |
| )[:5] |
|
|
| |
| tri_start = emp_start + len(EMPATH_CATS) |
| tri_attr = attr[tri_start:tri_start+len(TRIGRAM_LIST)] |
|
|
| top_trigrams = sorted( |
| zip(TRIGRAM_LIST, tri_attr), |
| key=lambda x: x[1], |
| reverse=True |
| )[:10] |
|
|
| return ( |
| pred, |
| float(probs[0]), |
| float(probs[1]), |
| top_empath, |
| top_trigrams |
| ) |
|
|
| |
| |
| |
| demo = gr.Interface( |
| fn=predict, |
| inputs=gr.Textbox(lines=12, label="Clinical Note"), |
| outputs=[ |
| gr.Label(label="Prediction (IP / OP)"), |
| gr.Number(label="IP Probability"), |
| gr.Number(label="OP Probability"), |
| gr.JSON(label="Top 5 Empath Categories"), |
| gr.JSON(label="Top 10 Trigrams"), |
| ], |
| title="Clinical IP / OP Classifier with Explainability", |
| description="GatorTron + GNN + Captum interpretability" |
| ) |
|
|
| demo.launch() |
|
|