bohraanuj23's picture
Modified some of the imports
b233c8d
import gradio as gr
import torch
import torch.nn.functional as F
from model import DualEncoderModel
from utils import (
extract_text_from_pdf,
extract_lab_tests_dict,
prepare_lab_tensor,
load_icd_mapping,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import OpenAIEmbeddings
from langchain.chains import RetrievalQA
from langchain_community.chat_models import ChatOpenAI
lab_cont_features_list = [
"ALT (SGPT)",
"AST (SGOT)",
"Bilirubin",
"Albumin",
"Platelet Count",
"Total Cholesterol",
"BP Systolic",
"BP Diastolic",
"Troponin",
"Ejection Fraction",
"HbA1c",
"Fasting Glucose",
"Postprandial Glucose",
"Triglycerides",
"Insulin Level",
"WBC Count",
"Fever",
"Hematocrit",
]
model_path = "model/dual_encoder_model.pth"
icd_csv_path = "model/augmented_lab_data.csv"
model = DualEncoderModel(
lab_cont_dim=len(lab_cont_features_list),
lab_cat_dims=[],
conv_cont_dim=0,
conv_cat_dims=[49, 17, 17],
embedding_dim=16,
num_classes=18,
)
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
model.eval()
icd_mapping = load_icd_mapping(icd_csv_path)
def predict_icd(pdf):
text = extract_text_from_pdf(pdf.name)
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
docs = splitter.create_documents([text])
embedding_model = OpenAIEmbeddings(model="text-embedding-3-large")
vectorstore = FAISS.from_documents(docs, embedding=embedding_model)
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
llm = ChatOpenAI(model_name="gpt-4o", temperature=0)
qa = RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="refine")
query = "List lab test names and values only with units (no suggestions). Format: Test: Value Unit"
lab_response = qa.run(query)
lab_data = extract_lab_tests_dict(lab_response)
lab_cont_tensor = prepare_lab_tensor(lab_data, lab_cont_features_list)
lab_cat_tensor = torch.zeros((1, 0), dtype=torch.int64)
conv_cont_tensor = torch.zeros((1, 0), dtype=torch.float32)
conv_cat_tensor = torch.tensor([[0, 0, 0]], dtype=torch.int64)
with torch.no_grad():
logits = model(
lab_cont_tensor, lab_cat_tensor, conv_cont_tensor, conv_cat_tensor
)
probs = F.softmax(logits, dim=1)
top_probs = torch.topk(probs, 3, dim=1)
output = ""
for i, (label_idx, prob) in enumerate(
zip(top_probs.indices[0], top_probs.values[0])
):
icd_code, icd_label, diagnosis = icd_mapping.get(
label_idx.item(), ("Unknown", "Unknown", "No Description Available")
)
confidence = (
"🔵 High" if prob > 0.6 else "🟡 Medium" if prob > 0.3 else "🔴 Low"
)
output += f"{i+1}. **{diagnosis}**\nICD Code: {icd_code}\nConfidence: {confidence} ({prob:.2%})\n\n"
return output.strip()
iface = gr.Interface(
fn=predict_icd,
inputs=gr.File(label="Upload PDF Lab Report"),
outputs=gr.Markdown(label="Predicted Diagnoses (ICD Codes)"),
title="ICD Code Predictor from Lab Report",
description="Upload a lab report PDF to predict possible diagnoses with ICD codes.",
)
if __name__ == "__main__":
iface.launch()