dep-dev commited on
Commit
ce4c4ee
Β·
verified Β·
1 Parent(s): 950277f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -0
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import json
4
+ import gradio as gr
5
+ from transformers import AutoTokenizer
6
+ from captum.attr import IntegratedGradients
7
+ from torch_geometric.data import Data
8
+ from empath import Empath
9
+ import spacy
10
+
11
+ # -----------------------
12
+ # Devices
13
+ # -----------------------
14
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ # -----------------------
17
+ # Load NLP
18
+ # -----------------------
19
+ try:
20
+ nlp = spacy.load("en_core_web_sm")
21
+ except:
22
+ import os
23
+ os.system("python -m spacy download en_core_web_sm")
24
+ nlp = spacy.load("en_core_web_sm")
25
+
26
+ empath = Empath()
27
+
28
+ # -----------------------
29
+ # Load Artifacts
30
+ # -----------------------
31
+ tokenizer = AutoTokenizer.from_pretrained("UFNLP/gatortron-base-2k")
32
+
33
+ with open("artifacts/union_trigrams.json") as f:
34
+ TRIGRAM_LIST = json.load(f)
35
+
36
+ with open("artifacts/empath_cats.json") as f:
37
+ EMPATH_CATS = json.load(f)
38
+
39
+ with open("artifacts/ip_op_trigram_sets.json") as f:
40
+ sets = json.load(f)
41
+ IP_SET = set(sets["ip"])
42
+ OP_SET = set(sets["op"])
43
+
44
+ # -----------------------
45
+ # Model Definitions (same as training)
46
+ # -----------------------
47
+ from gatortron_gnn_captum import GatorTronEncoder, MetaGNN, GNNWrapper
48
+
49
+ ckpt = torch.load("artifacts/best_model.pt", map_location=DEVICE)
50
+
51
+ gatortron = GatorTronEncoder("UFNLP/gatortron-base-2k").to(DEVICE)
52
+ gatortron.load_state_dict(ckpt["gatortron"])
53
+ gatortron.eval()
54
+
55
+ gnn = MetaGNN(
56
+ in_dim=ckpt["params"]["in_dim"],
57
+ hidden_dim=ckpt["params"]["hidden_dim"],
58
+ out_dim=2
59
+ ).to(DEVICE)
60
+ gnn.load_state_dict(ckpt["gnn"])
61
+ gnn.eval()
62
+
63
+ # -----------------------
64
+ # Helpers
65
+ # -----------------------
66
+ def extract_trigrams(text):
67
+ doc = nlp(text.lower())
68
+ toks = [t.lemma_ for t in doc if t.is_alpha and not t.is_stop]
69
+ return [" ".join(toks[i:i+3]) for i in range(len(toks)-2)]
70
+
71
+ def build_feature_vector(text):
72
+ inp = tokenizer(
73
+ text,
74
+ truncation=True,
75
+ padding="max_length",
76
+ max_length=2000,
77
+ return_tensors="pt"
78
+ ).to(DEVICE)
79
+
80
+ with torch.no_grad():
81
+ gt = gatortron(inp["input_ids"], inp["attention_mask"]).cpu().numpy()[0]
82
+
83
+ emp = empath.analyze(text, normalize=True)
84
+ emp_vec = np.array([emp.get(c, 0.0) for c in EMPATH_CATS])
85
+
86
+ trigs = extract_trigrams(text)
87
+ tri_vec = np.array([trigs.count(t) for t in TRIGRAM_LIST])
88
+
89
+ rsn = np.zeros(384) # reasoning placeholder
90
+
91
+ return np.concatenate([gt, emp_vec, tri_vec, rsn])
92
+
93
+ def explain(x_tensor):
94
+ dummy_edge = torch.tensor([[0], [0]]).to(DEVICE)
95
+ wrapper = GNNWrapper(gnn, dummy_edge)
96
+ ig = IntegratedGradients(wrapper)
97
+
98
+ attr = ig.attribute(
99
+ x_tensor,
100
+ baselines=torch.zeros_like(x_tensor),
101
+ target=0,
102
+ internal_batch_size=16
103
+ )
104
+
105
+ return attr.abs().cpu().numpy()[0]
106
+
107
+ # -----------------------
108
+ # Inference Function
109
+ # -----------------------
110
+ def predict(note):
111
+ x = build_feature_vector(note)
112
+ x_tensor = torch.tensor(x, dtype=torch.float32).unsqueeze(0).to(DEVICE)
113
+
114
+ dummy_edge = torch.tensor([[0], [0]]).to(DEVICE)
115
+ data = Data(x=x_tensor, edge_index=dummy_edge)
116
+
117
+ with torch.no_grad():
118
+ out = gnn(data)
119
+ probs = torch.exp(out)[0].cpu().numpy()
120
+
121
+ pred = "IP" if probs[0] > probs[1] else "OP"
122
+
123
+ attr = explain(x_tensor)
124
+
125
+ # ---- Empath ----
126
+ emp_start = len(x) - (len(EMPATH_CATS) + len(TRIGRAM_LIST) + 384)
127
+ emp_attr = attr[emp_start:emp_start+len(EMPATH_CATS)]
128
+
129
+ top_empath = sorted(
130
+ zip(EMPATH_CATS, emp_attr),
131
+ key=lambda x: x[1],
132
+ reverse=True
133
+ )[:5]
134
+
135
+ # ---- Trigrams ----
136
+ tri_start = emp_start + len(EMPATH_CATS)
137
+ tri_attr = attr[tri_start:tri_start+len(TRIGRAM_LIST)]
138
+
139
+ top_trigrams = sorted(
140
+ zip(TRIGRAM_LIST, tri_attr),
141
+ key=lambda x: x[1],
142
+ reverse=True
143
+ )[:10]
144
+
145
+ return (
146
+ pred,
147
+ float(probs[0]),
148
+ float(probs[1]),
149
+ top_empath,
150
+ top_trigrams
151
+ )
152
+
153
+ # -----------------------
154
+ # Gradio UI
155
+ # -----------------------
156
+ demo = gr.Interface(
157
+ fn=predict,
158
+ inputs=gr.Textbox(lines=12, label="Clinical Note"),
159
+ outputs=[
160
+ gr.Label(label="Prediction (IP / OP)"),
161
+ gr.Number(label="IP Probability"),
162
+ gr.Number(label="OP Probability"),
163
+ gr.JSON(label="Top 5 Empath Categories"),
164
+ gr.JSON(label="Top 10 Trigrams"),
165
+ ],
166
+ title="Clinical IP / OP Classifier with Explainability",
167
+ description="GatorTron + GNN + Captum interpretability"
168
+ )
169
+
170
+ demo.launch()