CNC / app.py
dep-dev's picture
Create app.py
ce4c4ee verified
import torch
import numpy as np
import json
import gradio as gr
from transformers import AutoTokenizer
from captum.attr import IntegratedGradients
from torch_geometric.data import Data
from empath import Empath
import spacy
# -----------------------
# Devices
# -----------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -----------------------
# Load NLP
# -----------------------
try:
nlp = spacy.load("en_core_web_sm")
except:
import os
os.system("python -m spacy download en_core_web_sm")
nlp = spacy.load("en_core_web_sm")
empath = Empath()
# -----------------------
# Load Artifacts
# -----------------------
tokenizer = AutoTokenizer.from_pretrained("UFNLP/gatortron-base-2k")
with open("artifacts/union_trigrams.json") as f:
TRIGRAM_LIST = json.load(f)
with open("artifacts/empath_cats.json") as f:
EMPATH_CATS = json.load(f)
with open("artifacts/ip_op_trigram_sets.json") as f:
sets = json.load(f)
IP_SET = set(sets["ip"])
OP_SET = set(sets["op"])
# -----------------------
# Model Definitions (same as training)
# -----------------------
from gatortron_gnn_captum import GatorTronEncoder, MetaGNN, GNNWrapper
ckpt = torch.load("artifacts/best_model.pt", map_location=DEVICE)
gatortron = GatorTronEncoder("UFNLP/gatortron-base-2k").to(DEVICE)
gatortron.load_state_dict(ckpt["gatortron"])
gatortron.eval()
gnn = MetaGNN(
in_dim=ckpt["params"]["in_dim"],
hidden_dim=ckpt["params"]["hidden_dim"],
out_dim=2
).to(DEVICE)
gnn.load_state_dict(ckpt["gnn"])
gnn.eval()
# -----------------------
# Helpers
# -----------------------
def extract_trigrams(text):
doc = nlp(text.lower())
toks = [t.lemma_ for t in doc if t.is_alpha and not t.is_stop]
return [" ".join(toks[i:i+3]) for i in range(len(toks)-2)]
def build_feature_vector(text):
inp = tokenizer(
text,
truncation=True,
padding="max_length",
max_length=2000,
return_tensors="pt"
).to(DEVICE)
with torch.no_grad():
gt = gatortron(inp["input_ids"], inp["attention_mask"]).cpu().numpy()[0]
emp = empath.analyze(text, normalize=True)
emp_vec = np.array([emp.get(c, 0.0) for c in EMPATH_CATS])
trigs = extract_trigrams(text)
tri_vec = np.array([trigs.count(t) for t in TRIGRAM_LIST])
rsn = np.zeros(384) # reasoning placeholder
return np.concatenate([gt, emp_vec, tri_vec, rsn])
def explain(x_tensor):
dummy_edge = torch.tensor([[0], [0]]).to(DEVICE)
wrapper = GNNWrapper(gnn, dummy_edge)
ig = IntegratedGradients(wrapper)
attr = ig.attribute(
x_tensor,
baselines=torch.zeros_like(x_tensor),
target=0,
internal_batch_size=16
)
return attr.abs().cpu().numpy()[0]
# -----------------------
# Inference Function
# -----------------------
def predict(note):
x = build_feature_vector(note)
x_tensor = torch.tensor(x, dtype=torch.float32).unsqueeze(0).to(DEVICE)
dummy_edge = torch.tensor([[0], [0]]).to(DEVICE)
data = Data(x=x_tensor, edge_index=dummy_edge)
with torch.no_grad():
out = gnn(data)
probs = torch.exp(out)[0].cpu().numpy()
pred = "IP" if probs[0] > probs[1] else "OP"
attr = explain(x_tensor)
# ---- Empath ----
emp_start = len(x) - (len(EMPATH_CATS) + len(TRIGRAM_LIST) + 384)
emp_attr = attr[emp_start:emp_start+len(EMPATH_CATS)]
top_empath = sorted(
zip(EMPATH_CATS, emp_attr),
key=lambda x: x[1],
reverse=True
)[:5]
# ---- Trigrams ----
tri_start = emp_start + len(EMPATH_CATS)
tri_attr = attr[tri_start:tri_start+len(TRIGRAM_LIST)]
top_trigrams = sorted(
zip(TRIGRAM_LIST, tri_attr),
key=lambda x: x[1],
reverse=True
)[:10]
return (
pred,
float(probs[0]),
float(probs[1]),
top_empath,
top_trigrams
)
# -----------------------
# Gradio UI
# -----------------------
demo = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=12, label="Clinical Note"),
outputs=[
gr.Label(label="Prediction (IP / OP)"),
gr.Number(label="IP Probability"),
gr.Number(label="OP Probability"),
gr.JSON(label="Top 5 Empath Categories"),
gr.JSON(label="Top 10 Trigrams"),
],
title="Clinical IP / OP Classifier with Explainability",
description="GatorTron + GNN + Captum interpretability"
)
demo.launch()