bohraanuj23 commited on
Commit
14f41f0
·
1 Parent(s): f27e986

Changes in the model architecture implemented

Browse files
README.md CHANGED
@@ -10,3 +10,14 @@ pinned: false
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
13
+
14
+ # ICD Code Predictor from Lab Reports
15
+
16
+ Upload a medical lab report (PDF), and the model will extract lab test values, process them using LangChain + GPT, and predict ICD codes with confidence scores.
17
+
18
+ Built with:
19
+
20
+ - PyTorch Lightning
21
+ - LangChain + GPT-4o
22
+ - FAISS
23
+ - Gradio
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  import torch
3
  import torch.nn.functional as F
4
- import pandas as pd
5
  from model import DualEncoderModel
6
  from utils import (
7
  extract_text_from_pdf,
@@ -15,7 +14,6 @@ from langchain.embeddings import OpenAIEmbeddings
15
  from langchain.chains import RetrievalQA
16
  from langchain.chat_models import ChatOpenAI
17
 
18
- # Load model and ICD mapping once
19
  lab_cont_features_list = [
20
  "ALT (SGPT)",
21
  "AST (SGOT)",
@@ -37,6 +35,9 @@ lab_cont_features_list = [
37
  "Hematocrit",
38
  ]
39
 
 
 
 
40
  model = DualEncoderModel(
41
  lab_cont_dim=len(lab_cont_features_list),
42
  lab_cat_dims=[],
@@ -45,16 +46,14 @@ model = DualEncoderModel(
45
  embedding_dim=16,
46
  num_classes=18,
47
  )
48
- model.load_state_dict(
49
- torch.load("dual_encoder_model.pth", map_location=torch.device("cpu"))
50
- )
51
  model.eval()
52
 
53
- icd_mapping = load_icd_mapping("cleaned_lab_data.csv")
54
 
55
 
56
- def predict_from_pdf(pdf_file):
57
- text = extract_text_from_pdf(pdf_file.name)
58
  splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
59
  docs = splitter.create_documents([text])
60
 
@@ -64,10 +63,9 @@ def predict_from_pdf(pdf_file):
64
  llm = ChatOpenAI(model_name="gpt-4o", temperature=0)
65
  qa = RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="refine")
66
 
67
- lab_tests_response = qa.run(
68
- "List lab test names and values only with units (no suggestions). Format: Test: Value Unit"
69
- )
70
- lab_data = extract_lab_tests_dict(lab_tests_response)
71
 
72
  lab_cont_tensor = prepare_lab_tensor(lab_data, lab_cont_features_list)
73
  lab_cat_tensor = torch.zeros((1, 0), dtype=torch.int64)
@@ -79,24 +77,29 @@ def predict_from_pdf(pdf_file):
79
  lab_cont_tensor, lab_cat_tensor, conv_cont_tensor, conv_cat_tensor
80
  )
81
  probs = F.softmax(logits, dim=1)
82
- top_k_probs = torch.topk(probs, 3, dim=1)
83
 
84
- diagnoses = []
85
- for idx, prob in zip(top_k_probs.indices[0], top_k_probs.values[0]):
 
 
86
  icd_code, icd_label, diagnosis = icd_mapping.get(
87
- idx.item(), ("Unknown", "Unknown", "No Description Available")
88
  )
89
- diagnoses.append(f"{diagnosis} (ICD: {icd_code}) - {prob.item()*100:.2f}%")
90
-
91
- return "\n".join(diagnoses)
 
 
92
 
93
 
94
- interface = gr.Interface(
95
- fn=predict_from_pdf,
96
- inputs=gr.File(label="Upload Lab Report (PDF)"),
97
- outputs=gr.Textbox(label="Top 3 Predicted Diagnoses with ICD Codes"),
98
- title="Medical ICD Code Predictor",
 
99
  )
100
 
101
  if __name__ == "__main__":
102
- interface.launch()
 
1
  import gradio as gr
2
  import torch
3
  import torch.nn.functional as F
 
4
  from model import DualEncoderModel
5
  from utils import (
6
  extract_text_from_pdf,
 
14
  from langchain.chains import RetrievalQA
15
  from langchain.chat_models import ChatOpenAI
16
 
 
17
  lab_cont_features_list = [
18
  "ALT (SGPT)",
19
  "AST (SGOT)",
 
35
  "Hematocrit",
36
  ]
37
 
38
+ model_path = "dual_encoder_model.pth"
39
+ icd_csv_path = "augmented_lab_data.csv"
40
+
41
  model = DualEncoderModel(
42
  lab_cont_dim=len(lab_cont_features_list),
43
  lab_cat_dims=[],
 
46
  embedding_dim=16,
47
  num_classes=18,
48
  )
49
+ model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
 
 
50
  model.eval()
51
 
52
+ icd_mapping = load_icd_mapping(icd_csv_path)
53
 
54
 
55
+ def predict_icd(pdf):
56
+ text = extract_text_from_pdf(pdf.name)
57
  splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
58
  docs = splitter.create_documents([text])
59
 
 
63
  llm = ChatOpenAI(model_name="gpt-4o", temperature=0)
64
  qa = RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="refine")
65
 
66
+ query = "List lab test names and values only with units (no suggestions). Format: Test: Value Unit"
67
+ lab_response = qa.run(query)
68
+ lab_data = extract_lab_tests_dict(lab_response)
 
69
 
70
  lab_cont_tensor = prepare_lab_tensor(lab_data, lab_cont_features_list)
71
  lab_cat_tensor = torch.zeros((1, 0), dtype=torch.int64)
 
77
  lab_cont_tensor, lab_cat_tensor, conv_cont_tensor, conv_cat_tensor
78
  )
