Spaces:
Sleeping
Sleeping
Emmanuel Chinonye Nnajiofor commited on
Commit ·
4c42be0
1
Parent(s): 8fbb0e2
Migrate project from pip to UV package manager
Browse files- Remove requirements.txt and old pip‑based install commands
- Add pyproject.toml and uv.lock for dependency declaration & locking
- Replace `make install` with `make install` → `uv sync`
- Update Makefile targets (lint, test, format, run) to use `uv` commands
- added logger.py file to enable logs across project
- Adjusted README.md for the new logger.py file added.
- Cleaned out comments
- Other minor changes
This change improves reproducibility by using UV’s lockfile and
isolated virtual environments, and centralizes all commands under
`uv run` so tools always execute in the correct env.
- .python-version +1 -0
- Makefile +32 -8
- README.md +1 -0
- app.py +6 -11
- chat.py +12 -17
- chatlib/assistant_node.py +13 -16
- chatlib/guidlines_rag_agent_li.py +2 -18
- chatlib/idsr_check.py +8 -34
- chatlib/logger.py +65 -0
- chatlib/patient_all_data.py +0 -8
- chatlib/patient_sql_agent.py +7 -21
- chatlib/phi_filter.py +1 -4
- chatlib/state_types.py +4 -24
- main.py +5 -0
- pyproject.toml +26 -0
- requirements.txt +0 -15
- uv.lock +0 -0
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.12
|
Makefile
CHANGED
|
@@ -1,15 +1,39 @@
|
|
| 1 |
-
|
| 2 |
-
pip install --upgrade pip &&\
|
| 3 |
-
pip install -r requirements.txt
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
test:
|
| 9 |
-
PYTHONPATH=. pytest -vv
|
| 10 |
|
| 11 |
format:
|
| 12 |
-
black app.py chatlib
|
| 13 |
|
| 14 |
run:
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
SHELL := /bin/bash
|
|
|
|
|
|
|
| 2 |
|
| 3 |
+
VENV := .venv
|
| 4 |
+
|
| 5 |
+
ifeq ($(OS),Windows_NT)
|
| 6 |
+
VENV_BIN := $(VENV)/Scripts
|
| 7 |
+
PYTHON := $(VENV_BIN)/python.exe
|
| 8 |
+
RM := del /s /q
|
| 9 |
+
else
|
| 10 |
+
VENV_BIN := $(VENV)/bin
|
| 11 |
+
PYTHON := $(VENV_BIN)/python
|
| 12 |
+
RM := rm -rf
|
| 13 |
+
endif
|
| 14 |
+
|
| 15 |
+
.PHONY: venv install install-dev lint test format run clean
|
| 16 |
+
|
| 17 |
+
venv:
|
| 18 |
+
uv venv $(VENV)
|
| 19 |
+
|
| 20 |
+
install: venv
|
| 21 |
+
$(VENV_BIN)/uv sync
|
| 22 |
+
|
| 23 |
+
lint:
|
| 24 |
+
$(VENV_BIN)/pylint --disable=R,C app.py chatlib
|
| 25 |
|
| 26 |
test:
|
| 27 |
+
PYTHONPATH=. $(VENV_BIN)/pytest -vv
|
| 28 |
|
| 29 |
format:
|
| 30 |
+
$(VENV_BIN)/black app.py chatlib
|
| 31 |
|
| 32 |
run:
|
| 33 |
+
$(PYTHON) app.py
|
| 34 |
+
|
| 35 |
+
clean:
|
| 36 |
+
$(RM) $(VENV)
|
| 37 |
+
$(RM) .pytest_cache
|
| 38 |
+
$(RM) __pycache__
|
| 39 |
+
$(RM) .mypy_cache
|
README.md
CHANGED
|
@@ -34,6 +34,7 @@ A conversational assistant designed to help clinicians in Kenya access patient d
|
|
| 34 |
│ ├── assistant_node.py
|
| 35 |
│ ├── guidlines_rag_agent_li.py
|
| 36 |
│ ├── idsr_check.py
|
|
|
|
| 37 |
│ ├── patient_all_data.py
|
| 38 |
│ ├── patient_sql_agent.py
|
| 39 |
│ ├── phi_filter.py
|
|
|
|
| 34 |
│ ├── assistant_node.py
|
| 35 |
│ ├── guidlines_rag_agent_li.py
|
| 36 |
│ ├── idsr_check.py
|
| 37 |
+
│ ├── logger.py
|
| 38 |
│ ├── patient_all_data.py
|
| 39 |
│ ├── patient_sql_agent.py
|
| 40 |
│ ├── phi_filter.py
|
app.py
CHANGED
|
@@ -8,7 +8,7 @@ from langchain_core.messages import HumanMessage, SystemMessage
|
|
| 8 |
from langgraph.prebuilt import tools_condition, ToolNode
|
| 9 |
from langgraph.checkpoint.memory import MemorySaver
|
| 10 |
|
| 11 |
-
|
| 12 |
memory = MemorySaver()
|
| 13 |
|
| 14 |
if os.path.exists("config.env"):
|
|
@@ -29,9 +29,7 @@ def rag_retrieve_tool(query):
|
|
| 29 |
"""Retrieve relevant HIV clinical guidelines for the given query."""
|
| 30 |
result = rag_retrieve(query, llm=llm)
|
| 31 |
return {
|
| 32 |
-
"rag_result": result.get(
|
| 33 |
-
"rag_result", ""
|
| 34 |
-
), # adjust based on your rag_retrieve output
|
| 35 |
"last_tool": "rag_retrieve",
|
| 36 |
}
|
| 37 |
|
|
@@ -89,7 +87,7 @@ builder.add_edge("tools", "assistant")
|
|
| 89 |
react_graph = builder.compile(checkpointer=memory)
|
| 90 |
|
| 91 |
|
| 92 |
-
def chat_with_patient(question: str, thread_id: str = None):
|
| 93 |
# Generate or reuse thread_id for session persistence
|
| 94 |
if thread_id is None or thread_id == "":
|
| 95 |
thread_id = str(uuid.uuid4())
|
|
@@ -97,8 +95,7 @@ def chat_with_patient(question: str, thread_id: str = None):
|
|
| 97 |
# Check input for PHI and redact if necessary
|
| 98 |
question = detect_and_redact_phi(question)["redacted_text"]
|
| 99 |
print(question)
|
| 100 |
-
|
| 101 |
-
# initialize state with patient pk hash
|
| 102 |
input_state: AppState = {
|
| 103 |
"messages": [HumanMessage(content=question)],
|
| 104 |
"question": "",
|
|
@@ -111,13 +108,11 @@ def chat_with_patient(question: str, thread_id: str = None):
|
|
| 111 |
|
| 112 |
config = {"configurable": {"thread_id": thread_id, "user_id": thread_id}}
|
| 113 |
|
| 114 |
-
|
| 115 |
-
output_state = react_graph.invoke(input_state, config)
|
| 116 |
|
| 117 |
for m in output_state["messages"]:
|
| 118 |
m.pretty_print()
|
| 119 |
|
| 120 |
-
# Extract the last AImessage
|
| 121 |
assistant_message = output_state["messages"][-1].content
|
| 122 |
|
| 123 |
return assistant_message, thread_id
|
|
@@ -125,7 +120,7 @@ def chat_with_patient(question: str, thread_id: str = None):
|
|
| 125 |
|
| 126 |
with gr.Blocks() as app:
|
| 127 |
question_input = gr.Textbox(label="Question")
|
| 128 |
-
thread_id_state = gr.State()
|
| 129 |
output_chat = gr.Textbox(label="Assistant Response")
|
| 130 |
|
| 131 |
submit_btn = gr.Button("Ask")
|
|
|
|
| 8 |
from langgraph.prebuilt import tools_condition, ToolNode
|
| 9 |
from langgraph.checkpoint.memory import MemorySaver
|
| 10 |
|
| 11 |
+
|
| 12 |
memory = MemorySaver()
|
| 13 |
|
| 14 |
if os.path.exists("config.env"):
|
|
|
|
| 29 |
"""Retrieve relevant HIV clinical guidelines for the given query."""
|
| 30 |
result = rag_retrieve(query, llm=llm)
|
| 31 |
return {
|
| 32 |
+
"rag_result": result.get("rag_result", ""),
|
|
|
|
|
|
|
| 33 |
"last_tool": "rag_retrieve",
|
| 34 |
}
|
| 35 |
|
|
|
|
| 87 |
react_graph = builder.compile(checkpointer=memory)
|
| 88 |
|
| 89 |
|
| 90 |
+
def chat_with_patient(question: str, thread_id: str = None): # type: ignore
|
| 91 |
# Generate or reuse thread_id for session persistence
|
| 92 |
if thread_id is None or thread_id == "":
|
| 93 |
thread_id = str(uuid.uuid4())
|
|
|
|
| 95 |
# Check input for PHI and redact if necessary
|
| 96 |
question = detect_and_redact_phi(question)["redacted_text"]
|
| 97 |
print(question)
|
| 98 |
+
|
|
|
|
| 99 |
input_state: AppState = {
|
| 100 |
"messages": [HumanMessage(content=question)],
|
| 101 |
"question": "",
|
|
|
|
| 108 |
|
| 109 |
config = {"configurable": {"thread_id": thread_id, "user_id": thread_id}}
|
| 110 |
|
| 111 |
+
output_state = react_graph.invoke(input_state, config) # type: ignore
|
|
|
|
| 112 |
|
| 113 |
for m in output_state["messages"]:
|
| 114 |
m.pretty_print()
|
| 115 |
|
|
|
|
| 116 |
assistant_message = output_state["messages"][-1].content
|
| 117 |
|
| 118 |
return assistant_message, thread_id
|
|
|
|
| 120 |
|
| 121 |
with gr.Blocks() as app:
|
| 122 |
question_input = gr.Textbox(label="Question")
|
| 123 |
+
thread_id_state = gr.State()
|
| 124 |
output_chat = gr.Textbox(label="Assistant Response")
|
| 125 |
|
| 126 |
submit_btn = gr.Button("Ask")
|
chat.py
CHANGED
|
@@ -16,15 +16,13 @@ from chatlib.state_types import AppState
|
|
| 16 |
from chatlib.guidlines_rag_agent_li import rag_retrieve
|
| 17 |
from chatlib.patient_all_data import sql_chain
|
| 18 |
from chatlib.idsr_check import idsr_check
|
| 19 |
-
|
| 20 |
-
# from langchain_ollama.chat_models import ChatOllama
|
| 21 |
-
# llm = ChatOllama(model="mistral:latest", temperature=0)
|
| 22 |
|
| 23 |
tools = [rag_retrieve, sql_chain, idsr_check]
|
| 24 |
llm = ChatOpenAI(temperature = 0.0, model="gpt-4o")
|
| 25 |
llm_with_tools = llm.bind_tools([rag_retrieve, sql_chain, idsr_check])
|
| 26 |
|
| 27 |
-
|
| 28 |
sys_msg = SystemMessage(content="""
|
| 29 |
You are a helpful assistant supporting clinicians during patient visits. You have three tools:
|
| 30 |
|
|
@@ -57,7 +55,7 @@ Do not include any text outside the JSON response.
|
|
| 57 |
""")
|
| 58 |
|
| 59 |
|
| 60 |
-
|
| 61 |
def assistant(state: AppState) -> AppState:
|
| 62 |
|
| 63 |
pk_hash = state.get("pk_hash", None)
|
|
@@ -68,9 +66,8 @@ def assistant(state: AppState) -> AppState:
|
|
| 68 |
else:
|
| 69 |
messages = [sys_msg] + state["messages"]
|
| 70 |
|
| 71 |
-
# Get the LLM/tool response
|
| 72 |
new_message = llm_with_tools.invoke(messages)
|
| 73 |
-
|
| 74 |
|
| 75 |
latest_question = ""
|
| 76 |
for msg in reversed(messages):
|
|
@@ -78,39 +75,37 @@ def assistant(state: AppState) -> AppState:
|
|
| 78 |
latest_question = msg.content
|
| 79 |
break
|
| 80 |
|
| 81 |
-
state['messages'] = state['messages'] + [new_message]
|
| 82 |
-
state['question'] = latest_question
|
| 83 |
return state
|
| 84 |
-
# return {**state, "messages": state['messages'] + [new_message], "question": latest_question}
|
| 85 |
|
| 86 |
# Graph
|
| 87 |
builder = StateGraph(AppState)
|
| 88 |
|
| 89 |
-
|
| 90 |
builder.add_node("assistant", assistant)
|
| 91 |
builder.add_node("tools", ToolNode(tools))
|
| 92 |
|
| 93 |
-
|
| 94 |
builder.add_edge(START, "assistant")
|
| 95 |
builder.add_conditional_edges("assistant", tools_condition)
|
| 96 |
builder.add_edge("tools", "assistant")
|
| 97 |
react_graph = builder.compile(checkpointer=memory)
|
| 98 |
|
| 99 |
-
|
| 100 |
config = {"configurable": {"thread_id": "30"}}
|
| 101 |
|
| 102 |
-
|
| 103 |
input_state:AppState = {
|
| 104 |
"messages": [HumanMessage(content="summarize the patient's clinical visits")],
|
| 105 |
"question": "",
|
| 106 |
"rag_result": "",
|
| 107 |
"answer": "",
|
| 108 |
-
"pk_hash": "962885FEADB7CCF19A2CC506D39818EC448D5396C4D1AEFDC59873090C7FBF73"
|
| 109 |
}
|
| 110 |
|
| 111 |
|
| 112 |
-
|
| 113 |
-
message_output = react_graph.invoke(input_state, config)
|
| 114 |
|
| 115 |
for m in message_output['messages']:
|
| 116 |
m.pretty_print()
|
|
|
|
| 16 |
from chatlib.guidlines_rag_agent_li import rag_retrieve
|
| 17 |
from chatlib.patient_all_data import sql_chain
|
| 18 |
from chatlib.idsr_check import idsr_check
|
| 19 |
+
from chatlib.logger import get_logger
|
|
|
|
|
|
|
| 20 |
|
| 21 |
tools = [rag_retrieve, sql_chain, idsr_check]
|
| 22 |
llm = ChatOpenAI(temperature = 0.0, model="gpt-4o")
|
| 23 |
llm_with_tools = llm.bind_tools([rag_retrieve, sql_chain, idsr_check])
|
| 24 |
|
| 25 |
+
|
| 26 |
sys_msg = SystemMessage(content="""
|
| 27 |
You are a helpful assistant supporting clinicians during patient visits. You have three tools:
|
| 28 |
|
|
|
|
| 55 |
""")
|
| 56 |
|
| 57 |
|
| 58 |
+
|
| 59 |
def assistant(state: AppState) -> AppState:
|
| 60 |
|
| 61 |
pk_hash = state.get("pk_hash", None)
|
|
|
|
| 66 |
else:
|
| 67 |
messages = [sys_msg] + state["messages"]
|
| 68 |
|
|
|
|
| 69 |
new_message = llm_with_tools.invoke(messages)
|
| 70 |
+
|
| 71 |
|
| 72 |
latest_question = ""
|
| 73 |
for msg in reversed(messages):
|
|
|
|
| 75 |
latest_question = msg.content
|
| 76 |
break
|
| 77 |
|
| 78 |
+
state['messages'] = state['messages'] + [new_message] # type: ignore
|
| 79 |
+
state['question'] = latest_question # type: ignore
|
| 80 |
return state
|
|
|
|
| 81 |
|
| 82 |
# Graph
|
| 83 |
builder = StateGraph(AppState)
|
| 84 |
|
| 85 |
+
|
| 86 |
builder.add_node("assistant", assistant)
|
| 87 |
builder.add_node("tools", ToolNode(tools))
|
| 88 |
|
| 89 |
+
|
| 90 |
builder.add_edge(START, "assistant")
|
| 91 |
builder.add_conditional_edges("assistant", tools_condition)
|
| 92 |
builder.add_edge("tools", "assistant")
|
| 93 |
react_graph = builder.compile(checkpointer=memory)
|
| 94 |
|
| 95 |
+
|
| 96 |
config = {"configurable": {"thread_id": "30"}}
|
| 97 |
|
| 98 |
+
|
| 99 |
input_state:AppState = {
|
| 100 |
"messages": [HumanMessage(content="summarize the patient's clinical visits")],
|
| 101 |
"question": "",
|
| 102 |
"rag_result": "",
|
| 103 |
"answer": "",
|
| 104 |
+
"pk_hash": "962885FEADB7CCF19A2CC506D39818EC448D5396C4D1AEFDC59873090C7FBF73" # type: ignore
|
| 105 |
}
|
| 106 |
|
| 107 |
|
| 108 |
+
message_output = react_graph.invoke(input_state, config) # type: ignore
|
|
|
|
| 109 |
|
| 110 |
for m in message_output['messages']:
|
| 111 |
m.pretty_print()
|
chatlib/assistant_node.py
CHANGED
|
@@ -4,7 +4,6 @@ from langchain_core.messages import ToolMessage
|
|
| 4 |
import json
|
| 5 |
|
| 6 |
|
| 7 |
-
# Assistant Node
|
| 8 |
def assistant(state: AppState, sys_msg, llm, llm_with_tools) -> AppState:
|
| 9 |
|
| 10 |
if state.get("messages") and isinstance(state["messages"][-1], ToolMessage):
|
|
@@ -14,11 +13,10 @@ def assistant(state: AppState, sys_msg, llm, llm_with_tools) -> AppState:
|
|
| 14 |
try:
|
| 15 |
tool_content_dict = json.loads(tool_content)
|
| 16 |
state.update(tool_content_dict)
|
| 17 |
-
# print("Merged tool content into state:", tool_content_dict)
|
| 18 |
except json.JSONDecodeError:
|
| 19 |
print("Failed to parse tool content as JSON")
|
| 20 |
elif isinstance(tool_content, dict):
|
| 21 |
-
state.update(tool_content)
|
| 22 |
|
| 23 |
pk_hash = state.get("pk_hash", None)
|
| 24 |
|
|
@@ -30,20 +28,18 @@ def assistant(state: AppState, sys_msg, llm, llm_with_tools) -> AppState:
|
|
| 30 |
else:
|
| 31 |
messages = [sys_msg] + state.get("messages", [])
|
| 32 |
|
| 33 |
-
# Extract latest human question
|
| 34 |
latest_question = ""
|
| 35 |
for msg in reversed(messages):
|
| 36 |
if isinstance(msg, HumanMessage):
|
| 37 |
latest_question = msg.content
|
| 38 |
break
|
| 39 |
|
| 40 |
-
# Generate AIMessage only if answer is new
|
| 41 |
if "answer" in state and state["answer"]:
|
| 42 |
if state.get("last_answer") != state["answer"]:
|
| 43 |
last_tool = state.get("last_tool")
|
| 44 |
-
|
| 45 |
if last_tool == "idsr_check":
|
| 46 |
-
|
| 47 |
disclaimer_needed = not state.get("idsr_disclaimer_shown", False)
|
| 48 |
print(disclaimer_needed)
|
| 49 |
format_instructions = """
|
|
@@ -72,20 +68,22 @@ def assistant(state: AppState, sys_msg, llm, llm_with_tools) -> AppState:
|
|
| 72 |
"Disclaimer: This is not a diagnosis. This is meant to help\n"
|
| 73 |
"identify possible matches based on priority IDSR diseases for clinician awareness.\n"
|
| 74 |
)
|
| 75 |
-
state["idsr_disclaimer_shown"] = True
|
| 76 |
else:
|
| 77 |
disclaimer_text = ""
|
| 78 |
|
| 79 |
-
prompt =
|
|
|
|
|
|
|
|
|
|
| 80 |
print("Prompt sent to LLM:\n", prompt)
|
| 81 |
-
|
| 82 |
llm_response = llm.invoke(prompt)
|
| 83 |
formatted_answer = llm_response.content.strip()
|
| 84 |
|
| 85 |
ai_message = AIMessage(content=formatted_answer)
|
| 86 |
|
| 87 |
-
|
| 88 |
-
state["idsr_disclaimer_shown"] = True
|
| 89 |
|
| 90 |
else:
|
| 91 |
# For other tools, use the raw answer as is
|
|
@@ -93,13 +91,12 @@ def assistant(state: AppState, sys_msg, llm, llm_with_tools) -> AppState:
|
|
| 93 |
|
| 94 |
messages = messages + [ai_message]
|
| 95 |
state["messages"] = messages
|
| 96 |
-
state["question"] = latest_question
|
| 97 |
-
state["last_answer"] = state["answer"]
|
| 98 |
return state
|
| 99 |
|
| 100 |
-
# Otherwise, normal LLM with tools invocation
|
| 101 |
new_message = llm_with_tools.invoke(messages)
|
| 102 |
messages = messages + [new_message]
|
| 103 |
state["messages"] = messages
|
| 104 |
-
state["question"] = latest_question
|
| 105 |
return state
|
|
|
|
| 4 |
import json
|
| 5 |
|
| 6 |
|
|
|
|
| 7 |
def assistant(state: AppState, sys_msg, llm, llm_with_tools) -> AppState:
|
| 8 |
|
| 9 |
if state.get("messages") and isinstance(state["messages"][-1], ToolMessage):
|
|
|
|
| 13 |
try:
|
| 14 |
tool_content_dict = json.loads(tool_content)
|
| 15 |
state.update(tool_content_dict)
|
|
|
|
| 16 |
except json.JSONDecodeError:
|
| 17 |
print("Failed to parse tool content as JSON")
|
| 18 |
elif isinstance(tool_content, dict):
|
| 19 |
+
state.update(tool_content) # type: ignore
|
| 20 |
|
| 21 |
pk_hash = state.get("pk_hash", None)
|
| 22 |
|
|
|
|
| 28 |
else:
|
| 29 |
messages = [sys_msg] + state.get("messages", [])
|
| 30 |
|
|
|
|
| 31 |
latest_question = ""
|
| 32 |
for msg in reversed(messages):
|
| 33 |
if isinstance(msg, HumanMessage):
|
| 34 |
latest_question = msg.content
|
| 35 |
break
|
| 36 |
|
|
|
|
| 37 |
if "answer" in state and state["answer"]:
|
| 38 |
if state.get("last_answer") != state["answer"]:
|
| 39 |
last_tool = state.get("last_tool")
|
| 40 |
+
|
| 41 |
if last_tool == "idsr_check":
|
| 42 |
+
|
| 43 |
disclaimer_needed = not state.get("idsr_disclaimer_shown", False)
|
| 44 |
print(disclaimer_needed)
|
| 45 |
format_instructions = """
|
|
|
|
| 68 |
"Disclaimer: This is not a diagnosis. This is meant to help\n"
|
| 69 |
"identify possible matches based on priority IDSR diseases for clinician awareness.\n"
|
| 70 |
)
|
| 71 |
+
state["idsr_disclaimer_shown"] = True # type: ignore
|
| 72 |
else:
|
| 73 |
disclaimer_text = ""
|
| 74 |
|
| 75 |
+
prompt = (
|
| 76 |
+
format_instructions.format(disclaimer=disclaimer_text)
|
| 77 |
+
+ f"\n\nResponse:\n{state['answer']}"
|
| 78 |
+
)
|
| 79 |
print("Prompt sent to LLM:\n", prompt)
|
| 80 |
+
|
| 81 |
llm_response = llm.invoke(prompt)
|
| 82 |
formatted_answer = llm_response.content.strip()
|
| 83 |
|
| 84 |
ai_message = AIMessage(content=formatted_answer)
|
| 85 |
|
| 86 |
+
state["idsr_disclaimer_shown"] = True # type: ignore
|
|
|
|
| 87 |
|
| 88 |
else:
|
| 89 |
# For other tools, use the raw answer as is
|
|
|
|
| 91 |
|
| 92 |
messages = messages + [ai_message]
|
| 93 |
state["messages"] = messages
|
| 94 |
+
state["question"] = latest_question # type: ignore
|
| 95 |
+
state["last_answer"] = state["answer"]
|
| 96 |
return state
|
| 97 |
|
|
|
|
| 98 |
new_message = llm_with_tools.invoke(messages)
|
| 99 |
messages = messages + [new_message]
|
| 100 |
state["messages"] = messages
|
| 101 |
+
state["question"] = latest_question # type: ignore
|
| 102 |
return state
|
chatlib/guidlines_rag_agent_li.py
CHANGED
|
@@ -1,12 +1,11 @@
|
|
| 1 |
from llama_index.core import StorageContext, load_index_from_storage
|
| 2 |
from .state_types import AppState
|
| 3 |
|
| 4 |
-
|
| 5 |
storage_context = StorageContext.from_defaults(persist_dir="guidance_docs/arv_metadata")
|
| 6 |
index = load_index_from_storage(storage_context)
|
| 7 |
retriever = index.as_retriever(
|
| 8 |
similarity_top_k=3,
|
| 9 |
-
# Similarity threshold for filtering
|
| 10 |
similarity_threshold=0.5,
|
| 11 |
)
|
| 12 |
|
|
@@ -26,21 +25,6 @@ def rag_retrieve(query: str, llm) -> AppState:
|
|
| 26 |
f"Guideline Text:\n{retrieved_text}"
|
| 27 |
)
|
| 28 |
|
| 29 |
-
# Call your LLM to generate the summary
|
| 30 |
summary_response = llm.invoke(summarization_prompt)
|
| 31 |
|
| 32 |
-
return {"rag_result": summary_response.content, "last_tool": "rag_retrieve"}
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
# if __name__ == "__main__":
|
| 36 |
-
# # Test the function
|
| 37 |
-
# test_state = AppState(
|
| 38 |
-
# messages=[],
|
| 39 |
-
# question="What are the first-line treatments for HIV in Kenya?",
|
| 40 |
-
# rag_result="",
|
| 41 |
-
# query="",
|
| 42 |
-
# result="",
|
| 43 |
-
# answer=""
|
| 44 |
-
# )
|
| 45 |
-
# updated_state = rag_retrieve(test_state)
|
| 46 |
-
# print(updated_state["rag_result"])
|
|
|
|
| 1 |
from llama_index.core import StorageContext, load_index_from_storage
|
| 2 |
from .state_types import AppState
|
| 3 |
|
| 4 |
+
|
| 5 |
storage_context = StorageContext.from_defaults(persist_dir="guidance_docs/arv_metadata")
|
| 6 |
index = load_index_from_storage(storage_context)
|
| 7 |
retriever = index.as_retriever(
|
| 8 |
similarity_top_k=3,
|
|
|
|
| 9 |
similarity_threshold=0.5,
|
| 10 |
)
|
| 11 |
|
|
|
|
| 25 |
f"Guideline Text:\n{retrieved_text}"
|
| 26 |
)
|
| 27 |
|
|
|
|
| 28 |
summary_response = llm.invoke(summarization_prompt)
|
| 29 |
|
| 30 |
+
return {"rag_result": summary_response.content, "last_tool": "rag_retrieve"} # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chatlib/idsr_check.py
CHANGED
|
@@ -10,26 +10,23 @@ import json
|
|
| 10 |
import math
|
| 11 |
from collections import Counter
|
| 12 |
|
| 13 |
-
|
| 14 |
-
# load keywords from file
|
| 15 |
with open("./guidance_docs/idsr_keywords.txt", "r", encoding="utf-8") as f:
|
| 16 |
keywords = [line.strip() for line in f if line.strip()]
|
| 17 |
|
| 18 |
-
# load vectorstore
|
| 19 |
vectorstore = FAISS.load_local(
|
| 20 |
"./guidance_docs/disease_vectorstore",
|
| 21 |
OpenAIEmbeddings(),
|
| 22 |
allow_dangerous_deserialization=True,
|
| 23 |
)
|
| 24 |
|
| 25 |
-
|
| 26 |
with open("./guidance_docs/tagged_documents.json", "r", encoding="utf-8") as f:
|
| 27 |
doc_dicts = json.load(f)
|
| 28 |
|
| 29 |
tagged_documents = [Document(**d) for d in doc_dicts]
|
| 30 |
|
| 31 |
-
|
| 32 |
-
# Count how many documents each keyword appears in
|
| 33 |
keyword_doc_counts = Counter()
|
| 34 |
total_docs = len(tagged_documents)
|
| 35 |
|
|
@@ -38,10 +35,9 @@ for doc in tagged_documents:
|
|
| 38 |
for kw in seen:
|
| 39 |
keyword_doc_counts[kw] += 1
|
| 40 |
|
| 41 |
-
|
| 42 |
keyword_weights = {
|
| 43 |
-
kw: math.log(total_docs / (1 + count))
|
| 44 |
-
for kw, count in keyword_doc_counts.items()
|
| 45 |
}
|
| 46 |
|
| 47 |
|
|
@@ -51,7 +47,6 @@ def score_doc(doc_to_score, matched_keywords):
|
|
| 51 |
return sum(keyword_weights.get(kw, 0) for kw in overlap)
|
| 52 |
|
| 53 |
|
| 54 |
-
## Define helper functions
|
| 55 |
class KeywordsOutput(BaseModel):
|
| 56 |
keywords: List[str] = Field(
|
| 57 |
description="List of relevant keywords extracted from the query"
|
|
@@ -75,7 +70,6 @@ Return the matching keywords as a JSON object with a single key "keywords" whose
|
|
| 75 |
"""
|
| 76 |
)
|
| 77 |
|
| 78 |
-
# Compose the chain as a RunnableSequence: prompt -> llm -> parser
|
| 79 |
chain = prompt | llm | parser
|
| 80 |
|
| 81 |
output = chain.invoke(
|
|
@@ -86,24 +80,17 @@ Return the matching keywords as a JSON object with a single key "keywords" whose
|
|
| 86 |
}
|
| 87 |
)
|
| 88 |
|
| 89 |
-
# output is a list of strings, not a KeywordsOutput instance
|
| 90 |
return output.keywords
|
| 91 |
|
| 92 |
|
| 93 |
-
# function to perform hybrid search combining semantic search and keyword matching
|
| 94 |
def hybrid_search_with_query_keywords(
|
| 95 |
query, vstore, documents, keyword_list, llm, top_k=5
|
| 96 |
):
|
| 97 |
|
| 98 |
-
# Step 1: Semantic search
|
| 99 |
semantic_hits = vstore.similarity_search(query, k=top_k)
|
| 100 |
|
| 101 |
-
# Step 2: Use GPT to extract keywords from the query
|
| 102 |
matched_keywords = extract_keywords_with_gpt(query, llm, keyword_list)
|
| 103 |
|
| 104 |
-
# print("Matched keywords:", matched_keywords)
|
| 105 |
-
|
| 106 |
-
# Step 3: Filter docs whose metadata has any of those keywords
|
| 107 |
keyword_hits = [
|
| 108 |
doc
|
| 109 |
for doc in documents
|
|
@@ -114,33 +101,22 @@ def hybrid_search_with_query_keywords(
|
|
| 114 |
)
|
| 115 |
]
|
| 116 |
|
| 117 |
-
# print("Keyword hits:", len(keyword_hits))
|
| 118 |
-
|
| 119 |
-
# Step 4: Score keyword-matching documents by keyword rarity
|
| 120 |
scored_docs = [
|
| 121 |
(
|
| 122 |
doc,
|
| 123 |
score_doc(doc, matched_keywords),
|
| 124 |
-
)
|
| 125 |
for doc in keyword_hits
|
| 126 |
]
|
| 127 |
|
| 128 |
-
# # print doc metadata and scores
|
| 129 |
-
# for doc, score in scored_docs:
|
| 130 |
-
# print(f"Document: {doc.metadata.get('disease_name', 'Unknown')}, Score: {score}")
|
| 131 |
-
# print(f"Matched Keywords: {doc.metadata.get('matched_keywords', [])}")
|
| 132 |
-
|
| 133 |
-
# Step 5: Rank and select top documents by score
|
| 134 |
ranked_docs = sorted(scored_docs, key=lambda x: -x[1])
|
| 135 |
top_docs = [doc for doc, score in ranked_docs if score > 0]
|
| 136 |
top_3_docs = top_docs[:3]
|
| 137 |
|
| 138 |
-
# Step 4: Merge by unique content
|
| 139 |
merged = {doc.page_content: doc for doc in semantic_hits + top_3_docs}
|
| 140 |
return list(merged.values())
|
| 141 |
|
| 142 |
|
| 143 |
-
# Main function to perform the IDSR check
|
| 144 |
def idsr_check(query: str, llm) -> AppState:
|
| 145 |
"""
|
| 146 |
Perform hybrid search combining semantic search and keyword matching.
|
|
@@ -151,7 +127,7 @@ def idsr_check(query: str, llm) -> AppState:
|
|
| 151 |
Returns:
|
| 152 |
AppState: Updated state with search results.
|
| 153 |
"""
|
| 154 |
-
|
| 155 |
results = hybrid_search_with_query_keywords(
|
| 156 |
query, vectorstore, tagged_documents, keywords, llm
|
| 157 |
)
|
|
@@ -163,7 +139,6 @@ def idsr_check(query: str, llm) -> AppState:
|
|
| 163 |
]
|
| 164 |
)
|
| 165 |
|
| 166 |
-
# Prepare prompt for the LLM
|
| 167 |
prompt = """
|
| 168 |
You are a medical assistant reviewing a brief clinical case in Kenya to help identify which diseases the patient may plausibly have. You have access to several disease definitions.
|
| 169 |
|
|
@@ -202,7 +177,6 @@ def idsr_check(query: str, llm) -> AppState:
|
|
| 202 |
query=query, disease_definitions=disease_definitions
|
| 203 |
)
|
| 204 |
|
| 205 |
-
# Call the LLM to generate the answer, passing the case description and disease definitions
|
| 206 |
llm_response = llm.invoke(prompt)
|
| 207 |
answer_text = (
|
| 208 |
llm_response.content.strip()
|
|
@@ -210,4 +184,4 @@ def idsr_check(query: str, llm) -> AppState:
|
|
| 210 |
else "No relevant disease information found."
|
| 211 |
)
|
| 212 |
|
| 213 |
-
return {"answer": answer_text, "last_tool": "idsr_check"}
|
|
|
|
| 10 |
import math
|
| 11 |
from collections import Counter
|
| 12 |
|
| 13 |
+
|
|
|
|
| 14 |
with open("./guidance_docs/idsr_keywords.txt", "r", encoding="utf-8") as f:
|
| 15 |
keywords = [line.strip() for line in f if line.strip()]
|
| 16 |
|
|
|
|
| 17 |
vectorstore = FAISS.load_local(
|
| 18 |
"./guidance_docs/disease_vectorstore",
|
| 19 |
OpenAIEmbeddings(),
|
| 20 |
allow_dangerous_deserialization=True,
|
| 21 |
)
|
| 22 |
|
| 23 |
+
|
| 24 |
with open("./guidance_docs/tagged_documents.json", "r", encoding="utf-8") as f:
|
| 25 |
doc_dicts = json.load(f)
|
| 26 |
|
| 27 |
tagged_documents = [Document(**d) for d in doc_dicts]
|
| 28 |
|
| 29 |
+
|
|
|
|
| 30 |
keyword_doc_counts = Counter()
|
| 31 |
total_docs = len(tagged_documents)
|
| 32 |
|
|
|
|
| 35 |
for kw in seen:
|
| 36 |
keyword_doc_counts[kw] += 1
|
| 37 |
|
| 38 |
+
|
| 39 |
keyword_weights = {
|
| 40 |
+
kw: math.log(total_docs / (1 + count)) for kw, count in keyword_doc_counts.items()
|
|
|
|
| 41 |
}
|
| 42 |
|
| 43 |
|
|
|
|
| 47 |
return sum(keyword_weights.get(kw, 0) for kw in overlap)
|
| 48 |
|
| 49 |
|
|
|
|
| 50 |
class KeywordsOutput(BaseModel):
|
| 51 |
keywords: List[str] = Field(
|
| 52 |
description="List of relevant keywords extracted from the query"
|
|
|
|
| 70 |
"""
|
| 71 |
)
|
| 72 |
|
|
|
|
| 73 |
chain = prompt | llm | parser
|
| 74 |
|
| 75 |
output = chain.invoke(
|
|
|
|
| 80 |
}
|
| 81 |
)
|
| 82 |
|
|
|
|
| 83 |
return output.keywords
|
| 84 |
|
| 85 |
|
|
|
|
| 86 |
def hybrid_search_with_query_keywords(
|
| 87 |
query, vstore, documents, keyword_list, llm, top_k=5
|
| 88 |
):
|
| 89 |
|
|
|
|
| 90 |
semantic_hits = vstore.similarity_search(query, k=top_k)
|
| 91 |
|
|
|
|
| 92 |
matched_keywords = extract_keywords_with_gpt(query, llm, keyword_list)
|
| 93 |
|
|
|
|
|
|
|
|
|
|
| 94 |
keyword_hits = [
|
| 95 |
doc
|
| 96 |
for doc in documents
|
|
|
|
| 101 |
)
|
| 102 |
]
|
| 103 |
|
|
|
|
|
|
|
|
|
|
| 104 |
scored_docs = [
|
| 105 |
(
|
| 106 |
doc,
|
| 107 |
score_doc(doc, matched_keywords),
|
| 108 |
+
)
|
| 109 |
for doc in keyword_hits
|
| 110 |
]
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
ranked_docs = sorted(scored_docs, key=lambda x: -x[1])
|
| 113 |
top_docs = [doc for doc, score in ranked_docs if score > 0]
|
| 114 |
top_3_docs = top_docs[:3]
|
| 115 |
|
|
|
|
| 116 |
merged = {doc.page_content: doc for doc in semantic_hits + top_3_docs}
|
| 117 |
return list(merged.values())
|
| 118 |
|
| 119 |
|
|
|
|
| 120 |
def idsr_check(query: str, llm) -> AppState:
|
| 121 |
"""
|
| 122 |
Perform hybrid search combining semantic search and keyword matching.
|
|
|
|
| 127 |
Returns:
|
| 128 |
AppState: Updated state with search results.
|
| 129 |
"""
|
| 130 |
+
|
| 131 |
results = hybrid_search_with_query_keywords(
|
| 132 |
query, vectorstore, tagged_documents, keywords, llm
|
| 133 |
)
|
|
|
|
| 139 |
]
|
| 140 |
)
|
| 141 |
|
|
|
|
| 142 |
prompt = """
|
| 143 |
You are a medical assistant reviewing a brief clinical case in Kenya to help identify which diseases the patient may plausibly have. You have access to several disease definitions.
|
| 144 |
|
|
|
|
| 177 |
query=query, disease_definitions=disease_definitions
|
| 178 |
)
|
| 179 |
|
|
|
|
| 180 |
llm_response = llm.invoke(prompt)
|
| 181 |
answer_text = (
|
| 182 |
llm_response.content.strip()
|
|
|
|
| 184 |
else "No relevant disease information found."
|
| 185 |
)
|
| 186 |
|
| 187 |
+
return {"answer": answer_text, "last_tool": "idsr_check"} # type: ignore
|
chatlib/logger.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from logging.handlers import RotatingFileHandler
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
LOG_DIR = Path(os.getenv("LOG_DIR", Path(__file__).resolve().parent.parent / "logs"))
|
| 9 |
+
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
| 10 |
+
|
| 11 |
+
LOG_FILE = LOG_DIR / "app.log"
|
| 12 |
+
|
| 13 |
+
LOG_CONFIG = {
|
| 14 |
+
"console_format": "%(asctime)s | %(levelname)-8s | %(name)-20s | %(message)s",
|
| 15 |
+
"file_format": "%(asctime)s | %(levelname)-8s | %(name)-20s | %(funcName)s:%(lineno)d | %(message)s",
|
| 16 |
+
"date_format": "%Y-%m-%d %H:%M:%S",
|
| 17 |
+
"max_bytes": 5 * 1024 * 1024, # 5MB
|
| 18 |
+
"backup_count": 3,
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
_logger_cache = {}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_logger(name: str = "text2sql-app") -> logging.Logger:
|
| 25 |
+
"""Get a configured logger instance with console and file handlers
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
name: Logger name (usually __name__ of calling module)
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
Configured Logger instance
|
| 32 |
+
"""
|
| 33 |
+
if name in _logger_cache:
|
| 34 |
+
return _logger_cache[name]
|
| 35 |
+
|
| 36 |
+
logger = logging.getLogger(name)
|
| 37 |
+
logger.setLevel(logging.DEBUG)
|
| 38 |
+
|
| 39 |
+
if logger.hasHandlers():
|
| 40 |
+
_logger_cache[name] = logger
|
| 41 |
+
return logger
|
| 42 |
+
|
| 43 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
| 44 |
+
console_handler.setLevel(logging.INFO)
|
| 45 |
+
console_formatter = logging.Formatter(
|
| 46 |
+
LOG_CONFIG["console_format"], LOG_CONFIG["date_format"]
|
| 47 |
+
)
|
| 48 |
+
console_handler.setFormatter(console_formatter)
|
| 49 |
+
|
| 50 |
+
file_handler = RotatingFileHandler(
|
| 51 |
+
LOG_FILE,
|
| 52 |
+
maxBytes=LOG_CONFIG["max_bytes"],
|
| 53 |
+
backupCount=LOG_CONFIG["backup_count"],
|
| 54 |
+
)
|
| 55 |
+
file_handler.setLevel(logging.DEBUG)
|
| 56 |
+
file_formatter = logging.Formatter(
|
| 57 |
+
LOG_CONFIG["file_format"], LOG_CONFIG["date_format"]
|
| 58 |
+
)
|
| 59 |
+
file_handler.setFormatter(file_formatter)
|
| 60 |
+
|
| 61 |
+
logger.addHandler(console_handler)
|
| 62 |
+
logger.addHandler(file_handler)
|
| 63 |
+
|
| 64 |
+
_logger_cache[name] = logger
|
| 65 |
+
return logger
|
chatlib/patient_all_data.py
CHANGED
|
@@ -3,14 +3,12 @@ import pandas as pd
|
|
| 3 |
import os
|
| 4 |
|
| 5 |
|
| 6 |
-
# define helper functions
|
| 7 |
def safe(val):
|
| 8 |
if pd.isnull(val) or val in ("", "NULL"):
|
| 9 |
return "missing"
|
| 10 |
return val
|
| 11 |
|
| 12 |
|
| 13 |
-
# function to return only year of date
|
| 14 |
def extract_year(date_str):
|
| 15 |
if pd.isnull(date_str) or date_str in ("", "NULL"):
|
| 16 |
return "missing"
|
|
@@ -20,7 +18,6 @@ def extract_year(date_str):
|
|
| 20 |
return "invalid date"
|
| 21 |
|
| 22 |
|
| 23 |
-
# Define the SQL query tool
|
| 24 |
def sql_chain(query: str, llm, rag_result: str) -> dict:
|
| 25 |
"""
|
| 26 |
Annotated function that takes a patient identifer (pk_hash) and returns
|
|
@@ -44,7 +41,6 @@ def sql_chain(query: str, llm, rag_result: str) -> dict:
|
|
| 44 |
conn = sqlite3.connect("data/patient_demonstration.sqlite")
|
| 45 |
cursor = conn.cursor()
|
| 46 |
|
| 47 |
-
# Write the SQL query using the QuerySQLDatabaseTool
|
| 48 |
cursor.execute(
|
| 49 |
"SELECT * FROM clinical_visits WHERE PatientPKHash = :pk_hash",
|
| 50 |
{"pk_hash": pk_hash},
|
|
@@ -172,10 +168,6 @@ def sql_chain(query: str, llm, rag_result: str) -> dict:
|
|
| 172 |
|
| 173 |
demographic_summary = summarize_demographics(demographic_data)
|
| 174 |
|
| 175 |
-
# cursor.execute("SELECT * FROM data_dictionary")
|
| 176 |
-
# rows = cursor.fetchall()
|
| 177 |
-
# data_dictionary = pd.DataFrame(rows, columns=[column[0] for column in cursor.description])
|
| 178 |
-
|
| 179 |
conn.close()
|
| 180 |
|
| 181 |
prompt = (
|
|
|
|
| 3 |
import os
|
| 4 |
|
| 5 |
|
|
|
|
| 6 |
def safe(val):
|
| 7 |
if pd.isnull(val) or val in ("", "NULL"):
|
| 8 |
return "missing"
|
| 9 |
return val
|
| 10 |
|
| 11 |
|
|
|
|
| 12 |
def extract_year(date_str):
|
| 13 |
if pd.isnull(date_str) or date_str in ("", "NULL"):
|
| 14 |
return "missing"
|
|
|
|
| 18 |
return "invalid date"
|
| 19 |
|
| 20 |
|
|
|
|
| 21 |
def sql_chain(query: str, llm, rag_result: str) -> dict:
|
| 22 |
"""
|
| 23 |
Annotated function that takes a patient identifer (pk_hash) and returns
|
|
|
|
| 41 |
conn = sqlite3.connect("data/patient_demonstration.sqlite")
|
| 42 |
cursor = conn.cursor()
|
| 43 |
|
|
|
|
| 44 |
cursor.execute(
|
| 45 |
"SELECT * FROM clinical_visits WHERE PatientPKHash = :pk_hash",
|
| 46 |
{"pk_hash": pk_hash},
|
|
|
|
| 168 |
|
| 169 |
demographic_summary = summarize_demographics(demographic_data)
|
| 170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
conn.close()
|
| 172 |
|
| 173 |
prompt = (
|
chatlib/patient_sql_agent.py
CHANGED
|
@@ -11,10 +11,7 @@ from .state_types import AppState
|
|
| 11 |
db = SQLDatabase.from_uri("sqlite:///data/patient_demonstration.sqlite")
|
| 12 |
llm = ChatOpenAI(temperature=0.0, model="gpt-4o")
|
| 13 |
|
| 14 |
-
# from langchain_ollama.chat_models import ChatOllama
|
| 15 |
-
# local_llm = ChatOllama(model="mistral:latest", temperature=0)
|
| 16 |
|
| 17 |
-
# setup template for sql query tool
|
| 18 |
system_message = """
|
| 19 |
Given an input question, create a syntactically correct {dialect} query to
|
| 20 |
run to help find the answer. The database contains the following tables and columns:
|
|
@@ -109,10 +106,7 @@ def write_query(state: AppState) -> AppState:
|
|
| 109 |
prompt = query_prompt_template.invoke(
|
| 110 |
{
|
| 111 |
"dialect": db.dialect,
|
| 112 |
-
|
| 113 |
-
"table_info": db.run(
|
| 114 |
-
"SELECT * FROM data_dictionary;"
|
| 115 |
-
), # db.get_table_info(),
|
| 116 |
"input": state["question"],
|
| 117 |
"guidelines": state.get("rag_result", "No guidelines provided."),
|
| 118 |
"pk_hash": state.get("pk_hash", ""),
|
|
@@ -121,19 +115,16 @@ def write_query(state: AppState) -> AppState:
|
|
| 121 |
|
| 122 |
structured_llm = llm.with_structured_output(QueryOutput)
|
| 123 |
result = structured_llm.invoke(prompt)
|
| 124 |
-
|
| 125 |
-
state["query"] = result["query"]
|
| 126 |
return state
|
| 127 |
-
# return {**state, "query": result["query"]}
|
| 128 |
|
| 129 |
|
| 130 |
def execute_query(state: AppState) -> AppState:
|
| 131 |
"""Execute SQL query."""
|
| 132 |
|
| 133 |
execute_query_tool = QuerySQLDatabaseTool(db=db)
|
| 134 |
-
state["result"] = execute_query_tool.invoke(state["query"])
|
| 135 |
return state
|
| 136 |
-
# return {**state, "result": execute_query_tool.invoke(state["query"])}
|
| 137 |
|
| 138 |
|
| 139 |
def generate_answer(state: AppState) -> AppState:
|
|
@@ -153,19 +144,14 @@ def generate_answer(state: AppState) -> AppState:
|
|
| 153 |
"In that case, ignore the SQL query too and generate an answer based only on the context. \n\n"
|
| 154 |
f'Question: {state["question"]}\n'
|
| 155 |
f'Context: {state.get("rag_result", "No guidelines provided.")}\n'
|
| 156 |
-
f'SQL Query: {state["query"]}\n'
|
| 157 |
-
f'SQL Result: {state["result"]}'
|
| 158 |
-
# f'Question: {state["question"]}\n'
|
| 159 |
-
# f'SQL Query: {state["query"]}\n'
|
| 160 |
-
# f'SQL Result: {state["result"]}'
|
| 161 |
)
|
| 162 |
response = llm.invoke(prompt)
|
| 163 |
-
state["answer"] = response.content
|
| 164 |
return state
|
| 165 |
-
# return {**state, "answer": response.content}
|
| 166 |
|
| 167 |
|
| 168 |
-
# now define a stateful tool that does the same thing
|
| 169 |
@tool
|
| 170 |
def sql_chain(state: AppState) -> dict:
|
| 171 |
"""
|
|
@@ -178,4 +164,4 @@ def sql_chain(state: AppState) -> dict:
|
|
| 178 |
state = execute_query(state)
|
| 179 |
state = generate_answer(state)
|
| 180 |
|
| 181 |
-
return state
|
|
|
|
| 11 |
db = SQLDatabase.from_uri("sqlite:///data/patient_demonstration.sqlite")
|
| 12 |
llm = ChatOpenAI(temperature=0.0, model="gpt-4o")
|
| 13 |
|
|
|
|
|
|
|
| 14 |
|
|
|
|
| 15 |
system_message = """
|
| 16 |
Given an input question, create a syntactically correct {dialect} query to
|
| 17 |
run to help find the answer. The database contains the following tables and columns:
|
|
|
|
| 106 |
prompt = query_prompt_template.invoke(
|
| 107 |
{
|
| 108 |
"dialect": db.dialect,
|
| 109 |
+
"table_info": db.run("SELECT * FROM data_dictionary;"),
|
|
|
|
|
|
|
|
|
|
| 110 |
"input": state["question"],
|
| 111 |
"guidelines": state.get("rag_result", "No guidelines provided."),
|
| 112 |
"pk_hash": state.get("pk_hash", ""),
|
|
|
|
| 115 |
|
| 116 |
structured_llm = llm.with_structured_output(QueryOutput)
|
| 117 |
result = structured_llm.invoke(prompt)
|
| 118 |
+
state["query"] = result["query"] # type: ignore
|
|
|
|
| 119 |
return state
|
|
|
|
| 120 |
|
| 121 |
|
| 122 |
def execute_query(state: AppState) -> AppState:
|
| 123 |
"""Execute SQL query."""
|
| 124 |
|
| 125 |
execute_query_tool = QuerySQLDatabaseTool(db=db)
|
| 126 |
+
state["result"] = execute_query_tool.invoke(state["query"]) # type: ignore
|
| 127 |
return state
|
|
|
|
| 128 |
|
| 129 |
|
| 130 |
def generate_answer(state: AppState) -> AppState:
|
|
|
|
| 144 |
"In that case, ignore the SQL query too and generate an answer based only on the context. \n\n"
|
| 145 |
f'Question: {state["question"]}\n'
|
| 146 |
f'Context: {state.get("rag_result", "No guidelines provided.")}\n'
|
| 147 |
+
f'SQL Query: {state["query"]}\n' # type: ignore
|
| 148 |
+
f'SQL Result: {state["result"]}' # type: ignore
|
|
|
|
|
|
|
|
|
|
| 149 |
)
|
| 150 |
response = llm.invoke(prompt)
|
| 151 |
+
state["answer"] = response.content # type: ignore
|
| 152 |
return state
|
|
|
|
| 153 |
|
| 154 |
|
|
|
|
| 155 |
@tool
|
| 156 |
def sql_chain(state: AppState) -> dict:
|
| 157 |
"""
|
|
|
|
| 164 |
state = execute_query(state)
|
| 165 |
state = generate_answer(state)
|
| 166 |
|
| 167 |
+
return state # type: ignore
|
chatlib/phi_filter.py
CHANGED
|
@@ -4,7 +4,7 @@ import dateparser.search
|
|
| 4 |
from datetime import datetime
|
| 5 |
from dateutil.relativedelta import relativedelta
|
| 6 |
|
| 7 |
-
|
| 8 |
RELATIVE_INDICATORS = [
|
| 9 |
"ago",
|
| 10 |
"later",
|
|
@@ -28,7 +28,6 @@ def is_relative_date(text_relative):
|
|
| 28 |
return any(word in text_lower for word in RELATIVE_INDICATORS)
|
| 29 |
|
| 30 |
|
| 31 |
-
# Load Kenyan names list (basic txt file, one name per line, all lowercase for comparison)
|
| 32 |
def load_kenyan_names(filepath="data/kenyan_names.txt"):
|
| 33 |
if not Path(filepath).exists():
|
| 34 |
return set()
|
|
@@ -78,12 +77,10 @@ def detect_and_redact_phi(text_input):
|
|
| 78 |
|
| 79 |
phi_detected = bool(names_found or dates_found)
|
| 80 |
|
| 81 |
-
# Redact dates with relative descriptions
|
| 82 |
for match, dt in dates_found:
|
| 83 |
relative = describe_relative_date(dt)
|
| 84 |
text_input = text_input.replace(match, relative)
|
| 85 |
|
| 86 |
-
# Redact Kenyan names
|
| 87 |
for name in names_found:
|
| 88 |
pattern = re.compile(rf"\b{name}\b", re.IGNORECASE)
|
| 89 |
text_input = pattern.sub("[name]", text_input)
|
|
|
|
| 4 |
from datetime import datetime
|
| 5 |
from dateutil.relativedelta import relativedelta
|
| 6 |
|
| 7 |
+
|
| 8 |
RELATIVE_INDICATORS = [
|
| 9 |
"ago",
|
| 10 |
"later",
|
|
|
|
| 28 |
return any(word in text_lower for word in RELATIVE_INDICATORS)
|
| 29 |
|
| 30 |
|
|
|
|
| 31 |
def load_kenyan_names(filepath="data/kenyan_names.txt"):
|
| 32 |
if not Path(filepath).exists():
|
| 33 |
return set()
|
|
|
|
| 77 |
|
| 78 |
phi_detected = bool(names_found or dates_found)
|
| 79 |
|
|
|
|
| 80 |
for match, dt in dates_found:
|
| 81 |
relative = describe_relative_date(dt)
|
| 82 |
text_input = text_input.replace(match, relative)
|
| 83 |
|
|
|
|
| 84 |
for name in names_found:
|
| 85 |
pattern = re.compile(rf"\b{name}\b", re.IGNORECASE)
|
| 86 |
text_input = pattern.sub("[name]", text_input)
|
chatlib/state_types.py
CHANGED
|
@@ -1,33 +1,13 @@
|
|
| 1 |
-
from typing_extensions import TypedDict, Annotated
|
| 2 |
-
from typing import Optional
|
| 3 |
from langchain_core.messages import AnyMessage
|
| 4 |
from langgraph.graph.message import add_messages
|
| 5 |
|
| 6 |
-
# class ConversationState(TypedDict):
|
| 7 |
-
# question: str
|
| 8 |
-
# answer: str
|
| 9 |
-
# rag_result: str
|
| 10 |
-
# pk_hash: Optional[str]
|
| 11 |
-
|
| 12 |
-
# class QueryState(TypedDict):
|
| 13 |
-
# query: str
|
| 14 |
-
# result: Optional[str]
|
| 15 |
-
|
| 16 |
-
# class AppState(TypedDict):
|
| 17 |
-
# messages: Annotated[list[AnyMessage], add_messages]
|
| 18 |
-
# conversation: ConversationState
|
| 19 |
-
# query_data: QueryState
|
| 20 |
-
|
| 21 |
-
# class SqlChainOutputModel(BaseModel):
|
| 22 |
-
# messages: List[AnyMessage] = Field(...)
|
| 23 |
-
# conversation: ConversationState = Field(...)
|
| 24 |
-
|
| 25 |
|
| 26 |
class AppState(TypedDict):
|
| 27 |
messages: Annotated[list[AnyMessage], add_messages]
|
| 28 |
question: str
|
| 29 |
rag_result: str
|
| 30 |
answer: str
|
| 31 |
-
last_answer:
|
| 32 |
-
last_tool:
|
| 33 |
-
idsr_disclaimer: bool
|
|
|
|
| 1 |
+
from typing_extensions import TypedDict, Annotated, NotRequired
|
|
|
|
| 2 |
from langchain_core.messages import AnyMessage
|
| 3 |
from langgraph.graph.message import add_messages
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
class AppState(TypedDict):
|
| 7 |
messages: Annotated[list[AnyMessage], add_messages]
|
| 8 |
question: str
|
| 9 |
rag_result: str
|
| 10 |
answer: str
|
| 11 |
+
last_answer: NotRequired[str | None]
|
| 12 |
+
last_tool: NotRequired[str | None]
|
| 13 |
+
idsr_disclaimer: NotRequired[bool]
|
main.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def main():
|
| 2 |
+
print("Hello from clinicalassistant!")
|
| 3 |
+
|
| 4 |
+
if __name__ == "__main__":
|
| 5 |
+
main()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "clinicalassistant"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.12"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"black>=25.1.0",
|
| 9 |
+
"dateparser>=1.2.2",
|
| 10 |
+
"faiss-cpu>=1.11.0",
|
| 11 |
+
"gradio>=5.36.2",
|
| 12 |
+
"langchain-community>=0.3.27",
|
| 13 |
+
"langchain-openai>=0.3.27",
|
| 14 |
+
"langgraph>=0.5.2",
|
| 15 |
+
"llama-index>=0.12.48",
|
| 16 |
+
"pandas>=2.3.1",
|
| 17 |
+
"pylint>=3.3.7",
|
| 18 |
+
"python-dotenv>=1.1.1",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
[dependency-groups]
|
| 22 |
+
dev = [
|
| 23 |
+
"black>=25.1.0",
|
| 24 |
+
"mypy>=1.16.1",
|
| 25 |
+
"pytest>=8.4.1",
|
| 26 |
+
]
|
requirements.txt
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
dateparser==1.2.2
|
| 2 |
-
gradio==5.36.2
|
| 3 |
-
langchain_community==0.3.27
|
| 4 |
-
langchain_core==0.3.68
|
| 5 |
-
langchain_openai==0.3.27
|
| 6 |
-
langgraph==0.5.2
|
| 7 |
-
llama_index==0.12.48
|
| 8 |
-
pandas==2.3.1
|
| 9 |
-
pydantic==2.11.7
|
| 10 |
-
python-dotenv==1.1.1
|
| 11 |
-
python_dateutil==2.9.0.post0
|
| 12 |
-
typing_extensions==4.14.1
|
| 13 |
-
pylint==3.3.7
|
| 14 |
-
black==25.1.0
|
| 15 |
-
faiss-cpu==1.11.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|