Emmanuel Chinonye Nnajiofor commited on
Commit
4c42be0
·
1 Parent(s): 8fbb0e2

Migrate project from pip to UV package manager

Browse files

- Remove requirements.txt and old pip‑based install commands
- Add pyproject.toml and uv.lock for dependency declaration & locking
- Replace `make install` with `make install` → `uv sync`
- Update Makefile targets (lint, test, format, run) to use `uv` commands
- added logger.py file to enable logs across project
- Adjusted README.md for the new logger.py file added.
- Cleaned out comments
- Other minor changes

This change improves reproducibility by using UV’s lockfile and
isolated virtual environments, and centralizes all commands under
`uv run` so tools always execute in the correct env.

.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
Makefile CHANGED
@@ -1,15 +1,39 @@
1
- install:
2
- pip install --upgrade pip &&\
3
- pip install -r requirements.txt
4
 
5
- lint:
6
- pylint --disable=R,C app.py chatlib
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  test:
9
- PYTHONPATH=. pytest -vv
10
 
11
  format:
12
- black app.py chatlib
13
 
14
  run:
15
- python app.py
 
 
 
 
 
 
 
1
+ SHELL := /bin/bash
 
 
2
 
3
+ VENV := .venv
4
+
5
+ ifeq ($(OS),Windows_NT)
6
+ VENV_BIN := $(VENV)/Scripts
7
+ PYTHON := $(VENV_BIN)/python.exe
8
+ RM := del /s /q
9
+ else
10
+ VENV_BIN := $(VENV)/bin
11
+ PYTHON := $(VENV_BIN)/python
12
+ RM := rm -rf
13
+ endif
14
+
15
+ .PHONY: venv install install-dev lint test format run clean
16
+
17
+ venv:
18
+ uv venv $(VENV)
19
+
20
+ install: venv
21
+ $(VENV_BIN)/uv sync
22
+
23
+ lint:
24
+ $(VENV_BIN)/pylint --disable=R,C app.py chatlib
25
 
26
  test:
27
+ PYTHONPATH=. $(VENV_BIN)/pytest -vv
28
 
29
  format:
30
+ $(VENV_BIN)/black app.py chatlib
31
 
32
  run:
33
+ $(PYTHON) app.py
34
+
35
+ clean:
36
+ $(RM) $(VENV)
37
+ $(RM) .pytest_cache
38
+ $(RM) __pycache__
39
+ $(RM) .mypy_cache
README.md CHANGED
@@ -34,6 +34,7 @@ A conversational assistant designed to help clinicians in Kenya access patient d
34
  │ ├── assistant_node.py
35
  │ ├── guidlines_rag_agent_li.py
36
  │ ├── idsr_check.py
 
37
  │ ├── patient_all_data.py
38
  │ ├── patient_sql_agent.py
39
  │ ├── phi_filter.py
 
34
  │ ├── assistant_node.py
35
  │ ├── guidlines_rag_agent_li.py
36
  │ ├── idsr_check.py
37
+ │ ├── logger.py
38
  │ ├── patient_all_data.py
39
  │ ├── patient_sql_agent.py
40
  │ ├── phi_filter.py
app.py CHANGED
@@ -8,7 +8,7 @@ from langchain_core.messages import HumanMessage, SystemMessage
8
  from langgraph.prebuilt import tools_condition, ToolNode
9
  from langgraph.checkpoint.memory import MemorySaver
10
 
11
- # Initialize your graph and checkpointer once - eventually make this persistent
12
  memory = MemorySaver()
13
 
14
  if os.path.exists("config.env"):
@@ -29,9 +29,7 @@ def rag_retrieve_tool(query):
29
  """Retrieve relevant HIV clinical guidelines for the given query."""
30
  result = rag_retrieve(query, llm=llm)
31
  return {
32
- "rag_result": result.get(
33
- "rag_result", ""
34
- ), # adjust based on your rag_retrieve output
35
  "last_tool": "rag_retrieve",
36
  }
37
 
@@ -89,7 +87,7 @@ builder.add_edge("tools", "assistant")
89
  react_graph = builder.compile(checkpointer=memory)
90
 
91
 
92
- def chat_with_patient(question: str, thread_id: str = None):
93
  # Generate or reuse thread_id for session persistence
94
  if thread_id is None or thread_id == "":
95
  thread_id = str(uuid.uuid4())
@@ -97,8 +95,7 @@ def chat_with_patient(question: str, thread_id: str = None):
97
  # Check input for PHI and redact if necessary
98
  question = detect_and_redact_phi(question)["redacted_text"]
99
  print(question)
