JDFPalladium commited on
Commit
e670011
·
1 Parent(s): 389c5f0

reducing chunks retrieved and adding retrieved chunks to display

Browse files
app.py CHANGED
@@ -31,6 +31,7 @@ def rag_retrieve_tool(query):
31
  result = rag_retrieve(query, llm=llm)
32
  return {
33
  "rag_result": result.get("rag_result", ""),
 
34
  "last_tool": "rag_retrieve",
35
  }
36
 
@@ -164,7 +165,7 @@ def chat_with_patient(question: str, patient_id: str, sitecode: str, thread_id:
164
 
165
  assistant_message = output_state["messages"][-1].content
166
 
167
- return assistant_message, thread_id
168
 
169
  def init_session():
170
  return str(uuid.uuid4())
@@ -195,16 +196,19 @@ with gr.Blocks() as app:
195
  label="Sitecode",
196
  )
197
 
 
198
  question_input = gr.Textbox(label="Question")
199
  thread_id_state = gr.State(init_session())
200
  output_chat = gr.Textbox(label="Assistant Response")
201
 
 
 
202
  submit_btn = gr.Button("Ask")
203
 
204
  submit_btn.click( # pylint: disable=no-member
205
  chat_with_patient,
206
  inputs=[question_input, id_selected, sitecode_selection, thread_id_state],
207
- outputs=[output_chat, thread_id_state],
208
  )
209
 
210
  app.launch(
 
31
  result = rag_retrieve(query, llm=llm)
32
  return {
33
  "rag_result": result.get("rag_result", ""),
34
+ "rag_sources": result.get("rag_sources", []),
35
  "last_tool": "rag_retrieve",
36
  }
37
 
 
165
 
166
  assistant_message = output_state["messages"][-1].content
167
 
168
+ return assistant_message, thread_id, output_state.get("rag_sources", "")
169
 
170
  def init_session():
171
  return str(uuid.uuid4())
 
196
  label="Sitecode",
197
  )
198
 
199
+ gr.Markdown("### Ask a Clinical Question")
200
  question_input = gr.Textbox(label="Question")
201
  thread_id_state = gr.State(init_session())
202
  output_chat = gr.Textbox(label="Assistant Response")
203
 
204
+ retrieved_sources_display = gr.HTML(label="Retrieved Sources (if applicable)")
205
+
206
  submit_btn = gr.Button("Ask")
207
 
208
  submit_btn.click( # pylint: disable=no-member
209
  chat_with_patient,
210
  inputs=[question_input, id_selected, sitecode_selection, thread_id_state],
211
+ outputs=[output_chat, thread_id_state, retrieved_sources_display],
212
  )
213
 
