JDFPalladium commited on
Commit
0669c52
·
1 Parent(s): dc55b14

improving reranking and expanding system prompt for HIV adjacent questions

Browse files
Files changed (3) hide show
  1. app.py +57 -9
  2. requirements.txt +1 -1
  3. utils/helpers.py +8 -1
app.py CHANGED
@@ -3,6 +3,8 @@
3
  #%%
4
  # Import libraries
5
  import os
 
 
6
  from lingua import Language, LanguageDetectorBuilder
7
  import gradio as gr
8
  from openai import OpenAI as OpenAIOG
@@ -27,20 +29,41 @@ client = OpenAIOG()
27
  # Load index for retrieval
28
  storage_context = StorageContext.from_defaults(persist_dir="arv_metadata")
29
  index = load_index_from_storage(storage_context)
30
- retriever = index.as_retriever(similarity_top_k=10,
31
- # Similarity threshold for filtering
32
- similarity_threshold=0.5,
33
- # Use LLM reranking to filter results
34
- reranker=LLMRerank(top_n=3))
35
 
36
  #%%
37
  # Define Gradio function
38
- def nishauri(question, conversation_history: list[str]):
39
 
40
  """Process user query, detect language, handle greetings, acknowledgments, and retrieve relevant information."""
41
  # context = " ".join([item["user"] + " " + item["chatbot"] for item in conversation_history])
42
  # formatted_history = convert_conversation_format(conversation_history)
43
  # summary = summarize_conversation(formatted_history)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  # detect language of user
46
  lang_question = helpers.detect_language(question, Language, LanguageDetectorBuilder, client)
@@ -78,8 +101,29 @@ def nishauri(question, conversation_history: list[str]):
78
  question = GoogleTranslator(source='sw', target='en').translate(question)
79
 
80
  # Retrieve relevant sources
81
- sources = retriever.retrieve(question)
82
- retrieved_text = "\n\n".join([f"Source {i+1}: {source.text}" for i, source in enumerate(sources[:3])])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  # Combine into new user question - conversation history, new question, retrieved sources
85
  question_final = (
@@ -96,7 +140,8 @@ def nishauri(question, conversation_history: list[str]):
96
  "You are a helpful assistant who only answers questions about HIV.\n"
97
  "- Only answers questions about HIV (Human Immunodeficiency Virus).\n"
98
  "- Recognize that users may type 'HIV' with any capitalization (e.g., HIV, hiv, Hiv, etc.) or make minor typos (e.g., hvi, hiv/aids).\n"
99
- "- Use your best judgment to understand when a user intends to refer to HIV. Politely correct any significant misunderstandings, but otherwise proceed to answer normally.\n"
 
100
  "- Do not answer questions about other topics (e.g., malaria or tuberculosis).\n"
101
  "- If a question is unrelated to HIV, politely respond that you can only answer HIV-related questions.\n\n"
102
 
@@ -115,6 +160,9 @@ def nishauri(question, conversation_history: list[str]):
115
  "- A suppressed viral load is one below 200 copies/ml.\n\n"
116
  )
117
 
 
 
 
118
  # Start with context
119
  messages = [{"role": "system", "content": system_prompt}]
120
 
 
3
  #%%
4
  # Import libraries
5
  import os
6
+ import json
7
+ from datetime import datetime
8
  from lingua import Language, LanguageDetectorBuilder
9
  import gradio as gr
10
  from openai import OpenAI as OpenAIOG
 
29
  # Load index for retrieval
30
  storage_context = StorageContext.from_defaults(persist_dir="arv_metadata")
31
  index = load_index_from_storage(storage_context)
32
+ # retriever = index.as_retriever(similarity_top_k=10,
33
+ # # Similarity threshold for filtering
34
+ # similarity_threshold=0.5)
 
 
35
 
36
  #%%
37
  # Define Gradio function
38
+ def nishauri(user_params: str, conversation_history: list[str]):
39
 
40
  """Process user query, detect language, handle greetings, acknowledgments, and retrieve relevant information."""
41
  # context = " ".join([item["user"] + " " + item["chatbot"] for item in conversation_history])
42
  # formatted_history = convert_conversation_format(conversation_history)
43
  # summary = summarize_conversation(formatted_history)
