CNC-IPvsOP / app.py
dep-dev's picture
app script
f56a590 verified
Raw
History Blame Contribute Delete
6.62 kB
import os
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import gradio as gr
import spacy
from empath import Empath
from transformers import AutoTokenizer, AutoModel
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv
from huggingface_hub import hf_hub_download
# --- Device Config ---
DEVICE = torch.device("cpu") # Free Space uses CPU
torch.set_num_threads(2)
# --- Restore exact architectures ---
class GatorTronEncoder(nn.Module):
def __init__(self, model_id):
super().__init__()
self.model = AutoModel.from_pretrained(model_id)
self.hidden_size = self.model.config.hidden_size
def forward(self, input_ids, attention_mask):
out = self.model(input_ids=input_ids, attention_mask=attention_mask)
return out.last_hidden_state[:, 0, :]
class MetaGNN(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim):
super().__init__()
self.lin = nn.Linear(in_dim, hidden_dim)
self.conv1 = SAGEConv(hidden_dim, hidden_dim)
self.conv2 = SAGEConv(hidden_dim, out_dim)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = torch.relu(self.lin(x))
x = torch.relu(self.conv1(x, edge_index))
x = self.conv2(x, edge_index)
return torch.log_softmax(x, dim=1)
# --- NLP Feature Extractors ---
nlp = spacy.load("en_core_web_sm")
empath_analyzer = Empath()
def preprocess_text(text):
doc = nlp(text.lower())
tokens = []
for tok in doc:
if tok.is_stop or tok.is_punct or tok.like_num or tok.ent_type_ == "PERSON" or not tok.is_alpha:
continue
tokens.append(tok.lemma_)
return tokens
def extract_trigrams(text):
toks = preprocess_text(text)
return [" ".join(toks[i:i+3]) for i in range(len(toks)-2)]
# --- Lazy Global Asset Loaders ---
print("Downloading model weights from Hugging Face Hub...")
REPO_ID = "dep-dev/CNC-weights"
FILENAME = "best_model.pt"
MODEL_PATH = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
print("Loading checkpoint into memory...")
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
# --- Reconstruct Missing Metadata On-the-Fly ---
print("Reconstructing feature metadata...")
empath_cats = sorted(empath_analyzer.cats)
# Load the CSVs you uploaded to the Space
ip_trigrams = set(pd.read_csv("ip_specific_trigrams_masked_train_new_bothmasked.csv")["Trigram"])
op_trigrams = set(pd.read_csv("op_specific_trigrams_masked_train_new_bothmasked.csv")["Trigram"])
trigram_list = sorted(ip_trigrams | op_trigrams)
hidden_dim = checkpoint["params"]["hidden_dim"]
# The original model used all 4 feature sets
setup_keys = ["gatortron", "empath", "trigrams", "reasoning"]
print("Loading GatorTron...")
tokenizer = AutoTokenizer.from_pretrained("UFNLP/gatortron-base-2k")
gatortron = GatorTronEncoder("UFNLP/gatortron-base-2k").to(DEVICE)
# Uses the old key "gatortron" instead of "gatortron_state_dict"
gatortron.load_state_dict(checkpoint["gatortron"])
gatortron.eval()
# Re-evaluate GNN input dimension strictly out of sliced setup keys
total_in_dim = checkpoint["gnn"]["lin.weight"].shape[1]
# total_in_dim = 0
# for key in setup_keys:
# if key == "gatortron": total_in_dim += 1024
# elif key == "empath": total_in_dim += len(empath_cats)
# elif key == "trigrams": total_in_dim += len(trigram_list)
# elif key == "reasoning": total_in_dim += 384
print("Loading GNN...")
gnn = MetaGNN(in_dim=total_in_dim, hidden_dim=hidden_dim, out_dim=2).to(DEVICE)
# Uses the old key "gnn" instead of "gnn_state_dict"
gnn.load_state_dict(checkpoint["gnn"])
gnn.eval()
# --- Inference Logic Pipeline ---
def predict_clinical_note(note_text):
if not note_text.strip():
return "Please input a valid clinical note text.", 0.0, 0.0
with torch.no_grad():
# 1. Text Encoding Component
# inp = tokenizer([note_text], truncation=True, padding="max_length", max_length=2000, return_tensors="pt").to(DEVICE)
inp = tokenizer([note_text], truncation=True, padding="max_length", max_length=512, return_tensors="pt").to(DEVICE)
gt_emb = gatortron(inp["input_ids"], inp["attention_mask"]).cpu().numpy()[0]
# 2. Empath Component
emp = empath_analyzer.analyze(note_text, normalize=True)
emp_vec = np.array([emp.get(c, 0.0) if emp else 0.0 for c in empath_cats], dtype=np.float32)
# 3. Trigram Counts Component
text_trigs = extract_trigrams(note_text)
tri_vec = np.array([text_trigs.count(t) for t in trigram_list], dtype=np.float32)
# 4. Reason Placeholder
rsn_vec = np.zeros(384, dtype=np.float32)
# Construct vector slices dynamically matching feature mapping logic
slices = []
for key in setup_keys:
if key == "gatortron": slices.append(gt_emb)
elif key == "empath": slices.append(emp_vec)
elif key == "trigrams": slices.append(tri_vec)
elif key == "reasoning": slices.append(rsn_vec)
final_x = np.concatenate(slices)[np.newaxis, :]
current_dim = final_x.shape[0]
if current_dim < total_in_dim:
final_x = np.pad(final_x, (0, total_in_dim - current_dim), 'constant')
elif current_dim > total_in_dim:
final_x = final_x[:total_in_dim]
final_x = final_x[np.newaxis, :]
# 5. Handle structural constraints gracefully for single inputs
edge_index = torch.tensor([[0], [0]], dtype=torch.long).to(DEVICE)
pyg_data = Data(x=torch.tensor(final_x, dtype=torch.float32).to(DEVICE), edge_index=edge_index)
out = gnn(pyg_data)
probs = torch.exp(out).cpu().numpy()[0]
labels = ["Inpatient (IP)", "Outpatient (OP)"]
prediction = labels[np.argmax(probs)]
return {
"Prediction Decision": prediction,
"Inpatient Probability (IP)": f"{probs[0] * 100:.2f}%",
"Outpatient Probability (OP)": f"{probs[1] * 100:.2f}%"
}
# --- Interface Setup ---
interface = gr.Interface(
fn=predict_clinical_note,
inputs=gr.Textbox(lines=8, placeholder="Enter anonymous clinical or progress notes here...", label="Clinical Patient Document"),
outputs=gr.JSON(label="Prediction System Distribution Outcome Metrics"),
title="Clinical Document Target Assignment Classifier",
description="An active meta-optimization evaluation system determining optimization processing routes utilizing structural-semantic patterns across textual records."
)
if __name__ == "__main__":
interface.launch()