100
- # Prepare input state with new user message and pk_hash
101
- # initialize state with patient pk hash
102
  input_state: AppState = {
103
  "messages": [HumanMessage(content=question)],
104
  "question": "",
@@ -111,13 +108,11 @@ def chat_with_patient(question: str, thread_id: str = None):
111
 
112
  config = {"configurable": {"thread_id": thread_id, "user_id": thread_id}}
113
 
114
- # Invoke the graph with persistent state
115
- output_state = react_graph.invoke(input_state, config)
116
 
117
  for m in output_state["messages"]:
118
  m.pretty_print()
119
 
120
- # Extract the last AImessage
121
  assistant_message = output_state["messages"][-1].content
122
 
123
  return assistant_message, thread_id
@@ -125,7 +120,7 @@ def chat_with_patient(question: str, thread_id: str = None):
125
 
126
  with gr.Blocks() as app:
127
  question_input = gr.Textbox(label="Question")
128
- thread_id_state = gr.State() # to store thread_id between calls
129
  output_chat = gr.Textbox(label="Assistant Response")
130
 
131
  submit_btn = gr.Button("Ask")
 
8
  from langgraph.prebuilt import tools_condition, ToolNode
9
  from langgraph.checkpoint.memory import MemorySaver
10
 
11
+
12
  memory = MemorySaver()
13
 
14
  if os.path.exists("config.env"):
 
29
  """Retrieve relevant HIV clinical guidelines for the given query."""
30
  result = rag_retrieve(query, llm=llm)
31
  return {
32
+ "rag_result": result.get("rag_result", ""),
 
 
33
  "last_tool": "rag_retrieve",
34
  }
35
 
 
87
  react_graph = builder.compile(checkpointer=memory)
88
 
89
 
90
+ def chat_with_patient(question: str, thread_id: str = None): # type: ignore
91
  # Generate or reuse thread_id for session persistence
92
  if thread_id is None or thread_id == "":
93
  thread_id = str(uuid.uuid4())
 
95
  # Check input for PHI and redact if necessary
96
  question = detect_and_redact_phi(question)["redacted_text"]
97
  print(question)
98
+
 
99
  input_state: AppState = {
100
  "messages": [HumanMessage(content=question)],
101
  "question": "",
 
108
 
109
  config = {"configurable": {"thread_id": thread_id, "user_id": thread_id}}
110
 
111
+ output_state = react_graph.invoke(input_state, config) # type: ignore
 
112
 
113
  for m in output_state["messages"]:
114
  m.pretty_print()
115
 
 
116
  assistant_message = output_state["messages"][-1].content
117
 
118
  return assistant_message, thread_id
 
120
 
121
  with gr.Blocks() as app:
122
  question_input = gr.Textbox(label="Question")
123
+ thread_id_state = gr.State()
124
  output_chat = gr.Textbox(label="Assistant Response")
125
 
126
  submit_btn = gr.Button("Ask")
chat.py CHANGED
@@ -16,15 +16,13 @@ from chatlib.state_types import AppState
16
  from chatlib.guidlines_rag_agent_li import rag_retrieve
17
  from chatlib.patient_all_data import sql_chain
18
  from chatlib.idsr_check import idsr_check
19
-
20
- # from langchain_ollama.chat_models import ChatOllama
21
- # llm = ChatOllama(model="mistral:latest", temperature=0)
22
 
23
  tools = [rag_retrieve, sql_chain, idsr_check]
24
  llm = ChatOpenAI(temperature = 0.0, model="gpt-4o")
25
  llm_with_tools = llm.bind_tools([rag_retrieve, sql_chain, idsr_check])
26
 
27
- # System message
28
  sys_msg = SystemMessage(content="""
29
  You are a helpful assistant supporting clinicians during patient visits. You have three tools:
30
 
@@ -57,7 +55,7 @@ Do not include any text outside the JSON response.
57
  """)
58
 
59
 
60
- # Assistant Node
61
  def assistant(state: AppState) -> AppState:
62
 
63
  pk_hash = state.get("pk_hash", None)
@@ -68,9 +66,8 @@ def assistant(state: AppState) -> AppState:
68
  else:
69
  messages = [sys_msg] + state["messages"]
70
 
71
- # Get the LLM/tool response
72
  new_message = llm_with_tools.invoke(messages)
73
- # Extract the question from the latest HumanMessage, if present
74
 
75
  latest_question = ""
76
  for msg in reversed(messages):
@@ -78,39 +75,37 @@ def assistant(state: AppState) -> AppState:
78
  latest_question = msg.content
79
  break
80
 
81
- state['messages'] = state['messages'] + [new_message]
82
- state['question'] = latest_question
83
  return state
84
- # return {**state, "messages": state['messages'] + [new_message], "question": latest_question}
85
 
86
  # Graph
87
  builder = StateGraph(AppState)
88
 
89
- # Define nodes: these do the work
90
  builder.add_node("assistant", assistant)
91
  builder.add_node("tools", ToolNode(tools))
92
 
93
- # Define edges: these determine how the control flow moves
94
  builder.add_edge(START, "assistant")
95
  builder.add_conditional_edges("assistant", tools_condition)
96
  builder.add_edge("tools", "assistant")
97
  react_graph = builder.compile(checkpointer=memory)
98
 
99
- # Specify a thread
100
  config = {"configurable": {"thread_id": "30"}}
101
 
102
- # initialize state with patient pk hash
103
  input_state:AppState = {
104
  "messages": [HumanMessage(content="summarize the patient's clinical visits")],
105
  "question": "",
106
  "rag_result": "",
107
  "answer": "",
108
- "pk_hash": "962885FEADB7CCF19A2CC506D39818EC448D5396C4D1AEFDC59873090C7FBF73"
109
  }
110
 
111
 
112
- # messages = [HumanMessage(content="how many appointments has this patient had?")]
113
- message_output = react_graph.invoke(input_state, config)
114
 
115
  for m in message_output['messages']:
116
  m.pretty_print()
 
16
  from chatlib.guidlines_rag_agent_li import rag_retrieve
17
  from chatlib.patient_all_data import sql_chain
18
  from chatlib.idsr_check import idsr_check
19
+ from chatlib.logger import get_logger
 
 
20
 
21
  tools = [rag_retrieve, sql_chain, idsr_check]
22
  llm = ChatOpenAI(temperature = 0.0, model="gpt-4o")
23
  llm_with_tools = llm.bind_tools([rag_retrieve, sql_chain, idsr_check])
24
 
25
+
26
  sys_msg = SystemMessage(content="""
27
  You are a helpful assistant supporting clinicians during patient visits. You have three tools:
28
 
 
55
  """)
56
 
57
 
58
+
59
  def assistant(state: AppState) -> AppState:
60
 
61
  pk_hash = state.get("pk_hash", None)
 
66
  else:
67
  messages = [sys_msg] + state["messages"]
68
 
 
69
  new_message = llm_with_tools.invoke(messages)
70
+
71
 
72
  latest_question = ""
73
  for msg in reversed(messages):
 
75
  latest_question = msg.content
76
  break
77
 
78
+ state['messages'] = state['messages'] + [new_message] # type: ignore
79
+ state['question'] = latest_question # type: ignore
80
  return state
 
81
 
82
  # Graph
83
  builder = StateGraph(AppState)
84
 
85
+
86
  builder.add_node("assistant", assistant)
87
  builder.add_node("tools", ToolNode(tools))
88
 
89
+
90
  builder.add_edge(START, "assistant")
91
  builder.add_conditional_edges("assistant", tools_condition)
92
  builder.add_edge("tools", "assistant")
93
  react_graph = builder.compile(checkpointer=memory)
94
 
95
+
96
  config = {"configurable": {"thread_id": "30"}}
97
 
98
+
99
  input_state:AppState = {
100
  "messages": [HumanMessage(content="summarize the patient's clinical visits")],
101
  "question": "",
102
  "rag_result": "",
103
  "answer": "",
104
+ "pk_hash": "962885FEADB7CCF19A2CC506D39818EC448D5396C4D1AEFDC59873090C7FBF73" # type: ignore
105
  }
106
 
107
 
108
+ message_output = react_graph.invoke(input_state, config) # type: ignore
 
109
 
110
  for m in message_output['messages']:
111
  m.pretty_print()
chatlib/assistant_node.py CHANGED
@@ -4,7 +4,6 @@ from langchain_core.messages import ToolMessage
4
  import json
5
 
6
 
7
- # Assistant Node
8
  def assistant(state: AppState, sys_msg, llm, llm_with_tools) -> AppState:
9
 
10
  if state.get("messages") and isinstance(state["messages"][-1], ToolMessage):
@@ -14,11 +13,10 @@ def assistant(state: AppState, sys_msg, llm, llm_with_tools) -> AppState:
14
  try:
15
  tool_content_dict = json.loads(tool_content)
16
  state.update(tool_content_dict)
17
- # print("Merged tool content into state:", tool_content_dict)
18
  except json.JSONDecodeError:
19
  print("Failed to parse tool content as JSON")
20
  elif isinstance(tool_content, dict):
21
- state.update(tool_content)
22
 
23
  pk_hash = state.get("pk_hash", None)
24
 
@@ -30,20 +28,18 @@ def assistant(state: AppState, sys_msg, llm, llm_with_tools) -> AppState:
30
  else:
31
  messages = [sys_msg] + state.get("messages", [])
32
 
33
- # Extract latest human question
34
  latest_question = ""
35
  for msg in reversed(messages):
36
  if isinstance(msg, HumanMessage):
37
  latest_question = msg.content
38
  break
39
 
40
- # Generate AIMessage only if answer is new
41
  if "answer" in state and state["answer"]:
42
  if state.get("last_answer") != state["answer"]:
43
  last_tool = state.get("last_tool")
44
-
45
  if last_tool == "idsr_check":
46
-
47
  disclaimer_needed = not state.get("idsr_disclaimer_shown", False)
48
  print(disclaimer_needed)
49
  format_instructions = """
@@ -72,20 +68,22 @@ def assistant(state: AppState, sys_msg, llm, llm_with_tools) -> AppState:
72
  "Disclaimer: This is not a diagnosis. This is meant to help\n"
73
  "identify possible matches based on priority IDSR diseases for clinician awareness.\n"
74
  )
75
- state["idsr_disclaimer_shown"] = True
76
  else:
77
  disclaimer_text = ""
78
 
79
- prompt = format_instructions.format(disclaimer=disclaimer_text) + f"\n\nResponse:\n{state['answer']}"
 
 
 
80
  print("Prompt sent to LLM:\n", prompt)
81
- # Call LLM to reformat the answer
82
  llm_response = llm.invoke(prompt)
83
  formatted_answer = llm_response.content.strip()
84
 
85
  ai_message = AIMessage(content=formatted_answer)
86
 
87
- # Set the flag so disclaimer is not shown again
88
- state["idsr_disclaimer_shown"] = True
89
 
90
  else:
91
  # For other tools, use the raw answer as is
@@ -93,13 +91,12 @@ def assistant(state: AppState, sys_msg, llm, llm_with_tools) -> AppState:
93
 
94
  messages = messages + [ai_message]
95
  state["messages"] = messages
96
- state["question"] = latest_question
97
- state["last_answer"] = state["answer"] # track processed answer
98
  return state
99
 
100
- # Otherwise, normal LLM with tools invocation
101
  new_message = llm_with_tools.invoke(messages)
102
  messages = messages + [new_message]
103
  state["messages"] = messages
104
- state["question"] = latest_question
105
  return state
 
4
  import json
5
 
6
 
 
7
  def assistant(state: AppState, sys_msg, llm, llm_with_tools) -> AppState:
8
 
9
  if state.get("messages") and isinstance(state["messages"][-1], ToolMessage):
 
13
  try:
14
  tool_content_dict = json.loads(tool_content)
15
  state.update(tool_content_dict)
 
16
  except json.JSONDecodeError:
17
  print("Failed to parse tool content as JSON")
18
  elif isinstance(tool_content, dict):
19
+ state.update(tool_content) # type: ignore
20
 
21
  pk_hash = state.get("pk_hash", None)
22
 
 
28
  else:
29
  messages = [sys_msg] + state.get("messages", [])
30
 
 
31
  latest_question = ""
32
  for msg in reversed(messages):
33
  if isinstance(msg, HumanMessage):
34
  latest_question = msg.content
35
  break
36
 
 
37
  if "answer" in state and state["answer"]:
38
  if state.get("last_answer") != state["answer"]:
39
  last_tool = state.get("last_tool")
40
+
41
  if last_tool == "idsr_check":
42
+
43
  disclaimer_needed = not state.get("idsr_disclaimer_shown", False)
44
  print(disclaimer_needed)
45
  format_instructions = """
 
68
  "Disclaimer: This is not a diagnosis. This is meant to help\n"
69
  "identify possible matches based on priority IDSR diseases for clinician awareness.\n"
70
  )
71
+ state["idsr_disclaimer_shown"] = True # type: ignore
72
  else:
73
  disclaimer_text = ""
74
 
75
+ prompt = (
76
+ format_instructions.format(disclaimer=disclaimer_text)
77
+ + f"\n\nResponse:\n{state['answer']}"
78
+ )
79
  print("Prompt sent to LLM:\n", prompt)
80
+
81
  llm_response = llm.invoke(prompt)
82
  formatted_answer = llm_response.content.strip()
83
 
84
  ai_message = AIMessage(content=formatted_answer)
85
 
86
+ state["idsr_disclaimer_shown"] = True # type: ignore
 
87
 
88
  else:
89
  # For other tools, use the raw answer as is
 
91
 
92
  messages = messages + [ai_message]
93
  state["messages"] = messages
94
+ state["question"] = latest_question # type: ignore
95
+ state["last_answer"] = state["answer"]
96
  return state
97
 
 
98
  new_message = llm_with_tools.invoke(messages)
99
  messages = messages + [new_message]
100
  state["messages"] = messages
101
+ state["question"] = latest_question # type: ignore
102
  return state
chatlib/guidlines_rag_agent_li.py CHANGED
@@ -1,12 +1,11 @@
1
  from llama_index.core import StorageContext, load_index_from_storage
2
  from .state_types import AppState
3
 
4
- # Load index for retrieval
5
  storage_context = StorageContext.from_defaults(persist_dir="guidance_docs/arv_metadata")
6
  index = load_index_from_storage(storage_context)
7
  retriever = index.as_retriever(
8
  similarity_top_k=3,
9
- # Similarity threshold for filtering
10
  similarity_threshold=0.5,
11
  )
12
 
@@ -26,21 +25,6 @@ def rag_retrieve(query: str, llm) -> AppState:
26
  f"Guideline Text:\n{retrieved_text}"
27
  )