214
  app.launch(
chatlib/assistant_node.py CHANGED
@@ -39,6 +39,7 @@ def assistant(state: AppState, sys_msg, llm, llm_with_tools) -> AppState:
39
  state.setdefault("pk_hash", "")
40
  state.setdefault("sitecode", "")
41
  state.setdefault("rag_result", "")
 
42
  state.setdefault("answer", "")
43
  state.setdefault("last_answer", None)
44
  state.setdefault("last_user_message", None)
@@ -176,9 +177,14 @@ def assistant(state: AppState, sys_msg, llm, llm_with_tools) -> AppState:
176
  elif state.get("rag_result"):
177
  # Use conversation history + a system message to inject RAG guidance
178
  rag_msg = SystemMessage(
179
- content="The following clinical guidelines may help answer the user's question:\n\n"
180
- f"{state['rag_result']}\n\n"
181
- "Use this information when responding."
 
 
 
 
 
182
  )
183
  messages_with_rag = messages + [rag_msg]
184
  llm_response = llm.invoke(messages_with_rag)
 
39
  state.setdefault("pk_hash", "")
40
  state.setdefault("sitecode", "")
41
  state.setdefault("rag_result", "")
42
+ state.setdefault("rag_sources", "")
43
  state.setdefault("answer", "")
44
  state.setdefault("last_answer", None)
45
  state.setdefault("last_user_message", None)
 
177
  elif state.get("rag_result"):
178
  # Use conversation history + a system message to inject RAG guidance
179
  rag_msg = SystemMessage(
180
+ content = (
181
+ "Based on the following clinical guideline excerpts, answer the clinician's question as precisely as possible.\n\n"
182
+ "Focus only on information that directly addresses the question.\n"
183
+ "Do not include background or general recommendations unless they are explicitly relevant.\n\n"
184
+ "Guideline excerpts:\n"
185
+ f"{state['rag_result']}\n\n"
186
+ "Respond with a focused summary tailored to the question about advanced HIV disease."
187
+ )
188
  )
189
  messages_with_rag = messages + [rag_msg]
190
  llm_response = llm.invoke(messages_with_rag)
chatlib/guidlines_rag_agent_li.py CHANGED
@@ -18,7 +18,7 @@ embedding_model = OpenAIEmbedding()
18
  llm_llama = OpenAI(model="gpt-4o", temperature=0.0)
19
 
20
  # Create LLM reranker
21
- reranker = LLMRerank(llm=llm_llama, top_n=5)
22
 
23
  # Define a prompt template for query expansion
24
  query_expansion_prompt = ChatPromptTemplate.from_messages([
@@ -49,6 +49,20 @@ def cosine_similarity_numpy(query_vec: np.ndarray, matrix: np.ndarray) -> np.nda
49
  # Dot product gives cosine similarity
50
  return matrix_norm @ query_norm
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def rag_retrieve(query: str, llm) -> AppState:
53
  """Perform RAG search of repository containing authoritative information on HIV/AIDS in Kenya."""
54
 
@@ -59,7 +73,7 @@ def rag_retrieve(query: str, llm) -> AppState:
59
  # Embed the expanded query and find similar summaries
60
  query_embedding = embedding_model.get_text_embedding(expanded_query)
61
  similarities = cosine_similarity_numpy(query_embedding, embeddings)
62
- top_indices = similarities.argsort()[-5:][::-1]
63
  selected_paths = df.loc[top_indices, "vectorestore_path"].tolist()
64
  print(f"Selected paths for retrieval: {selected_paths}")
65
 
@@ -81,6 +95,7 @@ def rag_retrieve(query: str, llm) -> AppState:
81
  "rag_result": "No relevant information found in the sources. Please try rephrasing your question.",
82
  "last_tool": "rag_retrieve"
83
  }
 
84
  retrieved_text = "\n\n".join([
85
  f"Source {i+1}: {source.text}" for i, source in enumerate(sources)
86
  ])
@@ -97,4 +112,7 @@ def rag_retrieve(query: str, llm) -> AppState:
97
  print("Prompt length in characters:", len(summarization_prompt))
98
  summary_response = llm.invoke(summarization_prompt)
99
 
100
- return {"rag_result": summary_response.content, "last_tool": "rag_retrieve"} # type: ignore
 
 
 
 
18
  llm_llama = OpenAI(model="gpt-4o", temperature=0.0)
19
 
20
  # Create LLM reranker
21
+ reranker = LLMRerank(llm=llm_llama, top_n=3)
22
 
23
  # Define a prompt template for query expansion
24
  query_expansion_prompt = ChatPromptTemplate.from_messages([
 
49
  # Dot product gives cosine similarity
50
  return matrix_norm @ query_norm
51
 
52
+ def format_sources_for_html(sources):
53
+ html_blocks = []
54
+ for i, source in enumerate(sources):
55
+ text = source.text.replace("\n", "<br>").strip()
56
+ block = f"""
57
+ <details style='margin-bottom: 1em;'>
58
+ <summary><strong>Source {i+1}</strong></summary>
59
+ <div style='margin-top: 0.5em; font-family: monospace;'>{text}</div>
60
+ </details>
61
+ """
62
+ html_blocks.append(block)
63
+ return "\n".join(html_blocks)
64
+
65
+
66
  def rag_retrieve(query: str, llm) -> AppState:
67
  """Perform RAG search of repository containing authoritative information on HIV/AIDS in Kenya."""
68
 
 
73
  # Embed the expanded query and find similar summaries
74
  query_embedding = embedding_model.get_text_embedding(expanded_query)
75
  similarities = cosine_similarity_numpy(query_embedding, embeddings)
76
+ top_indices = similarities.argsort()[-3:][::-1]
77
  selected_paths = df.loc[top_indices, "vectorestore_path"].tolist()
78
  print(f"Selected paths for retrieval: {selected_paths}")
79
 
 
95
  "rag_result": "No relevant information found in the sources. Please try rephrasing your question.",
96
  "last_tool": "rag_retrieve"
97
  }
98
+ # Format the retrieved sources for the response (and remove lengthy white space or repeated dashes)
99
  retrieved_text = "\n\n".join([
100
  f"Source {i+1}: {source.text}" for i, source in enumerate(sources)
101
  ])
 
112
  print("Prompt length in characters:", len(summarization_prompt))
113
  summary_response = llm.invoke(summarization_prompt)
114
 
115
+ return {"rag_result": summary_response.content,
116
+ "rag_sources": format_sources_for_html(sources),
117
+ "last_tool": "rag_retrieve"
118
+ } # type: ignore
chatlib/state_types.py CHANGED
@@ -11,6 +11,7 @@ class AppState(TypedDict):
11
  pk_hash: str
12
  sitecode: str
13
  rag_result: str
 
14
  answer: str
15
  last_answer: Optional[str] = None
16
  last_user_message: Optional[str] = None
 
11
  pk_hash: str
12
  sitecode: str
13
  rag_result: str
14
+ rag_sources: Optional[str] # Added to store retrieved sources
15
  answer: str
16
  last_answer: Optional[str] = None
17
  last_user_message: Optional[str] = None