Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import pandas as pd | |
| import gradio as gr | |
| import spacy | |
| from empath import Empath | |
| from transformers import AutoTokenizer, AutoModel | |
| from torch_geometric.data import Data | |
| from torch_geometric.nn import SAGEConv | |
| from huggingface_hub import hf_hub_download | |
| # --- Device Config --- | |
| DEVICE = torch.device("cpu") # Free Space uses CPU | |
| torch.set_num_threads(2) | |
| # --- Restore exact architectures --- | |
| class GatorTronEncoder(nn.Module): | |
| def __init__(self, model_id): | |
| super().__init__() | |
| self.model = AutoModel.from_pretrained(model_id) | |
| self.hidden_size = self.model.config.hidden_size | |
| def forward(self, input_ids, attention_mask): | |
| out = self.model(input_ids=input_ids, attention_mask=attention_mask) | |
| return out.last_hidden_state[:, 0, :] | |
| class MetaGNN(nn.Module): | |
| def __init__(self, in_dim, hidden_dim, out_dim): | |
| super().__init__() | |
| self.lin = nn.Linear(in_dim, hidden_dim) | |
| self.conv1 = SAGEConv(hidden_dim, hidden_dim) | |
| self.conv2 = SAGEConv(hidden_dim, out_dim) | |
| def forward(self, data): | |
| x, edge_index = data.x, data.edge_index | |
| x = torch.relu(self.lin(x)) | |
| x = torch.relu(self.conv1(x, edge_index)) | |
| x = self.conv2(x, edge_index) | |
| return torch.log_softmax(x, dim=1) | |
| # --- NLP Feature Extractors --- | |
| nlp = spacy.load("en_core_web_sm") | |
| empath_analyzer = Empath() | |
| def preprocess_text(text): | |
| doc = nlp(text.lower()) | |
| tokens = [] | |
| for tok in doc: | |
| if tok.is_stop or tok.is_punct or tok.like_num or tok.ent_type_ == "PERSON" or not tok.is_alpha: | |
| continue | |
| tokens.append(tok.lemma_) | |
| return tokens | |
| def extract_trigrams(text): | |
| toks = preprocess_text(text) | |
| return [" ".join(toks[i:i+3]) for i in range(len(toks)-2)] | |
| # --- Lazy Global Asset Loaders --- | |
| print("Downloading model weights from Hugging Face Hub...") | |
| REPO_ID = "dep-dev/CNC-weights" | |
| FILENAME = "best_model.pt" | |
| MODEL_PATH = hf_hub_download(repo_id=REPO_ID, filename=FILENAME) | |
| print("Loading checkpoint into memory...") | |
| checkpoint = torch.load(MODEL_PATH, map_location=DEVICE) | |
| # --- Reconstruct Missing Metadata On-the-Fly --- | |
| print("Reconstructing feature metadata...") | |
| empath_cats = sorted(empath_analyzer.cats) | |
| # Load the CSVs you uploaded to the Space | |
| ip_trigrams = set(pd.read_csv("ip_specific_trigrams_masked_train_new_bothmasked.csv")["Trigram"]) | |
| op_trigrams = set(pd.read_csv("op_specific_trigrams_masked_train_new_bothmasked.csv")["Trigram"]) | |
| trigram_list = sorted(ip_trigrams | op_trigrams) | |
| hidden_dim = checkpoint["params"]["hidden_dim"] | |
| # The original model used all 4 feature sets | |
| setup_keys = ["gatortron", "empath", "trigrams", "reasoning"] | |
| print("Loading GatorTron...") | |
| tokenizer = AutoTokenizer.from_pretrained("UFNLP/gatortron-base-2k") | |
| gatortron = GatorTronEncoder("UFNLP/gatortron-base-2k").to(DEVICE) | |
| # Uses the old key "gatortron" instead of "gatortron_state_dict" | |
| gatortron.load_state_dict(checkpoint["gatortron"]) | |
| gatortron.eval() | |
| # Re-evaluate GNN input dimension strictly out of sliced setup keys | |
| total_in_dim = checkpoint["gnn"]["lin.weight"].shape[1] | |
| # total_in_dim = 0 | |
| # for key in setup_keys: | |
| # if key == "gatortron": total_in_dim += 1024 | |
| # elif key == "empath": total_in_dim += len(empath_cats) | |
| # elif key == "trigrams": total_in_dim += len(trigram_list) | |
| # elif key == "reasoning": total_in_dim += 384 | |
| print("Loading GNN...") | |
| gnn = MetaGNN(in_dim=total_in_dim, hidden_dim=hidden_dim, out_dim=2).to(DEVICE) | |
| # Uses the old key "gnn" instead of "gnn_state_dict" | |
| gnn.load_state_dict(checkpoint["gnn"]) | |
| gnn.eval() | |
| # --- Inference Logic Pipeline --- | |
| def predict_clinical_note(note_text): | |
| if not note_text.strip(): | |
| return "Please input a valid clinical note text.", 0.0, 0.0 | |
| with torch.no_grad(): | |
| # 1. Text Encoding Component | |
| # inp = tokenizer([note_text], truncation=True, padding="max_length", max_length=2000, return_tensors="pt").to(DEVICE) | |
| inp = tokenizer([note_text], truncation=True, padding="max_length", max_length=512, return_tensors="pt").to(DEVICE) | |
| gt_emb = gatortron(inp["input_ids"], inp["attention_mask"]).cpu().numpy()[0] | |
| # 2. Empath Component | |
| emp = empath_analyzer.analyze(note_text, normalize=True) | |
| emp_vec = np.array([emp.get(c, 0.0) if emp else 0.0 for c in empath_cats], dtype=np.float32) | |
| # 3. Trigram Counts Component | |
| text_trigs = extract_trigrams(note_text) | |
| tri_vec = np.array([text_trigs.count(t) for t in trigram_list], dtype=np.float32) | |
| # 4. Reason Placeholder | |
| rsn_vec = np.zeros(384, dtype=np.float32) | |
| # Construct vector slices dynamically matching feature mapping logic | |
| slices = [] | |
| for key in setup_keys: | |
| if key == "gatortron": slices.append(gt_emb) | |
| elif key == "empath": slices.append(emp_vec) | |
| elif key == "trigrams": slices.append(tri_vec) | |
| elif key == "reasoning": slices.append(rsn_vec) | |
| final_x = np.concatenate(slices)[np.newaxis, :] | |
| current_dim = final_x.shape[0] | |
| if current_dim < total_in_dim: | |
| final_x = np.pad(final_x, (0, total_in_dim - current_dim), 'constant') | |
| elif current_dim > total_in_dim: | |
| final_x = final_x[:total_in_dim] | |
| final_x = final_x[np.newaxis, :] | |
| # 5. Handle structural constraints gracefully for single inputs | |
| edge_index = torch.tensor([[0], [0]], dtype=torch.long).to(DEVICE) | |
| pyg_data = Data(x=torch.tensor(final_x, dtype=torch.float32).to(DEVICE), edge_index=edge_index) | |
| out = gnn(pyg_data) | |
| probs = torch.exp(out).cpu().numpy()[0] | |
| labels = ["Inpatient (IP)", "Outpatient (OP)"] | |
| prediction = labels[np.argmax(probs)] | |
| return { | |
| "Prediction Decision": prediction, | |
| "Inpatient Probability (IP)": f"{probs[0] * 100:.2f}%", | |
| "Outpatient Probability (OP)": f"{probs[1] * 100:.2f}%" | |
| } | |
| # --- Interface Setup --- | |
| interface = gr.Interface( | |
| fn=predict_clinical_note, | |
| inputs=gr.Textbox(lines=8, placeholder="Enter anonymous clinical or progress notes here...", label="Clinical Patient Document"), | |
| outputs=gr.JSON(label="Prediction System Distribution Outcome Metrics"), | |
| title="Clinical Document Target Assignment Classifier", | |
| description="An active meta-optimization evaluation system determining optimization processing routes utilizing structural-semantic patterns across textual records." | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch() |