Risk-Adjustment-Version1 / PatientInfoExtractionEngine.py
sujataprakashdatycs's picture
Update PatientInfoExtractionEngine.py
8435e7b verified
import json
import os
from typing import List, Dict
from json_repair import repair_json
from crewai import Agent, Task, Crew, Process
from crewai_tools import tool, SerperDevTool
from langchain_openai import ChatOpenAI
from embedding_manager import DirectoryEmbeddingManager
class PatientInfoExtractionEngine:
"""
Builds Agents and Tasks to extract patient demographics and provider info
from the embeddings created by EmbeddingManager.
"""
def __init__(self, pdf_dir: str):
self.embed_manager = DirectoryEmbeddingManager(pdf_dir)
self.llm = ChatOpenAI(model=os.environ.get("OPENAI_MODEL_NAME", "gpt-4o"), temperature=0)
@tool("patient_chart_search")
def patient_chart_search(query: str) -> str:
"""
Search the patient chart embeddings and return all top 15 results as a single string.
Each result is preserved individually and then combined at the end.
"""
print(f"\n[TOOL LOG] Searching patient chart for: '{query}'")
vectordb = self.embed_manager.get_or_create_embeddings()
results = vectordb.similarity_search(query, k=10)
# Keep all 15 results separate internally
all_results = [res.page_content for res in results]
# Combine into a single string for output (same format as before)
combined_results = "\n---\n".join(all_results)
return combined_results
# Unified Agent
self.patient_info_agent = Agent(
role="Patient Information Extractor",
goal="Extract patient demographics information from the patient chart.",
llm=ChatOpenAI(temperature=0),
verbose=True,
memory=False,
tools=[patient_chart_search],
backstory=(
"You are a medical assistant tasked with extracting patient demographics information. "
"You must only return factual information found explicitly in the chart. "
"If a particular field is not found, leave it as an empty string (''). "
"Do not guess, infer, or fill with placeholder values."
),
)
# Task 1: Demographics Extraction
self.demographics_task = Task(
description=(
"Extract patient demographics from the patient chart. "
"All patient demographic information is usually located together in one section "
"(for example, at the beginning or in a 'Patient Information' block). "
"Patient demographics may appear in different nearby sections, such as 'Patient Information', 'Reason for Visit', or 'History of Present Illness'. You may extract all demographics from these closely related sections."
"from multiple parts of the chart.\n\n"
"Fields to extract:\n"
"- name\n"
"- dob (date of birth in mm-dd-yyyy format)\n"
"- age\n"
"- gender\n"
"- address (only patient’s residential address; exclude any hospital or medical centre names)\n"
"- phone\n"
"- patient_identifier (may appear as patient ID, MRN, or insurance ID in the chart)\n\n"
"Return the result strictly as a JSON object in the format:\n"
"{\n"
" 'name': '',\n"
" 'dob': '',\n"
" 'age': '',\n"
" 'gender': '',\n"
" 'address': '',\n"
" 'phone': '',\n"
" 'patient_identifier': ''\n"
"}\n\n"
"Rules:\n"
"- Extract data only from the section where all patient demographic information is grouped together.\n"
"- You may extract values like gender, age, or DOB even if they appear within sentences (e.g., 'This 66-year-old male...')."
"- If both a name and a gender word appear, always prioritize the gender word in text, not inferred gender from name."
"- If the chart uses gender-descriptive words like 'gentleman' or 'lady', interpret them as 'male' or 'female' respectively."
"- Only include values explicitly mentioned in the chart.\n"
"- If a value is missing or not mentioned, leave it as an empty string.\n"
"- Do not hallucinate, infer, or guess missing details."
),
expected_output="A JSON object containing the extracted patient demographics.",
agent=self.patient_info_agent,
)
# Combine agents and tasks in Crew
self.crew = Crew(
agents=[self.patient_info_agent],
tasks=[self.demographics_task],
process=Process.sequential,
verbose=True,
)
def run(self):
result = self.crew.kickoff()
try:
result = json.loads(repair_json(result))
except (json.JSONDecodeError, TypeError):
print(f"[ERROR] Failed to decode the demographic information")
return result