File size: 5,163 Bytes
b58a2ed
cc49eb6
 
eac4c73
 
 
 
 
b58a2ed
eac4c73
b58a2ed
d7dd4c6
 
 
 
 
 
 
eb7971a
d7dd4c6
 
 
 
 
 
 
 
 
28c3514
d7dd4c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a997108
d7dd4c6
 
a997108
 
 
eb7971a
a997108
 
d7dd4c6
8435e7b
d7dd4c6
 
a997108
d7dd4c6
 
a997108
d7dd4c6
 
 
 
 
 
 
 
 
 
a997108
eb7971a
 
 
a997108
d7dd4c6
a997108
d7dd4c6
a997108
 
d7dd4c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112

import json
import os
from typing import List, Dict
from json_repair import repair_json
from crewai import Agent, Task, Crew, Process
from crewai_tools import tool, SerperDevTool
from langchain_openai import ChatOpenAI

from embedding_manager import DirectoryEmbeddingManager

class PatientInfoExtractionEngine:
    """
    Builds Agents and Tasks to extract patient demographics and provider info
    from the embeddings created by EmbeddingManager.
    """
    def __init__(self, pdf_dir: str):
        self.embed_manager = DirectoryEmbeddingManager(pdf_dir)
        self.llm = ChatOpenAI(model=os.environ.get("OPENAI_MODEL_NAME", "gpt-4o"), temperature=0)

        @tool("patient_chart_search")
        def patient_chart_search(query: str) -> str:
            """
            Search the patient chart embeddings and return all top 15 results as a single string.
            Each result is preserved individually and then combined at the end.
            """
            print(f"\n[TOOL LOG] Searching patient chart for: '{query}'")
            vectordb = self.embed_manager.get_or_create_embeddings()
            results = vectordb.similarity_search(query, k=10)

            # Keep all 15 results separate internally
            all_results = [res.page_content for res in results]

            # Combine into a single string for output (same format as before)
            combined_results = "\n---\n".join(all_results)

            return combined_results

        # Unified Agent
        self.patient_info_agent = Agent(
            role="Patient Information Extractor",
            goal="Extract patient demographics information from the patient chart.",
            llm=ChatOpenAI(temperature=0),
            verbose=True,
            memory=False,
            tools=[patient_chart_search],
            backstory=(
                "You are a medical assistant tasked with extracting patient demographics information. "
                "You must only return factual information found explicitly in the chart. "
                "If a particular field is not found, leave it as an empty string (''). "
                "Do not guess, infer, or fill with placeholder values."
            ),
        )

        # Task 1: Demographics Extraction
        
        self.demographics_task = Task(
            description=(
                "Extract patient demographics from the patient chart. "
                "All patient demographic information is usually located together in one section "
                "(for example, at the beginning or in a 'Patient Information' block). "
                "Patient demographics may appear in different nearby sections, such as 'Patient Information', 'Reason for Visit', or 'History of Present Illness'. You may extract all demographics from these closely related sections."
                "from multiple parts of the chart.\n\n"
                "Fields to extract:\n"
                "- name\n"
                "- dob (date of birth in mm-dd-yyyy format)\n"
                "- age\n"
                "- gender\n"
                "- address (only patient’s residential address; exclude any hospital or medical centre names)\n"
                "- phone\n"
                "- patient_identifier (may appear as patient ID, MRN, or insurance ID in the chart)\n\n"
                "Return the result strictly as a JSON object in the format:\n"
                "{\n"
                "  'name': '',\n"
                "  'dob': '',\n"
                "  'age': '',\n"
                "  'gender': '',\n"
                "  'address': '',\n"
                "  'phone': '',\n"
                "  'patient_identifier': ''\n"
                "}\n\n"
                "Rules:\n"
                "- Extract data only from the section where all patient demographic information is grouped together.\n"
                "- You may extract values like gender, age, or DOB even if they appear within sentences (e.g., 'This 66-year-old male...')."
                "- If both a name and a gender word appear, always prioritize the gender word in text, not inferred gender from name."
                "- If the chart uses gender-descriptive words like 'gentleman' or 'lady', interpret them as 'male' or 'female' respectively."
                "- Only include values explicitly mentioned in the chart.\n"
                "- If a value is missing or not mentioned, leave it as an empty string.\n"
                "- Do not hallucinate, infer, or guess missing details."
            ),
            expected_output="A JSON object containing the extracted patient demographics.",
            agent=self.patient_info_agent,        
        )
        # Combine agents and tasks in Crew
        self.crew = Crew(
            agents=[self.patient_info_agent],
            tasks=[self.demographics_task],
            process=Process.sequential,
            verbose=True,
        )


    def run(self):
        result = self.crew.kickoff()
        try:
            result = json.loads(repair_json(result))

        except (json.JSONDecodeError, TypeError):
            print(f"[ERROR] Failed to decode the demographic information")

        return result