import torch import numpy as np import json import gradio as gr from transformers import AutoTokenizer from captum.attr import IntegratedGradients from torch_geometric.data import Data from empath import Empath import spacy # ----------------------- # Devices # ----------------------- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ----------------------- # Load NLP # ----------------------- try: nlp = spacy.load("en_core_web_sm") except: import os os.system("python -m spacy download en_core_web_sm") nlp = spacy.load("en_core_web_sm") empath = Empath() # ----------------------- # Load Artifacts # ----------------------- tokenizer = AutoTokenizer.from_pretrained("UFNLP/gatortron-base-2k") with open("artifacts/union_trigrams.json") as f: TRIGRAM_LIST = json.load(f) with open("artifacts/empath_cats.json") as f: EMPATH_CATS = json.load(f) with open("artifacts/ip_op_trigram_sets.json") as f: sets = json.load(f) IP_SET = set(sets["ip"]) OP_SET = set(sets["op"]) # ----------------------- # Model Definitions (same as training) # ----------------------- from gatortron_gnn_captum import GatorTronEncoder, MetaGNN, GNNWrapper ckpt = torch.load("artifacts/best_model.pt", map_location=DEVICE) gatortron = GatorTronEncoder("UFNLP/gatortron-base-2k").to(DEVICE) gatortron.load_state_dict(ckpt["gatortron"]) gatortron.eval() gnn = MetaGNN( in_dim=ckpt["params"]["in_dim"], hidden_dim=ckpt["params"]["hidden_dim"], out_dim=2 ).to(DEVICE) gnn.load_state_dict(ckpt["gnn"]) gnn.eval() # ----------------------- # Helpers # ----------------------- def extract_trigrams(text): doc = nlp(text.lower()) toks = [t.lemma_ for t in doc if t.is_alpha and not t.is_stop] return [" ".join(toks[i:i+3]) for i in range(len(toks)-2)] def build_feature_vector(text): inp = tokenizer( text, truncation=True, padding="max_length", max_length=2000, return_tensors="pt" ).to(DEVICE) with torch.no_grad(): gt = gatortron(inp["input_ids"], inp["attention_mask"]).cpu().numpy()[0] emp = empath.analyze(text, normalize=True) emp_vec = np.array([emp.get(c, 0.0) for c in EMPATH_CATS]) trigs = extract_trigrams(text) tri_vec = np.array([trigs.count(t) for t in TRIGRAM_LIST]) rsn = np.zeros(384) # reasoning placeholder return np.concatenate([gt, emp_vec, tri_vec, rsn]) def explain(x_tensor): dummy_edge = torch.tensor([[0], [0]]).to(DEVICE) wrapper = GNNWrapper(gnn, dummy_edge) ig = IntegratedGradients(wrapper) attr = ig.attribute( x_tensor, baselines=torch.zeros_like(x_tensor), target=0, internal_batch_size=16 ) return attr.abs().cpu().numpy()[0] # ----------------------- # Inference Function # ----------------------- def predict(note): x = build_feature_vector(note) x_tensor = torch.tensor(x, dtype=torch.float32).unsqueeze(0).to(DEVICE) dummy_edge = torch.tensor([[0], [0]]).to(DEVICE) data = Data(x=x_tensor, edge_index=dummy_edge) with torch.no_grad(): out = gnn(data) probs = torch.exp(out)[0].cpu().numpy() pred = "IP" if probs[0] > probs[1] else "OP" attr = explain(x_tensor) # ---- Empath ---- emp_start = len(x) - (len(EMPATH_CATS) + len(TRIGRAM_LIST) + 384) emp_attr = attr[emp_start:emp_start+len(EMPATH_CATS)] top_empath = sorted( zip(EMPATH_CATS, emp_attr), key=lambda x: x[1], reverse=True )[:5] # ---- Trigrams ---- tri_start = emp_start + len(EMPATH_CATS) tri_attr = attr[tri_start:tri_start+len(TRIGRAM_LIST)] top_trigrams = sorted( zip(TRIGRAM_LIST, tri_attr), key=lambda x: x[1], reverse=True )[:10] return ( pred, float(probs[0]), float(probs[1]), top_empath, top_trigrams ) # ----------------------- # Gradio UI # ----------------------- demo = gr.Interface( fn=predict, inputs=gr.Textbox(lines=12, label="Clinical Note"), outputs=[ gr.Label(label="Prediction (IP / OP)"), gr.Number(label="IP Probability"), gr.Number(label="OP Probability"), gr.JSON(label="Top 5 Empath Categories"), gr.JSON(label="Top 10 Trigrams"), ], title="Clinical IP / OP Classifier with Explainability", description="GatorTron + GNN + Captum interpretability" ) demo.launch()