JDFPalladium commited on
Commit ·
0669c52
1
Parent(s): dc55b14
improving reranking and expanding system prompt for HIV adjacent questions
Browse files- app.py +57 -9
- requirements.txt +1 -1
- utils/helpers.py +8 -1
app.py
CHANGED
|
@@ -3,6 +3,8 @@
|
|
| 3 |
#%%
|
| 4 |
# Import libraries
|
| 5 |
import os
|
|
|
|
|
|
|
| 6 |
from lingua import Language, LanguageDetectorBuilder
|
| 7 |
import gradio as gr
|
| 8 |
from openai import OpenAI as OpenAIOG
|
|
@@ -27,20 +29,41 @@ client = OpenAIOG()
|
|
| 27 |
# Load index for retrieval
|
| 28 |
storage_context = StorageContext.from_defaults(persist_dir="arv_metadata")
|
| 29 |
index = load_index_from_storage(storage_context)
|
| 30 |
-
retriever = index.as_retriever(similarity_top_k=10,
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
# Use LLM reranking to filter results
|
| 34 |
-
reranker=LLMRerank(top_n=3))
|
| 35 |
|
| 36 |
#%%
|
| 37 |
# Define Gradio function
|
| 38 |
-
def nishauri(
|
| 39 |
|
| 40 |
"""Process user query, detect language, handle greetings, acknowledgments, and retrieve relevant information."""
|
| 41 |
# context = " ".join([item["user"] + " " + item["chatbot"] for item in conversation_history])
|
| 42 |
# formatted_history = convert_conversation_format(conversation_history)
|
| 43 |
# summary = summarize_conversation(formatted_history)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
# detect language of user
|
| 46 |
lang_question = helpers.detect_language(question, Language, LanguageDetectorBuilder, client)
|
|
@@ -78,8 +101,29 @@ def nishauri(question, conversation_history: list[str]):
|
|
| 78 |
question = GoogleTranslator(source='sw', target='en').translate(question)
|
| 79 |
|
| 80 |
# Retrieve relevant sources
|
| 81 |
-
sources = retriever.retrieve(question)
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
# Combine into new user question - conversation history, new question, retrieved sources
|
| 85 |
question_final = (
|
|
@@ -96,7 +140,8 @@ def nishauri(question, conversation_history: list[str]):
|
|
| 96 |
"You are a helpful assistant who only answers questions about HIV.\n"
|
| 97 |
"- Only answers questions about HIV (Human Immunodeficiency Virus).\n"
|
| 98 |
"- Recognize that users may type 'HIV' with any capitalization (e.g., HIV, hiv, Hiv, etc.) or make minor typos (e.g., hvi, hiv/aids).\n"
|
| 99 |
-
"-
|
|
|
|
| 100 |
"- Do not answer questions about other topics (e.g., malaria or tuberculosis).\n"
|
| 101 |
"- If a question is unrelated to HIV, politely respond that you can only answer HIV-related questions.\n\n"
|
| 102 |
|
|
@@ -115,6 +160,9 @@ def nishauri(question, conversation_history: list[str]):
|
|
| 115 |
"- A suppressed viral load is one below 200 copies/ml.\n\n"
|
| 116 |
)
|
| 117 |
|
|
|
|
|
|
|
|
|
|
| 118 |
# Start with context
|
| 119 |
messages = [{"role": "system", "content": system_prompt}]
|
| 120 |
|
|
|
|
| 3 |
#%%
|
| 4 |
# Import libraries
|
| 5 |
import os
|
| 6 |
+
import json
|
| 7 |
+
from datetime import datetime
|
| 8 |
from lingua import Language, LanguageDetectorBuilder
|
| 9 |
import gradio as gr
|
| 10 |
from openai import OpenAI as OpenAIOG
|
|
|
|
| 29 |
# Load index for retrieval
|
| 30 |
storage_context = StorageContext.from_defaults(persist_dir="arv_metadata")
|
| 31 |
index = load_index_from_storage(storage_context)
|
| 32 |
+
# retriever = index.as_retriever(similarity_top_k=10,
|
| 33 |
+
# # Similarity threshold for filtering
|
| 34 |
+
# similarity_threshold=0.5)
|
|
|
|
|
|
|
| 35 |
|
| 36 |
#%%
|
| 37 |
# Define Gradio function
|
| 38 |
+
def nishauri(user_params: str, conversation_history: list[str]):
|
| 39 |
|
| 40 |
"""Process user query, detect language, handle greetings, acknowledgments, and retrieve relevant information."""
|
| 41 |
# context = " ".join([item["user"] + " " + item["chatbot"] for item in conversation_history])
|
| 42 |
# formatted_history = convert_conversation_format(conversation_history)
|
| 43 |
# summary = summarize_conversation(formatted_history)
|
| 44 |
+
user_params = json.loads(user_params)
|
| 45 |
+
|
| 46 |
+
# Extract user information
|
| 47 |
+
consent = user_params.get("CONSENT")
|
| 48 |
+
person_info = user_params.get("PERSON_INFO", {})
|
| 49 |
+
gender = person_info.get("GENDER", "")
|
| 50 |
+
age = person_info.get("AGE", "")
|
| 51 |
+
vl_result = person_info.get("VIRAL_LOAD", "")
|
| 52 |
+
vl_date = helpers.convert_to_date(person_info.get("VIRAL_LOAD_DATETIME", ""), datetime)
|
| 53 |
+
next_appt_date = helpers.convert_to_date(person_info.get("APPOINTMENT_DATETIME", ""), datetime)
|
| 54 |
+
regimen = person_info.get("REGIMEN", "")
|
| 55 |
+
question = user_params.get("QUESTION", "")
|
| 56 |
+
|
| 57 |
+
info_pieces = [
|
| 58 |
+
"Here is information about the person asking the question."
|
| 59 |
+
f"The person is {gender}." if gender else "",
|
| 60 |
+
f"The person is age {age}." if age else "",
|
| 61 |
+
f"The person's next clinical check-in is scheduled for {next_appt_date}." if next_appt_date else "",
|
| 62 |
+
f"The person is on the following regimen for HIV: {regimen}." if regimen else "",
|
| 63 |
+
f"The person's most recent viral load result was {vl_result}." if vl_result else "",
|
| 64 |
+
f"The person's most recent viral load was taken on {vl_date}." if vl_date else "",
|
| 65 |
+
]
|
| 66 |
+
full_text = " ".join(filter(None, info_pieces))
|
| 67 |
|
| 68 |
# detect language of user
|
| 69 |
lang_question = helpers.detect_language(question, Language, LanguageDetectorBuilder, client)
|
|
|
|
| 101 |
question = GoogleTranslator(source='sw', target='en').translate(question)
|
| 102 |
|
| 103 |
# Retrieve relevant sources
|
| 104 |
+
# sources = retriever.retrieve(question)
|
| 105 |
+
# Summarize the conversation history
|
| 106 |
+
history_summary = " ".join(
|
| 107 |
+
[f"User: {turn['user']} Assistant: {turn['chatbot']}" for turn in conversation_history]
|
| 108 |
+
)
|
| 109 |
+
query_with_context = f"Current question: {question}\n\nSummary of prior context: {history_summary}"
|
| 110 |
+
|
| 111 |
+
# Initialize the LLMRerank postprocessor
|
| 112 |
+
reranker = LLMRerank(top_n=3)
|
| 113 |
+
|
| 114 |
+
# Attach the reranker to the retriever
|
| 115 |
+
retriever_with_rerank = index.as_retriever(
|
| 116 |
+
similarity_top_k=10,
|
| 117 |
+
similarity_threshold=0.6,
|
| 118 |
+
postprocessors=[reranker]
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Retrieve and re-rank sources
|
| 122 |
+
sources = retriever_with_rerank.retrieve(query_with_context)
|
| 123 |
+
|
| 124 |
+
# Combine the top-ranked sources
|
| 125 |
+
retrieved_text = "\n\n".join([f"Source {i+1}: {source.text}" for i, source in enumerate(sources)])
|
| 126 |
+
|
| 127 |
|
| 128 |
# Combine into new user question - conversation history, new question, retrieved sources
|
| 129 |
question_final = (
|
|
|
|
| 140 |
"You are a helpful assistant who only answers questions about HIV.\n"
|
| 141 |
"- Only answers questions about HIV (Human Immunodeficiency Virus).\n"
|
| 142 |
"- Recognize that users may type 'HIV' with any capitalization (e.g., HIV, hiv, Hiv, etc.) or make minor typos (e.g., hvi, hiv/aids).\n"
|
| 143 |
+
"- If a question is ambiguous or might be indirectly related to HIV (e.g., symptoms, illness, or general health concerns), assume it could be relevant to HIV and respond accordingly.\n"
|
| 144 |
+
"- If a question is about using the Nishauri app, such as finding viral load results, regimen details, or the next appointment, provide clear instructions on how to navigate the app to find this information.\n"
|
| 145 |
"- Do not answer questions about other topics (e.g., malaria or tuberculosis).\n"
|
| 146 |
"- If a question is unrelated to HIV, politely respond that you can only answer HIV-related questions.\n\n"
|
| 147 |
|
|
|
|
| 160 |
"- A suppressed viral load is one below 200 copies/ml.\n\n"
|
| 161 |
)
|
| 162 |
|
| 163 |
+
if consent == "YES":
|
| 164 |
+
system_prompt = f"{system_prompt} {full_text}."
|
| 165 |
+
|
| 166 |
# Start with context
|
| 167 |
messages = [{"role": "system", "content": system_prompt}]
|
| 168 |
|
requirements.txt
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
gradio==4.44.1
|
| 2 |
-
llama_index==0.
|
| 3 |
langdetect==1.0.9
|
| 4 |
deep-translator==1.11.4
|
| 5 |
lingua-language-detector==2.0.2
|
|
|
|
| 1 |
gradio==4.44.1
|
| 2 |
+
llama_index==0.12.34
|
| 3 |
langdetect==1.0.9
|
| 4 |
deep-translator==1.11.4
|
| 5 |
lingua-language-detector==2.0.2
|
utils/helpers.py
CHANGED
|
@@ -100,4 +100,11 @@ def detect_intention(user_input, client):
|
|
| 100 |
temperature=0 # for deterministic output
|
| 101 |
)
|
| 102 |
|
| 103 |
-
return completion.choices[0].message.content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
temperature=0 # for deterministic output
|
| 101 |
)
|
| 102 |
|
| 103 |
+
return completion.choices[0].message.content
|
| 104 |
+
|
| 105 |
+
def convert_to_date(date_str, datetime):
|
| 106 |
+
"""Convert date string in YYYYMMDD format to YYYY-MM-DD."""
|
| 107 |
+
try:
|
| 108 |
+
return datetime.strptime(date_str, "%Y%m%d").strftime("%Y-%m-%d")
|
| 109 |
+
except ValueError:
|
| 110 |
+
return "Unknown Date"
|