79
  probs = F.softmax(logits, dim=1)
80
+ top_probs = torch.topk(probs, 3, dim=1)
81
 
82
+ output = ""
83
+ for i, (label_idx, prob) in enumerate(
84
+ zip(top_probs.indices[0], top_probs.values[0])
85
+ ):
86
  icd_code, icd_label, diagnosis = icd_mapping.get(
87
+ label_idx.item(), ("Unknown", "Unknown", "No Description Available")
88
  )
89
+ confidence = (
90
+ "🔵 High" if prob > 0.6 else "🟡 Medium" if prob > 0.3 else "🔴 Low"
91
+ )
92
+ output += f"{i+1}. **{diagnosis}**\nICD Code: {icd_code}\nConfidence: {confidence} ({prob:.2%})\n\n"
93
+ return output.strip()
94
 
95
 
96
+ iface = gr.Interface(
97
+ fn=predict_icd,
98
+ inputs=gr.File(label="Upload PDF Lab Report"),
99
+ outputs=gr.Markdown(label="Predicted Diagnoses (ICD Codes)"),
100
+ title="ICD Code Predictor from Lab Report",
101
+ description="Upload a lab report PDF to predict possible diagnoses with ICD codes.",
102
  )
103
 
104
  if __name__ == "__main__":
105
+ iface.launch()
model.py CHANGED
@@ -1,20 +1,28 @@
1
  import torch
2
  import torch.nn as nn
3
- import torch.nn.functional as F
4
 
5
 
6
  class ResidualBlock(nn.Module):
7
- def __init__(self, input_dim, hidden_dim):
8
  super().__init__()
9
- self.fc1 = nn.Linear(input_dim, hidden_dim)
10
  self.relu = nn.ReLU()
11
- self.fc2 = nn.Linear(hidden_dim, input_dim)
 
12
 
13
  def forward(self, x):
14
- return x + self.fc2(self.relu(self.fc1(x)))
 
 
 
 
 
 
 
15
 
16
 
17
- class DualEncoderModel(nn.Module):
18
  def __init__(
19
  self,
20
  lab_cont_dim,
@@ -23,72 +31,67 @@ class DualEncoderModel(nn.Module):
23
  conv_cat_dims,
24
  embedding_dim,
25
  num_classes,
 
26
  ):
27
  super().__init__()
 
28
 
29
- # Lab encoder
30
- self.lab_cont_dim = lab_cont_dim
31
- self.lab_cat_dims = lab_cat_dims
32
- self.lab_cat_embeds = nn.ModuleList(
33
- [nn.Embedding(cat_dim, embedding_dim) for cat_dim in lab_cat_dims]
34
  )