44
+ user_params = json.loads(user_params)
45
+
46
+ # Extract user information
47
+ consent = user_params.get("CONSENT")
48
+ person_info = user_params.get("PERSON_INFO", {})
49
+ gender = person_info.get("GENDER", "")
50
+ age = person_info.get("AGE", "")
51
+ vl_result = person_info.get("VIRAL_LOAD", "")
52
+ vl_date = helpers.convert_to_date(person_info.get("VIRAL_LOAD_DATETIME", ""), datetime)
53
+ next_appt_date = helpers.convert_to_date(person_info.get("APPOINTMENT_DATETIME", ""), datetime)
54
+ regimen = person_info.get("REGIMEN", "")
55
+ question = user_params.get("QUESTION", "")
56
+
57
+ info_pieces = [
58
+ "Here is information about the person asking the question."
59
+ f"The person is {gender}." if gender else "",
60
+ f"The person is age {age}." if age else "",
61
+ f"The person's next clinical check-in is scheduled for {next_appt_date}." if next_appt_date else "",
62
+ f"The person is on the following regimen for HIV: {regimen}." if regimen else "",
63
+ f"The person's most recent viral load result was {vl_result}." if vl_result else "",
64
+ f"The person's most recent viral load was taken on {vl_date}." if vl_date else "",
65
+ ]
66
+ full_text = " ".join(filter(None, info_pieces))
67
 
68
  # detect language of user
69
  lang_question = helpers.detect_language(question, Language, LanguageDetectorBuilder, client)
 
101
  question = GoogleTranslator(source='sw', target='en').translate(question)
102
 
103
  # Retrieve relevant sources
104
+ # sources = retriever.retrieve(question)
105
+ # Summarize the conversation history
106
+ history_summary = " ".join(
107
+ [f"User: {turn['user']} Assistant: {turn['chatbot']}" for turn in conversation_history]
108
+ )
109
+ query_with_context = f"Current question: {question}\n\nSummary of prior context: {history_summary}"
110
+
111
+ # Initialize the LLMRerank postprocessor
112
+ reranker = LLMRerank(top_n=3)
113
+
114
+ # Attach the reranker to the retriever
115
+ retriever_with_rerank = index.as_retriever(
116
+ similarity_top_k=10,
117
+ similarity_threshold=0.6,
118
+ postprocessors=[reranker]
119
+ )
120
+
121
+ # Retrieve and re-rank sources
122
+ sources = retriever_with_rerank.retrieve(query_with_context)
123
+
124
+ # Combine the top-ranked sources
125
+ retrieved_text = "\n\n".join([f"Source {i+1}: {source.text}" for i, source in enumerate(sources)])
126
+
127
 
128
  # Combine into new user question - conversation history, new question, retrieved sources
129
  question_final = (
 
140
  "You are a helpful assistant who only answers questions about HIV.\n"
141
  "- Only answers questions about HIV (Human Immunodeficiency Virus).\n"
142
  "- Recognize that users may type 'HIV' with any capitalization (e.g., HIV, hiv, Hiv, etc.) or make minor typos (e.g., hvi, hiv/aids).\n"
143
+ "- If a question is ambiguous or might be indirectly related to HIV (e.g., symptoms, illness, or general health concerns), assume it could be relevant to HIV and respond accordingly.\n"
144
+ "- If a question is about using the Nishauri app, such as finding viral load results, regimen details, or the next appointment, provide clear instructions on how to navigate the app to find this information.\n"
145
  "- Do not answer questions about other topics (e.g., malaria or tuberculosis).\n"
146
  "- If a question is unrelated to HIV, politely respond that you can only answer HIV-related questions.\n\n"
147
 
 
160
  "- A suppressed viral load is one below 200 copies/ml.\n\n"
161
  )
162
 
163
+ if consent == "YES":
164
+ system_prompt = f"{system_prompt} {full_text}."
165
+
166
  # Start with context
167
  messages = [{"role": "system", "content": system_prompt}]
168
 
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  gradio==4.44.1
2
- llama_index==0.10.51
3
  langdetect==1.0.9
4
  deep-translator==1.11.4
5
  lingua-language-detector==2.0.2
 
1
  gradio==4.44.1
2
+ llama_index==0.12.34
3
  langdetect==1.0.9
4
  deep-translator==1.11.4
5
  lingua-language-detector==2.0.2
utils/helpers.py CHANGED
@@ -100,4 +100,11 @@ def detect_intention(user_input, client):
100
  temperature=0 # for deterministic output
101
  )
102
 
103
- return completion.choices[0].message.content
 
 
 
 
 
 
 
 
100
  temperature=0 # for deterministic output
101
  )
102
 
103
+ return completion.choices[0].message.content
104
+
105
+ def convert_to_date(date_str, datetime):
106
+ """Convert date string in YYYYMMDD format to YYYY-MM-DD."""
107
+ try:
108
+ return datetime.strptime(date_str, "%Y%m%d").strftime("%Y-%m-%d")
109
+ except ValueError:
110
+ return "Unknown Date"