28
 
29
- # Call your LLM to generate the summary
30
  summary_response = llm.invoke(summarization_prompt)
31
 
32
- return {"rag_result": summary_response.content, "last_tool": "rag_retrieve"}
33
-
34
-
35
- # if __name__ == "__main__":
36
- # # Test the function
37
- # test_state = AppState(
38
- # messages=[],
39
- # question="What are the first-line treatments for HIV in Kenya?",
40
- # rag_result="",
41
- # query="",
42
- # result="",
43
- # answer=""
44
- # )
45
- # updated_state = rag_retrieve(test_state)
46
- # print(updated_state["rag_result"])
 
1
  from llama_index.core import StorageContext, load_index_from_storage
2
  from .state_types import AppState
3
 
4
+
5
  storage_context = StorageContext.from_defaults(persist_dir="guidance_docs/arv_metadata")
6
  index = load_index_from_storage(storage_context)
7
  retriever = index.as_retriever(
8
  similarity_top_k=3,
 
9
  similarity_threshold=0.5,
10
  )
11
 
 
25
  f"Guideline Text:\n{retrieved_text}"
26
  )
27
 
 
28
  summary_response = llm.invoke(summarization_prompt)
29
 
30
+ return {"rag_result": summary_response.content, "last_tool": "rag_retrieve"} # type: ignore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
chatlib/idsr_check.py CHANGED
@@ -10,26 +10,23 @@ import json
10
  import math