35
- lab_cat_total_dim = embedding_dim * len(lab_cat_dims)
36
- lab_total_input_dim = lab_cont_dim + lab_cat_total_dim
37
- self.lab_encoder = nn.Sequential(
38
- nn.Linear(lab_total_input_dim, 64),
39
- nn.ReLU(),
40
- ResidualBlock(64, 32),
41
- nn.ReLU(),
42
  )
43
 
44
- # Conversation encoder
45
- self.conv_cont_dim = conv_cont_dim
46
- self.conv_cat_dims = conv_cat_dims
47
- self.conv_cat_embeds = nn.ModuleList(
48
- [nn.Embedding(cat_dim, embedding_dim) for cat_dim in conv_cat_dims]
49
  )
50
- conv_cat_total_dim = embedding_dim * len(conv_cat_dims)
51
- conv_total_input_dim = conv_cont_dim + conv_cat_total_dim
52
- self.conv_encoder = nn.Sequential(
53
- nn.Linear(conv_total_input_dim, 64),
54
- nn.ReLU(),
55
- ResidualBlock(64, 32),
56
- nn.ReLU(),
57
  )
58
 
59
- # Fusion + Classifier
 
 
 
 
 
 
 
 
 
60
  self.classifier = nn.Sequential(
61
- nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, num_classes)
 
 
 
62
  )
63
 
64
  def forward(self, lab_cont, lab_cat, conv_cont, conv_cat):
65
- # Process lab categorical features
66
- if self.lab_cat_embeds:
67
- lab_cat_embeds = [
68
- embed(lab_cat[:, i]) for i, embed in enumerate(self.lab_cat_embeds)
 
 
 
69
  ]
70
- lab_cat_encoded = torch.cat(lab_cat_embeds, dim=1)
71
- else:
72
- lab_cat_encoded = torch.empty((lab_cont.size(0), 0), device=lab_cont.device)
73
-
74
- lab_input = torch.cat([lab_cont, lab_cat_encoded], dim=1)
75
- lab_repr = self.lab_encoder(lab_input)
76
 
77
- # Process conversation categorical features
78
- if self.conv_cat_embeds:
79
- conv_cat_embeds = [
80
- embed(conv_cat[:, i]) for i, embed in enumerate(self.conv_cat_embeds)
 
 
81
  ]
82
- conv_cat_encoded = torch.cat(conv_cat_embeds, dim=1)
83
- else:
84
- conv_cat_encoded = torch.empty(
85
- (conv_cont.size(0), 0), device=conv_cont.device
86
- )
87
-
88
- conv_input = torch.cat([conv_cont, conv_cat_encoded], dim=1)
89
- conv_repr = self.conv_encoder(conv_input)
90
 
91
- # Concatenate and classify
92
- fused = torch.cat([lab_repr, conv_repr], dim=1)
93
- output = self.classifier(fused)
94
- return output
 
1
  import torch
2
  import torch.nn as nn
3
+ import pytorch_lightning as pl
4
 
5
 
6
  class ResidualBlock(nn.Module):
7
+ def __init__(self, in_features, out_features, dropout=0.2):
8
  super().__init__()
9
+ self.fc1 = nn.Linear(in_features, out_features)
10
  self.relu = nn.ReLU()
11
+ self.dropout = nn.Dropout(dropout)
12
+ self.fc2 = nn.Linear(out_features, out_features)
13
 
14
  def forward(self, x):
15
+ residual = x
16
+ out = self.fc1(x)
17
+ out = self.relu(out)
18
+ out = self.dropout(out)
19
+ out = self.fc2(out)
20
+ if residual.shape == out.shape:
21
+ out += residual
22
+ return out
23
 
24
 
25
+ class DualEncoderModel(pl.LightningModule):
26
  def __init__(
27
  self,
28
  lab_cont_dim,
 
31
  conv_cat_dims,
32
  embedding_dim,
33
  num_classes,
34
+ lr=1e-3,
35
  ):
36
  super().__init__()
37
+ self.save_hyperparameters()
38
 
39
+ self.lab_cont_encoder = (
40
+ nn.Sequential(ResidualBlock(lab_cont_dim, 64), ResidualBlock(64, 64))
41
+ if lab_cont_dim > 0
42
+ else None
 
43
  )
