import os import torch import torch.nn as nn import numpy as np import pandas as pd import gradio as gr import spacy from empath import Empath from transformers import AutoTokenizer, AutoModel from torch_geometric.data import Data from torch_geometric.nn import SAGEConv from huggingface_hub import hf_hub_download # --- Device Config --- DEVICE = torch.device("cpu") # Free Space uses CPU torch.set_num_threads(2) # --- Restore exact architectures --- class GatorTronEncoder(nn.Module): def __init__(self, model_id): super().__init__() self.model = AutoModel.from_pretrained(model_id) self.hidden_size = self.model.config.hidden_size def forward(self, input_ids, attention_mask): out = self.model(input_ids=input_ids, attention_mask=attention_mask) return out.last_hidden_state[:, 0, :] class MetaGNN(nn.Module): def __init__(self, in_dim, hidden_dim, out_dim): super().__init__() self.lin = nn.Linear(in_dim, hidden_dim) self.conv1 = SAGEConv(hidden_dim, hidden_dim) self.conv2 = SAGEConv(hidden_dim, out_dim) def forward(self, data): x, edge_index = data.x, data.edge_index x = torch.relu(self.lin(x)) x = torch.relu(self.conv1(x, edge_index)) x = self.conv2(x, edge_index) return torch.log_softmax(x, dim=1) # --- NLP Feature Extractors --- nlp = spacy.load("en_core_web_sm") empath_analyzer = Empath() def preprocess_text(text): doc = nlp(text.lower()) tokens = [] for tok in doc: if tok.is_stop or tok.is_punct or tok.like_num or tok.ent_type_ == "PERSON" or not tok.is_alpha: continue tokens.append(tok.lemma_) return tokens def extract_trigrams(text): toks = preprocess_text(text) return [" ".join(toks[i:i+3]) for i in range(len(toks)-2)] # --- Lazy Global Asset Loaders --- print("Downloading model weights from Hugging Face Hub...") REPO_ID = "dep-dev/CNC-weights" FILENAME = "best_model.pt" MODEL_PATH = hf_hub_download(repo_id=REPO_ID, filename=FILENAME) print("Loading checkpoint into memory...") checkpoint = torch.load(MODEL_PATH, map_location=DEVICE) # --- Reconstruct Missing Metadata On-the-Fly --- print("Reconstructing feature metadata...") empath_cats = sorted(empath_analyzer.cats) # Load the CSVs you uploaded to the Space ip_trigrams = set(pd.read_csv("ip_specific_trigrams_masked_train_new_bothmasked.csv")["Trigram"]) op_trigrams = set(pd.read_csv("op_specific_trigrams_masked_train_new_bothmasked.csv")["Trigram"]) trigram_list = sorted(ip_trigrams | op_trigrams) hidden_dim = checkpoint["params"]["hidden_dim"] # The original model used all 4 feature sets setup_keys = ["gatortron", "empath", "trigrams", "reasoning"] print("Loading GatorTron...") tokenizer = AutoTokenizer.from_pretrained("UFNLP/gatortron-base-2k") gatortron = GatorTronEncoder("UFNLP/gatortron-base-2k").to(DEVICE) # Uses the old key "gatortron" instead of "gatortron_state_dict" gatortron.load_state_dict(checkpoint["gatortron"]) gatortron.eval() # Re-evaluate GNN input dimension strictly out of sliced setup keys total_in_dim = checkpoint["gnn"]["lin.weight"].shape[1] # total_in_dim = 0 # for key in setup_keys: # if key == "gatortron": total_in_dim += 1024 # elif key == "empath": total_in_dim += len(empath_cats) # elif key == "trigrams": total_in_dim += len(trigram_list) # elif key == "reasoning": total_in_dim += 384 print("Loading GNN...") gnn = MetaGNN(in_dim=total_in_dim, hidden_dim=hidden_dim, out_dim=2).to(DEVICE) # Uses the old key "gnn" instead of "gnn_state_dict" gnn.load_state_dict(checkpoint["gnn"]) gnn.eval() # --- Inference Logic Pipeline --- def predict_clinical_note(note_text): if not note_text.strip(): return "Please input a valid clinical note text.", 0.0, 0.0 with torch.no_grad(): # 1. Text Encoding Component # inp = tokenizer([note_text], truncation=True, padding="max_length", max_length=2000, return_tensors="pt").to(DEVICE) inp = tokenizer([note_text], truncation=True, padding="max_length", max_length=512, return_tensors="pt").to(DEVICE) gt_emb = gatortron(inp["input_ids"], inp["attention_mask"]).cpu().numpy()[0] # 2. Empath Component emp = empath_analyzer.analyze(note_text, normalize=True) emp_vec = np.array([emp.get(c, 0.0) if emp else 0.0 for c in empath_cats], dtype=np.float32) # 3. Trigram Counts Component text_trigs = extract_trigrams(note_text) tri_vec = np.array([text_trigs.count(t) for t in trigram_list], dtype=np.float32) # 4. Reason Placeholder rsn_vec = np.zeros(384, dtype=np.float32) # Construct vector slices dynamically matching feature mapping logic slices = [] for key in setup_keys: if key == "gatortron": slices.append(gt_emb) elif key == "empath": slices.append(emp_vec) elif key == "trigrams": slices.append(tri_vec) elif key == "reasoning": slices.append(rsn_vec) final_x = np.concatenate(slices)[np.newaxis, :] current_dim = final_x.shape[0] if current_dim < total_in_dim: final_x = np.pad(final_x, (0, total_in_dim - current_dim), 'constant') elif current_dim > total_in_dim: final_x = final_x[:total_in_dim] final_x = final_x[np.newaxis, :] # 5. Handle structural constraints gracefully for single inputs edge_index = torch.tensor([[0], [0]], dtype=torch.long).to(DEVICE) pyg_data = Data(x=torch.tensor(final_x, dtype=torch.float32).to(DEVICE), edge_index=edge_index) out = gnn(pyg_data) probs = torch.exp(out).cpu().numpy()[0] labels = ["Inpatient (IP)", "Outpatient (OP)"] prediction = labels[np.argmax(probs)] return { "Prediction Decision": prediction, "Inpatient Probability (IP)": f"{probs[0] * 100:.2f}%", "Outpatient Probability (OP)": f"{probs[1] * 100:.2f}%" } # --- Interface Setup --- interface = gr.Interface( fn=predict_clinical_note, inputs=gr.Textbox(lines=8, placeholder="Enter anonymous clinical or progress notes here...", label="Clinical Patient Document"), outputs=gr.JSON(label="Prediction System Distribution Outcome Metrics"), title="Clinical Document Target Assignment Classifier", description="An active meta-optimization evaluation system determining optimization processing routes utilizing structural-semantic patterns across textual records." ) if __name__ == "__main__": interface.launch()