11
  from collections import Counter
12
 
13
- ## Keywords
14
- # load keywords from file
15
  with open("./guidance_docs/idsr_keywords.txt", "r", encoding="utf-8") as f:
16
  keywords = [line.strip() for line in f if line.strip()]
17
 
18
- # load vectorstore
19
  vectorstore = FAISS.load_local(
20
  "./guidance_docs/disease_vectorstore",
21
  OpenAIEmbeddings(),
22
  allow_dangerous_deserialization=True,
23
  )
24
 
25
- # load tagged documents from JSON for keyword matching to document metadata
26
  with open("./guidance_docs/tagged_documents.json", "r", encoding="utf-8") as f:
27
  doc_dicts = json.load(f)
28
 
29
  tagged_documents = [Document(**d) for d in doc_dicts]
30
 
31
- # Set up metrics for keywords
32
- # Count how many documents each keyword appears in
33
  keyword_doc_counts = Counter()
34
  total_docs = len(tagged_documents)
35
 
@@ -38,10 +35,9 @@ for doc in tagged_documents:
38
  for kw in seen:
39
  keyword_doc_counts[kw] += 1
40
 
41
- # Use log-scaled inverse frequency to avoid extreme values
42
  keyword_weights = {
43
- kw: math.log(total_docs / (1 + count)) # add 1 to avoid div-by-zero
44
- for kw, count in keyword_doc_counts.items()
45
  }
46
 
47
 
@@ -51,7 +47,6 @@ def score_doc(doc_to_score, matched_keywords):
51
  return sum(keyword_weights.get(kw, 0) for kw in overlap)
52
 
53
 
54
- ## Define helper functions
55
  class KeywordsOutput(BaseModel):
56
  keywords: List[str] = Field(
57
  description="List of relevant keywords extracted from the query"
@@ -75,7 +70,6 @@ Return the matching keywords as a JSON object with a single key "keywords" whose
75
  """
76
  )
77
 
78
- # Compose the chain as a RunnableSequence: prompt -> llm -> parser
79
  chain = prompt | llm | parser
80
 
81
  output = chain.invoke(
@@ -86,24 +80,17 @@ Return the matching keywords as a JSON object with a single key "keywords" whose
86
  }
87
  )
88
 
89
- # output is a list of strings, not a KeywordsOutput instance
90
  return output.keywords
91
 
92
 
93
- # function to perform hybrid search combining semantic search and keyword matching
94
  def hybrid_search_with_query_keywords(
95
  query, vstore, documents, keyword_list, llm, top_k=5
96
  ):
97
 
98
- # Step 1: Semantic search
99
  semantic_hits = vstore.similarity_search(query, k=top_k)
100
 
101
- # Step 2: Use GPT to extract keywords from the query
102
  matched_keywords = extract_keywords_with_gpt(query, llm, keyword_list)
103
 
104
- # print("Matched keywords:", matched_keywords)
105
-
106
- # Step 3: Filter docs whose metadata has any of those keywords
107
  keyword_hits = [
108
  doc
109
  for doc in documents
@@ -114,33 +101,22 @@ def hybrid_search_with_query_keywords(
114
  )
115
  ]
116
 
117
- # print("Keyword hits:", len(keyword_hits))
118
-
119
- # Step 4: Score keyword-matching documents by keyword rarity
120
  scored_docs = [
121
  (
122
  doc,
123
  score_doc(doc, matched_keywords),
124
- ) # original (unnormalized) list used for scoring
125
  for doc in keyword_hits
126
  ]
127
 
128
- # # print doc metadata and scores
129
- # for doc, score in scored_docs:
130
- # print(f"Document: {doc.metadata.get('disease_name', 'Unknown')}, Score: {score}")
131
- # print(f"Matched Keywords: {doc.metadata.get('matched_keywords', [])}")
132
-
133
- # Step 5: Rank and select top documents by score
134
  ranked_docs = sorted(scored_docs, key=lambda x: -x[1])
135
  top_docs = [doc for doc, score in ranked_docs if score > 0]
136
  top_3_docs = top_docs[:3]
137
 
138
- # Step 4: Merge by unique content
139
  merged = {doc.page_content: doc for doc in semantic_hits + top_3_docs}
140
  return list(merged.values())
141
 
142
 
143
- # Main function to perform the IDSR check
144
  def idsr_check(query: str, llm) -> AppState:
145
  """
146
  Perform hybrid search combining semantic search and keyword matching.
@@ -151,7 +127,7 @@ def idsr_check(query: str, llm) -> AppState:
151
  Returns:
152
  AppState: Updated state with search results.
153
  """
154
- # Perform hybrid search
155
  results = hybrid_search_with_query_keywords(
156
  query, vectorstore, tagged_documents, keywords, llm
157
  )
@@ -163,7 +139,6 @@ def idsr_check(query: str, llm) -> AppState:
163
  ]
164
  )
165
 
166
- # Prepare prompt for the LLM
167
  prompt = """
168
  You are a medical assistant reviewing a brief clinical case in Kenya to help identify which diseases the patient may plausibly have. You have access to several disease definitions.
169
 
@@ -202,7 +177,6 @@ def idsr_check(query: str, llm) -> AppState:
202
  query=query, disease_definitions=disease_definitions
203
  )
204
 
205
- # Call the LLM to generate the answer, passing the case description and disease definitions
206
  llm_response = llm.invoke(prompt)
207
  answer_text = (
208
  llm_response.content.strip()
@@ -210,4 +184,4 @@ def idsr_check(query: str, llm) -> AppState:
210
  else "No relevant disease information found."
211
  )
212
 
213
- return {"answer": answer_text, "last_tool": "idsr_check"}
 
10
  import math
11
  from collections import Counter
12
 
13
+
 
14
  with open("./guidance_docs/idsr_keywords.txt", "r", encoding="utf-8") as f:
15
  keywords = [line.strip() for line in f if line.strip()]
16
 
 
17
  vectorstore = FAISS.load_local(
18
  "./guidance_docs/disease_vectorstore",
19
  OpenAIEmbeddings(),
20
  allow_dangerous_deserialization=True,
21
  )
22
 
23
+
24
  with open("./guidance_docs/tagged_documents.json", "r", encoding="utf-8") as f:
25
  doc_dicts = json.load(f)
26
 
27
  tagged_documents = [Document(**d) for d in doc_dicts]
28
 
29
+
 
30
  keyword_doc_counts = Counter()
31
  total_docs = len(tagged_documents)
32
 
 
35
  for kw in seen:
36
  keyword_doc_counts[kw] += 1
37
 
38
+
39
  keyword_weights = {
40
+ kw: math.log(total_docs / (1 + count)) for kw, count in keyword_doc_counts.items()
 
41
  }
42
 
43
 
 
47
  return sum(keyword_weights.get(kw, 0) for kw in overlap)
48
 
49
 
 
50
  class KeywordsOutput(BaseModel):
51
  keywords: List[str] = Field(
52
  description="List of relevant keywords extracted from the query"
 
70
  """
71
  )
72
 
 
73
  chain = prompt | llm | parser
74
 
75
  output = chain.invoke(
 
80
  }
81
  )
82
 
 
83
  return output.keywords
84
 
85
 
 
86
  def hybrid_search_with_query_keywords(
87
  query, vstore, documents, keyword_list, llm, top_k=5
88
  ):
89
 
 
90
  semantic_hits = vstore.similarity_search(query, k=top_k)
91
 
 
92
  matched_keywords = extract_keywords_with_gpt(query, llm, keyword_list)
93
 
 
 
 
94
  keyword_hits = [
95
  doc
96
  for doc in documents
 
101
  )
102
  ]
103
 
 
 
 
104
  scored_docs = [
105
  (
106
  doc,
107
  score_doc(doc, matched_keywords),
108
+ )
109
  for doc in keyword_hits
110
  ]
111
 
 
 
 
 
 
 
112
  ranked_docs = sorted(scored_docs, key=lambda x: -x[1])
113
  top_docs = [doc for doc, score in ranked_docs if score > 0]
114
  top_3_docs = top_docs[:3]
115
 
 
116
  merged = {doc.page_content: doc for doc in semantic_hits + top_3_docs}
117
  return list(merged.values())
118
 
119
 
 
120
  def idsr_check(query: str, llm) -> AppState:
121
  """
122
  Perform hybrid search combining semantic search and keyword matching.
 
127
  Returns:
128
  AppState: Updated state with search results.
129
  """
130
+
131
  results = hybrid_search_with_query_keywords(
132
  query, vectorstore, tagged_documents, keywords, llm
133
  )
 
139
  ]
140
  )
141
 
 
142
  prompt = """
143
  You are a medical assistant reviewing a brief clinical case in Kenya to help identify which diseases the patient may plausibly have. You have access to several disease definitions.
144
 
 
177
  query=query, disease_definitions=disease_definitions
178
  )
179
 
 
180
  llm_response = llm.invoke(prompt)
181
  answer_text = (
182
  llm_response.content.strip()
 
184
  else "No relevant disease information found."
185
  )
186
 
187
+ return {"answer": answer_text, "last_tool": "idsr_check"} # type: ignore
chatlib/logger.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+ import os
4
+ from pathlib import Path
5
+ from logging.handlers import RotatingFileHandler
6
+
7
+
8
+ LOG_DIR = Path(os.getenv("LOG_DIR", Path(__file__).resolve().parent.parent / "logs"))
9
+ LOG_DIR.mkdir(parents=True, exist_ok=True)
10
+
11
+ LOG_FILE = LOG_DIR / "app.log"
12
+
13
+ LOG_CONFIG = {
14
+ "console_format": "%(asctime)s | %(levelname)-8s | %(name)-20s | %(message)s",
15
+ "file_format": "%(asctime)s | %(levelname)-8s | %(name)-20s | %(funcName)s:%(lineno)d | %(message)s",
16
+ "date_format": "%Y-%m-%d %H:%M:%S",
17
+ "max_bytes": 5 * 1024 * 1024, # 5MB
18
+ "backup_count": 3,
19
+ }
20
+
21
+ _logger_cache = {}
22
+
23
+
24
+ def get_logger(name: str = "text2sql-app") -> logging.Logger:
25
+ """Get a configured logger instance with console and file handlers
26
+
27
+ Args:
28
+ name: Logger name (usually __name__ of calling module)
29
+
30
+ Returns:
31
+ Configured Logger instance
32
+ """
33
+ if name in _logger_cache:
34
+ return _logger_cache[name]
35
+
36
+ logger = logging.getLogger(name)
37
+ logger.setLevel(logging.DEBUG)
38
+
39
+ if logger.hasHandlers():
40
+ _logger_cache[name] = logger
41
+ return logger
42
+
43
+ console_handler = logging.StreamHandler(sys.stdout)
44
+ console_handler.setLevel(logging.INFO)
45
+ console_formatter = logging.Formatter(
46
+ LOG_CONFIG["console_format"], LOG_CONFIG["date_format"]
47
+ )
48
+ console_handler.setFormatter(console_formatter)
49
+
50
+ file_handler = RotatingFileHandler(
51
+ LOG_FILE,
52
+ maxBytes=LOG_CONFIG["max_bytes"],
53
+ backupCount=LOG_CONFIG["backup_count"],
54
+ )
55
+ file_handler.setLevel(logging.DEBUG)
56
+ file_formatter = logging.Formatter(
57
+ LOG_CONFIG["file_format"], LOG_CONFIG["date_format"]
58
+ )
59
+ file_handler.setFormatter(file_formatter)
60
+
61
+ logger.addHandler(console_handler)
62
+ logger.addHandler(file_handler)
63
+
64
+ _logger_cache[name] = logger
65
+ return logger
chatlib/patient_all_data.py CHANGED
@@ -3,14 +3,12 @@ import pandas as pd
3
  import os
4
 
5
 
6
- # define helper functions
7
  def safe(val):
8
  if pd.isnull(val) or val in ("", "NULL"):
9
  return "missing"
10
  return val
11
 
12
 
13
- # function to return only year of date
14
  def extract_year(date_str):
15
  if pd.isnull(date_str) or date_str in ("", "NULL"):
16
  return "missing"
@@ -20,7 +18,6 @@ def extract_year(date_str):
20
  return "invalid date"
21
 
22
 
23
- # Define the SQL query tool
24
  def sql_chain(query: str, llm, rag_result: str) -> dict:
25
  """
26
  Annotated function that takes a patient identifer (pk_hash) and returns
@@ -44,7 +41,6 @@ def sql_chain(query: str, llm, rag_result: str) -> dict:
44
  conn = sqlite3.connect("data/patient_demonstration.sqlite")
45
  cursor = conn.cursor()
46
 
47
- # Write the SQL query using the QuerySQLDatabaseTool
48
  cursor.execute(
49
  "SELECT * FROM clinical_visits WHERE PatientPKHash = :pk_hash",
50
  {"pk_hash": pk_hash},
@@ -172,10 +168,6 @@ def sql_chain(query: str, llm, rag_result: str) -> dict:
172
 
173
  demographic_summary = summarize_demographics(demographic_data)
174
 
175
- # cursor.execute("SELECT * FROM data_dictionary")
176
- # rows = cursor.fetchall()
177
- # data_dictionary = pd.DataFrame(rows, columns=[column[0] for column in cursor.description])
178
-
179
  conn.close()
180
 
181
  prompt = (
 
3
  import os
4
 
5
 
 
6
  def safe(val):
7
  if pd.isnull(val) or val in ("", "NULL"):
8
  return "missing"
9
  return val
10
 
11
 
 
12
  def extract_year(date_str):
13
  if pd.isnull(date_str) or date_str in ("", "NULL"):
14
  return "missing"
 
18
  return "invalid date"
19
 
20
 
 
21
  def sql_chain(query: str, llm, rag_result: str) -> dict:
22
  """