44
+
45
+ self.lab_cat_embeddings = nn.ModuleList(
46
+ [nn.Embedding(dim + 1, embedding_dim) for dim in lab_cat_dims]
 
 
 
 
47
  )
48
 
49
+ self.conv_cont_encoder = (
50
+ nn.Sequential(ResidualBlock(conv_cont_dim, 64), ResidualBlock(64, 64))
51
+ if conv_cont_dim > 0
52
+ else None
 
53
  )
54
+
55
+ self.conv_cat_embeddings = nn.ModuleList(
56
+ [nn.Embedding(dim + 1, embedding_dim) for dim in conv_cat_dims]
 
 
 
 
57
  )
58
 
59
+ total_dim = 0
60
+ if self.lab_cont_encoder:
61
+ total_dim += 64
62
+ if lab_cat_dims:
63
+ total_dim += embedding_dim * len(lab_cat_dims)
64
+ if self.conv_cont_encoder:
65
+ total_dim += 64
66
+ if conv_cat_dims:
67
+ total_dim += embedding_dim * len(conv_cat_dims)
68
+
69
  self.classifier = nn.Sequential(
70
+ nn.Linear(total_dim, 128),
71
+ nn.ReLU(),
72
+ nn.Dropout(0.3),
73
+ nn.Linear(128, num_classes),
74
  )
75
 
76
  def forward(self, lab_cont, lab_cat, conv_cont, conv_cat):
77
+ embeddings = []
78
+ if self.lab_cont_encoder and lab_cont.nelement() > 0:
79
+ embeddings.append(self.lab_cont_encoder(lab_cont))
80
+ embeddings.extend(
81
+ [
82
+ emb(torch.clamp(lab_cat[:, i], min=0))
83
+ for i, emb in enumerate(self.lab_cat_embeddings)
84
  ]
85
+ )
 
 
 
 
 
86
 
87
+ if self.conv_cont_encoder and conv_cont.nelement() > 0:
88
+ embeddings.append(self.conv_cont_encoder(conv_cont))
89
+ embeddings.extend(
90
+ [
91
+ emb(torch.clamp(conv_cat[:, i], min=0))
92
+ for i, emb in enumerate(self.conv_cat_embeddings)
93
  ]
94
+ )
 
 
 
 
 
 
 
95
 
96
+ fused = torch.cat(embeddings, dim=1)
97
+ return self.classifier(fused)
 
 
augmented_lab_data.csv → model/augmented_lab_data.csv RENAMED
File without changes
dual_encoder_model.pth → model/dual_encoder_model.pth RENAMED
File without changes
requirements.txt CHANGED
@@ -5,4 +5,6 @@ gradio
5
  langchain
6
  openai
7
  faiss-cpu
8
- langchain-community
 
 
 
5
  langchain
6
  openai
7
  faiss-cpu
8
+ langchain-community>=0.0.3
9
+ pytorch-lightning
10
+
utils.py CHANGED
@@ -1,99 +1,40 @@
1
- import torch
2
- import torch.nn as nn
3
- import pdfplumber
4
  import re
 
5
  import pandas as pd
6
- import torch.nn.functional as F
7
- from langchain.text_splitter import RecursiveCharacterTextSplitter
8
- from langchain.vectorstores import FAISS
9
- from langchain.embeddings import OpenAIEmbeddings
10
- from langchain.chains import RetrievalQA
11
- from langchain.chat_models import ChatOpenAI
12
 
13
 
14
- # ---- PDF Extraction ----
15
  def extract_text_from_pdf(pdf_path):
16
- """
17
- Extract text from a PDF file.
18
-
19
- Args:
20
- pdf_path (str): Path to the PDF file.
21
-
22
- Returns:
23
- str: Extracted text from the PDF.
24
- """
25
- try:
26
- with pdfplumber.open(pdf_path) as pdf:
27
- return "\n".join(
28
- page.extract_text() for page in pdf.pages if page.extract_text()
29
- )
30
- except Exception as e:
31
- print(f"Error extracting text from PDF: {e}")
32
- return ""
33
-
34
 
