Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from model import DualEncoderModel | |
| from utils import ( | |
| extract_text_from_pdf, | |
| extract_lab_tests_dict, | |
| prepare_lab_tensor, | |
| load_icd_mapping, | |
| ) | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.embeddings import OpenAIEmbeddings | |
| from langchain.chains import RetrievalQA | |
| from langchain_community.chat_models import ChatOpenAI | |
| lab_cont_features_list = [ | |
| "ALT (SGPT)", | |
| "AST (SGOT)", | |
| "Bilirubin", | |
| "Albumin", | |
| "Platelet Count", | |
| "Total Cholesterol", | |
| "BP Systolic", | |
| "BP Diastolic", | |
| "Troponin", | |
| "Ejection Fraction", | |
| "HbA1c", | |
| "Fasting Glucose", | |
| "Postprandial Glucose", | |
| "Triglycerides", | |
| "Insulin Level", | |
| "WBC Count", | |
| "Fever", | |
| "Hematocrit", | |
| ] | |
| model_path = "model/dual_encoder_model.pth" | |
| icd_csv_path = "model/augmented_lab_data.csv" | |
| model = DualEncoderModel( | |
| lab_cont_dim=len(lab_cont_features_list), | |
| lab_cat_dims=[], | |
| conv_cont_dim=0, | |
| conv_cat_dims=[49, 17, 17], | |
| embedding_dim=16, | |
| num_classes=18, | |
| ) | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) | |
| model.eval() | |
| icd_mapping = load_icd_mapping(icd_csv_path) | |
| def predict_icd(pdf): | |
| text = extract_text_from_pdf(pdf.name) | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
| docs = splitter.create_documents([text]) | |
| embedding_model = OpenAIEmbeddings(model="text-embedding-3-large") | |
| vectorstore = FAISS.from_documents(docs, embedding=embedding_model) | |
| retriever = vectorstore.as_retriever(search_kwargs={"k": 5}) | |
| llm = ChatOpenAI(model_name="gpt-4o", temperature=0) | |
| qa = RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="refine") | |
| query = "List lab test names and values only with units (no suggestions). Format: Test: Value Unit" | |
| lab_response = qa.run(query) | |
| lab_data = extract_lab_tests_dict(lab_response) | |
| lab_cont_tensor = prepare_lab_tensor(lab_data, lab_cont_features_list) | |
| lab_cat_tensor = torch.zeros((1, 0), dtype=torch.int64) | |
| conv_cont_tensor = torch.zeros((1, 0), dtype=torch.float32) | |
| conv_cat_tensor = torch.tensor([[0, 0, 0]], dtype=torch.int64) | |
| with torch.no_grad(): | |
| logits = model( | |
| lab_cont_tensor, lab_cat_tensor, conv_cont_tensor, conv_cat_tensor | |
| ) | |
| probs = F.softmax(logits, dim=1) | |
| top_probs = torch.topk(probs, 3, dim=1) | |
| output = "" | |
| for i, (label_idx, prob) in enumerate( | |
| zip(top_probs.indices[0], top_probs.values[0]) | |
| ): | |
| icd_code, icd_label, diagnosis = icd_mapping.get( | |
| label_idx.item(), ("Unknown", "Unknown", "No Description Available") | |
| ) | |
| confidence = ( | |
| "🔵 High" if prob > 0.6 else "🟡 Medium" if prob > 0.3 else "🔴 Low" | |
| ) | |
| output += f"{i+1}. **{diagnosis}**\nICD Code: {icd_code}\nConfidence: {confidence} ({prob:.2%})\n\n" | |
| return output.strip() | |
| iface = gr.Interface( | |
| fn=predict_icd, | |
| inputs=gr.File(label="Upload PDF Lab Report"), | |
| outputs=gr.Markdown(label="Predicted Diagnoses (ICD Codes)"), | |
| title="ICD Code Predictor from Lab Report", | |
| description="Upload a lab report PDF to predict possible diagnoses with ICD codes.", | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |