ClinicianAssistant / chatlib /idsr_check.py
JDFPalladium
cleaning up organization of scripts and data and updating filepaths in app to processed data
389c5f0
from .state_types import AppState
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from typing import List
from pydantic import BaseModel, Field
from langchain_core.output_parsers import PydanticOutputParser
import json
import math
from collections import Counter
import sqlite3
# import os
with open("./data/processed/idsr_keywords.txt", "r", encoding="utf-8") as f:
keywords = [line.strip() for line in f if line.strip()]
vectorstore = FAISS.load_local(
"./data/processed/disease_vectorstore",
OpenAIEmbeddings(),
allow_dangerous_deserialization=True,
)
with open("./data/processed/tagged_documents.json", "r", encoding="utf-8") as f:
doc_dicts = json.load(f)
tagged_documents = [Document(**d) for d in doc_dicts]
keyword_doc_counts = Counter()
total_docs = len(tagged_documents)
for doc in tagged_documents:
seen = set(doc.metadata.get("matched_keywords", []))
for kw in seen:
keyword_doc_counts[kw] += 1
keyword_weights = {
kw: math.log(total_docs / (1 + count)) for kw, count in keyword_doc_counts.items()
}
def score_doc(doc_to_score, matched_keywords):
doc_keywords = set(doc_to_score.metadata.get("matched_keywords", []))
overlap = doc_keywords & set(matched_keywords)
return sum(keyword_weights.get(kw, 0) for kw in overlap)
class KeywordsOutput(BaseModel):
keywords: List[str] = Field(
description="List of relevant keywords extracted from the query"
)
def extract_keywords_with_gpt(query: str, llm, known_keywords: List[str]) -> List[str]:
parser = PydanticOutputParser(pydantic_object=KeywordsOutput)
prompt = ChatPromptTemplate.from_template(
"""
You are helping identify relevant medical concepts.
Given this query: "{query}"
Select the most relevant 3-5 keywords from this list:
{keyword_list}
Return the matching keywords as a JSON object with a single key "keywords" whose value is a list of strings.
{format_instructions}
"""
)
chain = prompt | llm | parser
output = chain.invoke(
{
"query": query,
"keyword_list": ", ".join(known_keywords),
"format_instructions": parser.get_format_instructions(),
}
)
return output.keywords
def hybrid_search_with_query_keywords(
query, vstore, documents, keyword_list, llm, top_k=3
):
semantic_hits = vstore.similarity_search(query, k=top_k)
matched_keywords = extract_keywords_with_gpt(query, llm, keyword_list)
keyword_hits = [
doc
for doc in documents
if any(
kw1 == kw2
for kw1 in doc.metadata.get("matched_keywords", [])
for kw2 in matched_keywords
)
]
scored_docs = [
(
doc,
score_doc(doc, matched_keywords),
)
for doc in keyword_hits
]
ranked_docs = sorted(scored_docs, key=lambda x: -x[1])
top_docs = [doc for doc, score in ranked_docs if score > 0]
top_5_docs = top_docs[:5]
merged = {doc.page_content: doc for doc in semantic_hits + top_5_docs}
return list(merged.values())
def idsr_check(query: str, llm, sitecode) -> AppState:
"""
Perform hybrid search combining semantic search and keyword matching.
Args:
state (AppState): Application state containing the query.
Returns:
AppState: Updated state with search results.
"""
results = hybrid_search_with_query_keywords(
query, vectorstore, tagged_documents, keywords, llm
)
## prepare to get location data
# first, get sitecode from environment variable
# sitecode = os.environ.get("SITECODE")
# next, connect to location database and get county where code = sitecode
conn = sqlite3.connect("data/processed/location_data.sqlite")
county_cursor = conn.cursor()
county_cursor.execute(
"SELECT County FROM sitecode_county_xwalk WHERE Code = ?", (sitecode,)
)
county = county_cursor.fetchone()
# set up connection to location database and get EpidemicInfo for any diseases in the disease_name metadata field of the results from the hybrid search
cursor_epi = conn.cursor()
disease_names = [doc.metadata.get("disease_name") for doc in results]
placeholders = ",".join("?" * len(disease_names))
query_str = f"SELECT Disease, EpidemicInfo FROM who_bulletin WHERE Disease IN ({placeholders})"
cursor_epi.execute(query_str, disease_names)
epidemic_info = cursor_epi.fetchall()
# print(doc.metadata.get("disease_name") for doc in results)
# set up connection to location database and county-specific disease prevalence and seasonality information for any diseases in the disease_name metadata field of the results from the hybrid search
cursor_disease = conn.cursor()
if county: # Ensure county is not None
county_name = county[0]
disease_names = [doc.metadata.get("disease_name") for doc in results]
placeholders = ",".join("?" * len(disease_names))
query_str = f"SELECT County, Disease, Prevalence, Notes FROM county_disease_info WHERE County = ? AND Disease IN ({placeholders})"
cursor_disease.execute(query_str, (county_name, *disease_names))
county_info = cursor_disease.fetchall()
# Get climate information for the county from the rainy seasons table
# Get the current month
from datetime import datetime
current_month = datetime.now().strftime("%B") # Full month name, e.g. "March"
cursor_disease.execute(
"SELECT RainySeason FROM county_rainy_seasons WHERE County = ? and Month = ?",
(county_name, current_month),
)
rainy_season = cursor_disease.fetchone()
rainy_season = rainy_season[0] if rainy_season else "Unknown"
# close the connection
conn.close()
disease_definitions = "\n\n".join(
[
f"### Disease: {doc.metadata.get('disease_name', 'Unknown Disease')}:\n{doc.page_content}"
for doc in results
]
)
prompt = """
You are a medical assistant reviewing a brief clinical case in Kenya to help identify which diseases the patient may plausibly have.
You have access to several disease definitions. You also have access to information about the prevalence of each disease in the county
where the patient is located. The prevalence of some diseases varies by season, and some diseases are also more likely when there is a
declared epidemic. Information on the timing of the rainy season and any declared epidemics is also provided.
## Instructions:
1. Carefully compare the case description to each disease definition, taking into account the prevalence and seasonality information.
2. If a disease seems like a possible match based on the available information, list it and explain why.
3. Only include rare diseases, or diseases that don't fit seasonally, if the match is extremely strong. Prioritize common and plausible conditions.
4. Only list diseases if there are plausible matches based on the case and context. If no plausible matches are found, do not list any diseases.
5. If the information provided is insufficient or ambiguous, prioritize asking clarifying questions before making any recommendations.
6. Clarifying questions may include inquiries about specific symptoms, patient demographics, exposures, travel history, or other relevant clinical details.
7. Provide a brief recommendation on next steps only if confident matches are identified or after clarifications are obtained.
## Case:
{query}
## Diseases:
{disease_definitions}
## Locational context:
In {county_name}, the current rainy season status is {rainy_season}.
The above diseases have the following prevalence in the county where the patient is located:
{county_info}
Here are any relevant epidemic alerts for these diseases:
{epidemic_info}
## Expected Output
If applicable, list possible disease matches with explanations.
If needed, list clarifying questions to better understand the case.
Provide a brief recommendation on next steps if appropriate.
If no matches or recommendations are possible, focus on clarifying questions.
""".format(
query=query,
disease_definitions=disease_definitions,
county_name=county_name if county else "Unknown County",
rainy_season=rainy_season if county else "Unknown",
county_info=(
"\n".join(
[
f"- {row[0]}, {row[1]}, Prevalence: {row[2]}, Notes: {row[3]}"
for row in county_info
]
)
if county
else "No county information available."
),
epidemic_info=(
"\n".join([f"- {row[0]}: {row[1]}" for row in epidemic_info])
if epidemic_info
else "No epidemic information available."
),
)
print(prompt)
llm_response = llm.invoke(prompt)
answer_text = (
llm_response.content.strip()
if llm_response
else "No relevant disease information found."
)
# Set up context to return.
# First, use an LLM to identify which diseases from disease_definitions were mentioned in the answer_text
disease_names_in_answer = [
doc.metadata.get("disease_name")
for doc in results
if doc.metadata.get("disease_name") in answer_text
]
# Next, filter the results to only include those diseases
filtered_results = [
doc
for doc in results
if doc.metadata.get("disease_name") in disease_names_in_answer
]
# Finally, create context string with only those diseases, plus any county_info and epidemic_info
context_parts = []
if filtered_results:
context_parts.append(
"### Disease Definitions:\n"
+ "\n\n".join(
[
f"### Disease: {doc.metadata.get('disease_name', 'Unknown Disease')}:\n{doc.page_content}"
for doc in filtered_results
]
)
)
if county and county_info:
context_parts.append(
"### County Disease Information:\n"
+ "\n".join(
[
f"- {row[0]}, {row[1]}, Prevalence: {row[2]}, Seasonality: {row[3]}"
for row in county_info
]
)
)
if epidemic_info:
context_parts.append(
"### Epidemic Information:\n"
+ "\n".join([f"- {row[0]}: {row[1]}" for row in epidemic_info])
)
return {"answer": answer_text, "last_tool": "idsr_check", "context": context_parts} # type: ignore