35
- def extract_lab_tests_dict(response_text):
36
- """
37
- Extract lab test names and values from the response text.
38
 
39
- Args:
40
- response_text (str): The text containing lab test information.
41
-
42
- Returns:
43
- dict: A dictionary where keys are lab test names and values are their corresponding numeric values.
44
- """
45
- pattern = r"[-•]?\s*([\w\s/()%.-]+?):\s*([\d.,-]+)\s*(\w+/?.*)?"
46
- matches = re.findall(pattern, response_text)
47
  lab_dict = {}
48
  for test, value, unit in matches:
49
  test = test.strip()
50
  try:
51
- lab_dict[test] = float(value.replace(",", "")) # Handle commas in numbers
52
  except ValueError:
53
- continue # Skip invalid values
54
  return lab_dict
55
 
56
 
57
  def prepare_lab_tensor(lab_data, feature_list):
58
- """
59
- Prepare a tensor for the lab data to be fed into the model.
60
-
61
- Args:
62
- lab_data (dict): A dictionary of lab test names and their values.
63
- feature_list (list): A list of expected lab test names.
64
-
65
- Returns:
66
- torch.Tensor: A tensor containing the lab values.
67
- """
68
  values = [lab_data.get(feature, -1) for feature in feature_list]
69
  return torch.tensor([values], dtype=torch.float32)
70
 
71
 
72
- # ---- Updated ICD Mapping Loader ----
73
  def load_icd_mapping(csv_path):
74
- """
75
- Load and process the ICD mapping from a CSV file.
76
-
77
- Args:
78
- csv_path (str): Path to the CSV file containing ICD mappings.
79
-
80
- Returns:
81
- dict: A dictionary mapping ICD labels to (ICD Code, ICD Label, Diagnosis).
82
- """
83
- try:
84
- df = pd.read_csv(csv_path)
85
-
86
- # Defensive check for required columns
87
- if not {"ICD_Label", "ICD Code", "Diagnosis"}.issubset(df.columns):
88
- raise ValueError(
89
- "Required columns missing from CSV: ICD_Label, ICD Code, Diagnosis"
90
- )
91
-
92
- df = df.drop_duplicates(subset="ICD_Label")
93
- return {
94
- int(row["ICD_Label"]): (row["ICD Code"], row["ICD_Label"], row["Diagnosis"])
95
- for _, row in df.iterrows()
96
- }
97
- except Exception as e:
98
- print(f"Error loading ICD mapping: {e}")
99
- return {}
 
 
 
 
1
  import re
2
+ import pdfplumber
3
  import pandas as pd
4
+ import torch
 
 
 
 
 
5
 
6
 
 
7
  def extract_text_from_pdf(pdf_path):
8
+ with pdfplumber.open(pdf_path) as pdf:
9
+ return "\n".join(
10
+ page.extract_text() for page in pdf.pages if page.extract_text()
11
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
 
 
 
13
 
14
+ def extract_lab_tests_dict(text):
15
+ pattern = r"[-•]?\s*([\w\s/()%.-]+?):\s*([\d.]+)\s*(\w+/?.*)?"
16
+ matches = re.findall(pattern, text)
 
 
 
 
 
17
  lab_dict = {}
18
  for test, value, unit in matches:
19
  test = test.strip()
20
  try:
21
+ lab_dict[test] = float(value)
22
  except ValueError:
23
+ continue
24
  return lab_dict
25
 
26
 
27
  def prepare_lab_tensor(lab_data, feature_list):
 
 
 
 
 
 
 
 
 
 
28
  values = [lab_data.get(feature, -1) for feature in feature_list]
29
  return torch.tensor([values], dtype=torch.float32)
30
 
31
 
 
32
  def load_icd_mapping(csv_path):
33
+ df = pd.read_csv(csv_path)
34
+ if not {"ICD_Label", "ICD Code", "Diagnosis"}.issubset(df.columns):
35
+ raise ValueError("CSV must include ICD_Label, ICD Code, Diagnosis columns.")
36
+ df = df.drop_duplicates(subset="ICD_Label")
37
+ return {
38
+ int(row["ICD_Label"]): (row["ICD Code"], row["ICD_Label"], row["Diagnosis"])
39
+ for _, row in df.iterrows()
40
+ }