23
  Annotated function that takes a patient identifer (pk_hash) and returns
 
41
  conn = sqlite3.connect("data/patient_demonstration.sqlite")
42
  cursor = conn.cursor()
43
 
 
44
  cursor.execute(
45
  "SELECT * FROM clinical_visits WHERE PatientPKHash = :pk_hash",
46
  {"pk_hash": pk_hash},
 
168
 
169
  demographic_summary = summarize_demographics(demographic_data)
170
 
 
 
 
 
171
  conn.close()
172
 
173
  prompt = (
chatlib/patient_sql_agent.py CHANGED
@@ -11,10 +11,7 @@ from .state_types import AppState
11
  db = SQLDatabase.from_uri("sqlite:///data/patient_demonstration.sqlite")
12
  llm = ChatOpenAI(temperature=0.0, model="gpt-4o")
13
 
14
- # from langchain_ollama.chat_models import ChatOllama
15
- # local_llm = ChatOllama(model="mistral:latest", temperature=0)
16
 
17
- # setup template for sql query tool
18
  system_message = """
19
  Given an input question, create a syntactically correct {dialect} query to
20
  run to help find the answer. The database contains the following tables and columns:
@@ -109,10 +106,7 @@ def write_query(state: AppState) -> AppState:
109
  prompt = query_prompt_template.invoke(
110
  {
111
  "dialect": db.dialect,
112
- # "top_k": 10,
113
- "table_info": db.run(
114
- "SELECT * FROM data_dictionary;"
115
- ), # db.get_table_info(),
116
  "input": state["question"],
117
  "guidelines": state.get("rag_result", "No guidelines provided."),
118
  "pk_hash": state.get("pk_hash", ""),
@@ -121,19 +115,16 @@ def write_query(state: AppState) -> AppState:
121
 
122
  structured_llm = llm.with_structured_output(QueryOutput)
123
  result = structured_llm.invoke(prompt)
124
- # query_data["query"] = result["query"]
125
- state["query"] = result["query"]
126
  return state
127
- # return {**state, "query": result["query"]}
128
 
129
 
130
  def execute_query(state: AppState) -> AppState:
131
  """Execute SQL query."""
132
 
133
  execute_query_tool = QuerySQLDatabaseTool(db=db)
134
- state["result"] = execute_query_tool.invoke(state["query"])
135
  return state
136
- # return {**state, "result": execute_query_tool.invoke(state["query"])}
137
 
138
 
139
  def generate_answer(state: AppState) -> AppState:
@@ -153,19 +144,14 @@ def generate_answer(state: AppState) -> AppState:
153
  "In that case, ignore the SQL query too and generate an answer based only on the context. \n\n"
154
  f'Question: {state["question"]}\n'
155
  f'Context: {state.get("rag_result", "No guidelines provided.")}\n'
156
- f'SQL Query: {state["query"]}\n'
157
- f'SQL Result: {state["result"]}'
158
- # f'Question: {state["question"]}\n'
159
- # f'SQL Query: {state["query"]}\n'
160
- # f'SQL Result: {state["result"]}'
161
  )
162
  response = llm.invoke(prompt)
163
- state["answer"] = response.content
164
  return state
165
- # return {**state, "answer": response.content}
166
 
167
 
168
- # now define a stateful tool that does the same thing
169
  @tool
170
  def sql_chain(state: AppState) -> dict:
171
  """
@@ -178,4 +164,4 @@ def sql_chain(state: AppState) -> dict:
178
  state = execute_query(state)
179
  state = generate_answer(state)
180
 
181
- return state
 
11
  db = SQLDatabase.from_uri("sqlite:///data/patient_demonstration.sqlite")
12
  llm = ChatOpenAI(temperature=0.0, model="gpt-4o")
13
 
 
 
14
 
 
15
  system_message = """
16
  Given an input question, create a syntactically correct {dialect} query to
17
  run to help find the answer. The database contains the following tables and columns:
 
106
  prompt = query_prompt_template.invoke(
107
  {
108
  "dialect": db.dialect,
109
+ "table_info": db.run("SELECT * FROM data_dictionary;"),
 
 
 
110
  "input": state["question"],
111
  "guidelines": state.get("rag_result", "No guidelines provided."),
112
  "pk_hash": state.get("pk_hash", ""),
 
115
 
116
  structured_llm = llm.with_structured_output(QueryOutput)
117
  result = structured_llm.invoke(prompt)
118
+ state["query"] = result["query"] # type: ignore
 
119
  return state
 
120
 
121
 
122
  def execute_query(state: AppState) -> AppState:
123
  """Execute SQL query."""
124
 
125
  execute_query_tool = QuerySQLDatabaseTool(db=db)
126
+ state["result"] = execute_query_tool.invoke(state["query"]) # type: ignore
127
  return state
 
128
 
129
 
130
  def generate_answer(state: AppState) -> AppState:
 
144
  "In that case, ignore the SQL query too and generate an answer based only on the context. \n\n"
145
  f'Question: {state["question"]}\n'
146
  f'Context: {state.get("rag_result", "No guidelines provided.")}\n'
147
+ f'SQL Query: {state["query"]}\n' # type: ignore
148
+ f'SQL Result: {state["result"]}' # type: ignore
 
 
 
149
  )
150
  response = llm.invoke(prompt)
151
+ state["answer"] = response.content # type: ignore
152
  return state
 
153
 
154
 
 
155
  @tool
156
  def sql_chain(state: AppState) -> dict:
157
  """
 
164
  state = execute_query(state)
165
  state = generate_answer(state)
166
 
167
+ return state # type: ignore
chatlib/phi_filter.py CHANGED
@@ -4,7 +4,7 @@ import dateparser.search
4
  from datetime import datetime
5
  from dateutil.relativedelta import relativedelta
6
 
7
- # List of words indicating relative dates (to filter out)
8
  RELATIVE_INDICATORS = [
9
  "ago",
10
  "later",
@@ -28,7 +28,6 @@ def is_relative_date(text_relative):
28
  return any(word in text_lower for word in RELATIVE_INDICATORS)
29
 
30
 
31
- # Load Kenyan names list (basic txt file, one name per line, all lowercase for comparison)
32
  def load_kenyan_names(filepath="data/kenyan_names.txt"):
33
  if not Path(filepath).exists():
34
  return set()
@@ -78,12 +77,10 @@ def detect_and_redact_phi(text_input):
78
 
79
  phi_detected = bool(names_found or dates_found)
80
 
81
- # Redact dates with relative descriptions
82
  for match, dt in dates_found:
83
  relative = describe_relative_date(dt)
84
  text_input = text_input.replace(match, relative)
85
 
86
- # Redact Kenyan names
87
  for name in names_found:
88
  pattern = re.compile(rf"\b{name}\b", re.IGNORECASE)
89
  text_input = pattern.sub("[name]", text_input)
 
4
  from datetime import datetime
5
  from dateutil.relativedelta import relativedelta
6
 
7
+
8
  RELATIVE_INDICATORS = [
9
  "ago",
10
  "later",
 
28
  return any(word in text_lower for word in RELATIVE_INDICATORS)
29
 
30
 
 
31
  def load_kenyan_names(filepath="data/kenyan_names.txt"):
32
  if not Path(filepath).exists():
33
  return set()
 
77
 
78
  phi_detected = bool(names_found or dates_found)
79
 
 
80
  for match, dt in dates_found:
81
  relative = describe_relative_date(dt)
82
  text_input = text_input.replace(match, relative)
83
 
 
84
  for name in names_found:
85
  pattern = re.compile(rf"\b{name}\b", re.IGNORECASE)
86
  text_input = pattern.sub("[name]", text_input)
chatlib/state_types.py CHANGED
@@ -1,33 +1,13 @@
1
- from typing_extensions import TypedDict, Annotated
2
- from typing import Optional
3
  from langchain_core.messages import AnyMessage
4
  from langgraph.graph.message import add_messages
5
 
6
- # class ConversationState(TypedDict):
7
- # question: str
8
- # answer: str
9
- # rag_result: str
10
- # pk_hash: Optional[str]
11
-
12
- # class QueryState(TypedDict):
13
- # query: str
14
- # result: Optional[str]
15
-
16
- # class AppState(TypedDict):
17
- # messages: Annotated[list[AnyMessage], add_messages]
18
- # conversation: ConversationState
19
- # query_data: QueryState
20
-
21
- # class SqlChainOutputModel(BaseModel):
22
- # messages: List[AnyMessage] = Field(...)
23
- # conversation: ConversationState = Field(...)
24
-
25
 
26
  class AppState(TypedDict):
27
  messages: Annotated[list[AnyMessage], add_messages]
28
  question: str
29
  rag_result: str
30
  answer: str
31
- last_answer: Optional[str] = None
32
- last_tool: Optional[str] = None
33
- idsr_disclaimer: bool = False
 
1
+ from typing_extensions import TypedDict, Annotated, NotRequired
 
2
  from langchain_core.messages import AnyMessage
3
  from langgraph.graph.message import add_messages
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  class AppState(TypedDict):
7
  messages: Annotated[list[AnyMessage], add_messages]
8
  question: str
9
  rag_result: str
10
  answer: str
11
+ last_answer: NotRequired[str | None]
12
+ last_tool: NotRequired[str | None]
13
+ idsr_disclaimer: NotRequired[bool]
main.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ def main():
2
+ print("Hello from clinicalassistant!")
3
+
4
+ if __name__ == "__main__":
5
+ main()
pyproject.toml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "clinicalassistant"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "black>=25.1.0",
9
+ "dateparser>=1.2.2",
10
+ "faiss-cpu>=1.11.0",
11
+ "gradio>=5.36.2",
12
+ "langchain-community>=0.3.27",
13
+ "langchain-openai>=0.3.27",
14
+ "langgraph>=0.5.2",
15
+ "llama-index>=0.12.48",
16
+ "pandas>=2.3.1",
17
+ "pylint>=3.3.7",
18
+ "python-dotenv>=1.1.1",
19
+ ]
20
+
21
+ [dependency-groups]
22
+ dev = [
23
+ "black>=25.1.0",
24
+ "mypy>=1.16.1",
25
+ "pytest>=8.4.1",
26
+ ]
requirements.txt DELETED
@@ -1,15 +0,0 @@
1
- dateparser==1.2.2
2
- gradio==5.36.2
3
- langchain_community==0.3.27
4
- langchain_core==0.3.68
5
- langchain_openai==0.3.27
6
- langgraph==0.5.2
7
- llama_index==0.12.48
8
- pandas==2.3.1
9
- pydantic==2.11.7
10
- python-dotenv==1.1.1
11
- python_dateutil==2.9.0.post0
12
- typing_extensions==4.14.1
13
- pylint==3.3.7
14
- black==25.1.0
15
- faiss-cpu==1.11.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
uv.lock ADDED
The diff for this file is too large to render. See raw diff