import gradio as gr import torch import torch.nn.functional as F from model import DualEncoderModel from utils import ( extract_text_from_pdf, extract_lab_tests_dict, prepare_lab_tensor, load_icd_mapping, ) from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import FAISS from langchain_community.embeddings import OpenAIEmbeddings from langchain.chains import RetrievalQA from langchain_community.chat_models import ChatOpenAI lab_cont_features_list = [ "ALT (SGPT)", "AST (SGOT)", "Bilirubin", "Albumin", "Platelet Count", "Total Cholesterol", "BP Systolic", "BP Diastolic", "Troponin", "Ejection Fraction", "HbA1c", "Fasting Glucose", "Postprandial Glucose", "Triglycerides", "Insulin Level", "WBC Count", "Fever", "Hematocrit", ] model_path = "model/dual_encoder_model.pth" icd_csv_path = "model/augmented_lab_data.csv" model = DualEncoderModel( lab_cont_dim=len(lab_cont_features_list), lab_cat_dims=[], conv_cont_dim=0, conv_cat_dims=[49, 17, 17], embedding_dim=16, num_classes=18, ) model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) model.eval() icd_mapping = load_icd_mapping(icd_csv_path) def predict_icd(pdf): text = extract_text_from_pdf(pdf.name) splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) docs = splitter.create_documents([text]) embedding_model = OpenAIEmbeddings(model="text-embedding-3-large") vectorstore = FAISS.from_documents(docs, embedding=embedding_model) retriever = vectorstore.as_retriever(search_kwargs={"k": 5}) llm = ChatOpenAI(model_name="gpt-4o", temperature=0) qa = RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="refine") query = "List lab test names and values only with units (no suggestions). Format: Test: Value Unit" lab_response = qa.run(query) lab_data = extract_lab_tests_dict(lab_response) lab_cont_tensor = prepare_lab_tensor(lab_data, lab_cont_features_list) lab_cat_tensor = torch.zeros((1, 0), dtype=torch.int64) conv_cont_tensor = torch.zeros((1, 0), dtype=torch.float32) conv_cat_tensor = torch.tensor([[0, 0, 0]], dtype=torch.int64) with torch.no_grad(): logits = model( lab_cont_tensor, lab_cat_tensor, conv_cont_tensor, conv_cat_tensor ) probs = F.softmax(logits, dim=1) top_probs = torch.topk(probs, 3, dim=1) output = "" for i, (label_idx, prob) in enumerate( zip(top_probs.indices[0], top_probs.values[0]) ): icd_code, icd_label, diagnosis = icd_mapping.get( label_idx.item(), ("Unknown", "Unknown", "No Description Available") ) confidence = ( "🔵 High" if prob > 0.6 else "🟡 Medium" if prob > 0.3 else "🔴 Low" ) output += f"{i+1}. **{diagnosis}**\nICD Code: {icd_code}\nConfidence: {confidence} ({prob:.2%})\n\n" return output.strip() iface = gr.Interface( fn=predict_icd, inputs=gr.File(label="Upload PDF Lab Report"), outputs=gr.Markdown(label="Predicted Diagnoses (ICD Codes)"), title="ICD Code Predictor from Lab Report", description="Upload a lab report PDF to predict possible diagnoses with ICD codes.", ) if __name__ == "__main__": iface.launch()