Spaces:
Sleeping
Sleeping
JDFPalladium commited on
Commit ·
e670011
1
Parent(s): 389c5f0
reducing chunks retrieved and adding retrieved chunks to display
Browse files- app.py +6 -2
- chatlib/assistant_node.py +9 -3
- chatlib/guidlines_rag_agent_li.py +21 -3
- chatlib/state_types.py +1 -0
app.py
CHANGED
|
@@ -31,6 +31,7 @@ def rag_retrieve_tool(query):
|
|
| 31 |
result = rag_retrieve(query, llm=llm)
|
| 32 |
return {
|
| 33 |
"rag_result": result.get("rag_result", ""),
|
|
|
|
| 34 |
"last_tool": "rag_retrieve",
|
| 35 |
}
|
| 36 |
|
|
@@ -164,7 +165,7 @@ def chat_with_patient(question: str, patient_id: str, sitecode: str, thread_id:
|
|
| 164 |
|
| 165 |
assistant_message = output_state["messages"][-1].content
|
| 166 |
|
| 167 |
-
return assistant_message, thread_id
|
| 168 |
|
| 169 |
def init_session():
|
| 170 |
return str(uuid.uuid4())
|
|
@@ -195,16 +196,19 @@ with gr.Blocks() as app:
|
|
| 195 |
label="Sitecode",
|
| 196 |
)
|
| 197 |
|
|
|
|
| 198 |
question_input = gr.Textbox(label="Question")
|
| 199 |
thread_id_state = gr.State(init_session())
|
| 200 |
output_chat = gr.Textbox(label="Assistant Response")
|
| 201 |
|
|
|
|
|
|
|
| 202 |
submit_btn = gr.Button("Ask")
|
| 203 |
|
| 204 |
submit_btn.click( # pylint: disable=no-member
|
| 205 |
chat_with_patient,
|
| 206 |
inputs=[question_input, id_selected, sitecode_selection, thread_id_state],
|
| 207 |
-
outputs=[output_chat, thread_id_state],
|
| 208 |
)
|
| 209 |
|
| 210 |
app.launch(
|
|
|
|
| 31 |
result = rag_retrieve(query, llm=llm)
|
| 32 |
return {
|
| 33 |
"rag_result": result.get("rag_result", ""),
|
| 34 |
+
"rag_sources": result.get("rag_sources", []),
|
| 35 |
"last_tool": "rag_retrieve",
|
| 36 |
}
|
| 37 |
|
|
|
|
| 165 |
|
| 166 |
assistant_message = output_state["messages"][-1].content
|
| 167 |
|
| 168 |
+
return assistant_message, thread_id, output_state.get("rag_sources", "")
|
| 169 |
|
| 170 |
def init_session():
|
| 171 |
return str(uuid.uuid4())
|
|
|
|
| 196 |
label="Sitecode",
|
| 197 |
)
|
| 198 |
|
| 199 |
+
gr.Markdown("### Ask a Clinical Question")
|
| 200 |
question_input = gr.Textbox(label="Question")
|
| 201 |
thread_id_state = gr.State(init_session())
|
| 202 |
output_chat = gr.Textbox(label="Assistant Response")
|
| 203 |
|
| 204 |
+
retrieved_sources_display = gr.HTML(label="Retrieved Sources (if applicable)")
|
| 205 |
+
|
| 206 |
submit_btn = gr.Button("Ask")
|
| 207 |
|
| 208 |
submit_btn.click( # pylint: disable=no-member
|
| 209 |
chat_with_patient,
|
| 210 |
inputs=[question_input, id_selected, sitecode_selection, thread_id_state],
|
| 211 |
+
outputs=[output_chat, thread_id_state, retrieved_sources_display],
|
| 212 |
)
|
| 213 |
|
| 214 |
app.launch(
|
chatlib/assistant_node.py
CHANGED
|
@@ -39,6 +39,7 @@ def assistant(state: AppState, sys_msg, llm, llm_with_tools) -> AppState:
|
|
| 39 |
state.setdefault("pk_hash", "")
|
| 40 |
state.setdefault("sitecode", "")
|
| 41 |
state.setdefault("rag_result", "")
|
|
|
|
| 42 |
state.setdefault("answer", "")
|
| 43 |
state.setdefault("last_answer", None)
|
| 44 |
state.setdefault("last_user_message", None)
|
|
@@ -176,9 +177,14 @@ def assistant(state: AppState, sys_msg, llm, llm_with_tools) -> AppState:
|
|
| 176 |
elif state.get("rag_result"):
|
| 177 |
# Use conversation history + a system message to inject RAG guidance
|
| 178 |
rag_msg = SystemMessage(
|
| 179 |
-
content=
|
| 180 |
-
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
)
|
| 183 |
messages_with_rag = messages + [rag_msg]
|
| 184 |
llm_response = llm.invoke(messages_with_rag)
|
|
|
|
| 39 |
state.setdefault("pk_hash", "")
|
| 40 |
state.setdefault("sitecode", "")
|
| 41 |
state.setdefault("rag_result", "")
|
| 42 |
+
state.setdefault("rag_sources", "")
|
| 43 |
state.setdefault("answer", "")
|
| 44 |
state.setdefault("last_answer", None)
|
| 45 |
state.setdefault("last_user_message", None)
|
|
|
|
| 177 |
elif state.get("rag_result"):
|
| 178 |
# Use conversation history + a system message to inject RAG guidance
|
| 179 |
rag_msg = SystemMessage(
|
| 180 |
+
content = (
|
| 181 |
+
"Based on the following clinical guideline excerpts, answer the clinician's question as precisely as possible.\n\n"
|
| 182 |
+
"Focus only on information that directly addresses the question.\n"
|
| 183 |
+
"Do not include background or general recommendations unless they are explicitly relevant.\n\n"
|
| 184 |
+
"Guideline excerpts:\n"
|
| 185 |
+
f"{state['rag_result']}\n\n"
|
| 186 |
+
"Respond with a focused summary tailored to the question about advanced HIV disease."
|
| 187 |
+
)
|
| 188 |
)
|
| 189 |
messages_with_rag = messages + [rag_msg]
|
| 190 |
llm_response = llm.invoke(messages_with_rag)
|
chatlib/guidlines_rag_agent_li.py
CHANGED
|
@@ -18,7 +18,7 @@ embedding_model = OpenAIEmbedding()
|
|
| 18 |
llm_llama = OpenAI(model="gpt-4o", temperature=0.0)
|
| 19 |
|
| 20 |
# Create LLM reranker
|
| 21 |
-
reranker = LLMRerank(llm=llm_llama, top_n=
|
| 22 |
|
| 23 |
# Define a prompt template for query expansion
|
| 24 |
query_expansion_prompt = ChatPromptTemplate.from_messages([
|
|
@@ -49,6 +49,20 @@ def cosine_similarity_numpy(query_vec: np.ndarray, matrix: np.ndarray) -> np.nda
|
|
| 49 |
# Dot product gives cosine similarity
|
| 50 |
return matrix_norm @ query_norm
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
def rag_retrieve(query: str, llm) -> AppState:
|
| 53 |
"""Perform RAG search of repository containing authoritative information on HIV/AIDS in Kenya."""
|
| 54 |
|
|
@@ -59,7 +73,7 @@ def rag_retrieve(query: str, llm) -> AppState:
|
|
| 59 |
# Embed the expanded query and find similar summaries
|
| 60 |
query_embedding = embedding_model.get_text_embedding(expanded_query)
|
| 61 |
similarities = cosine_similarity_numpy(query_embedding, embeddings)
|
| 62 |
-
top_indices = similarities.argsort()[-
|
| 63 |
selected_paths = df.loc[top_indices, "vectorestore_path"].tolist()
|
| 64 |
print(f"Selected paths for retrieval: {selected_paths}")
|
| 65 |
|
|
@@ -81,6 +95,7 @@ def rag_retrieve(query: str, llm) -> AppState:
|
|
| 81 |
"rag_result": "No relevant information found in the sources. Please try rephrasing your question.",
|
| 82 |
"last_tool": "rag_retrieve"
|
| 83 |
}
|
|
|
|
| 84 |
retrieved_text = "\n\n".join([
|
| 85 |
f"Source {i+1}: {source.text}" for i, source in enumerate(sources)
|
| 86 |
])
|
|
@@ -97,4 +112,7 @@ def rag_retrieve(query: str, llm) -> AppState:
|
|
| 97 |
print("Prompt length in characters:", len(summarization_prompt))
|
| 98 |
summary_response = llm.invoke(summarization_prompt)
|
| 99 |
|
| 100 |
-
return {"rag_result": summary_response.content,
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
llm_llama = OpenAI(model="gpt-4o", temperature=0.0)
|
| 19 |
|
| 20 |
# Create LLM reranker
|
| 21 |
+
reranker = LLMRerank(llm=llm_llama, top_n=3)
|
| 22 |
|
| 23 |
# Define a prompt template for query expansion
|
| 24 |
query_expansion_prompt = ChatPromptTemplate.from_messages([
|
|
|
|
| 49 |
# Dot product gives cosine similarity
|
| 50 |
return matrix_norm @ query_norm
|
| 51 |
|
| 52 |
+
def format_sources_for_html(sources):
|
| 53 |
+
html_blocks = []
|
| 54 |
+
for i, source in enumerate(sources):
|
| 55 |
+
text = source.text.replace("\n", "<br>").strip()
|
| 56 |
+
block = f"""
|
| 57 |
+
<details style='margin-bottom: 1em;'>
|
| 58 |
+
<summary><strong>Source {i+1}</strong></summary>
|
| 59 |
+
<div style='margin-top: 0.5em; font-family: monospace;'>{text}</div>
|
| 60 |
+
</details>
|
| 61 |
+
"""
|
| 62 |
+
html_blocks.append(block)
|
| 63 |
+
return "\n".join(html_blocks)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
def rag_retrieve(query: str, llm) -> AppState:
|
| 67 |
"""Perform RAG search of repository containing authoritative information on HIV/AIDS in Kenya."""
|
| 68 |
|
|
|
|
| 73 |
# Embed the expanded query and find similar summaries
|
| 74 |
query_embedding = embedding_model.get_text_embedding(expanded_query)
|
| 75 |
similarities = cosine_similarity_numpy(query_embedding, embeddings)
|
| 76 |
+
top_indices = similarities.argsort()[-3:][::-1]
|
| 77 |
selected_paths = df.loc[top_indices, "vectorestore_path"].tolist()
|
| 78 |
print(f"Selected paths for retrieval: {selected_paths}")
|
| 79 |
|
|
|
|
| 95 |
"rag_result": "No relevant information found in the sources. Please try rephrasing your question.",
|
| 96 |
"last_tool": "rag_retrieve"
|
| 97 |
}
|
| 98 |
+
# Format the retrieved sources for the response (and remove lengthy white space or repeated dashes)
|
| 99 |
retrieved_text = "\n\n".join([
|
| 100 |
f"Source {i+1}: {source.text}" for i, source in enumerate(sources)
|
| 101 |
])
|
|
|
|
| 112 |
print("Prompt length in characters:", len(summarization_prompt))
|
| 113 |
summary_response = llm.invoke(summarization_prompt)
|
| 114 |
|
| 115 |
+
return {"rag_result": summary_response.content,
|
| 116 |
+
"rag_sources": format_sources_for_html(sources),
|
| 117 |
+
"last_tool": "rag_retrieve"
|
| 118 |
+
} # type: ignore
|
chatlib/state_types.py
CHANGED
|
@@ -11,6 +11,7 @@ class AppState(TypedDict):
|
|
| 11 |
pk_hash: str
|
| 12 |
sitecode: str
|
| 13 |
rag_result: str
|
|
|
|
| 14 |
answer: str
|
| 15 |
last_answer: Optional[str] = None
|
| 16 |
last_user_message: Optional[str] = None
|
|
|
|
| 11 |
pk_hash: str
|
| 12 |
sitecode: str
|
| 13 |
rag_result: str
|
| 14 |
+
rag_sources: Optional[str] # Added to store retrieved sources
|
| 15 |
answer: str
|
| 16 |
last_answer: Optional[str] = None
|
| 17 |
last_user_message: Optional[str] = None
|