YoniFriedman commited on
Commit
6a1a818
·
verified ·
1 Parent(s): e952d98

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -0
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["OPENAI_API_KEY"]
3
+
4
+ from llama_index import (
5
+ VectorStoreIndex,
6
+ SummaryIndex,
7
+ SimpleKeywordTableIndex,
8
+ SimpleDirectoryReader,
9
+ ServiceContext,
10
+ StorageContext,
11
+ load_index_from_storage
12
+ )
13
+ from llama_index.schema import IndexNode
14
+ from llama_index.tools import QueryEngineTool, ToolMetadata
15
+ from llama_index.llms import OpenAI
16
+
17
+ llm = OpenAI(temperature=0, model="gpt-3.5-turbo")
18
+ service_context = ServiceContext.from_defaults(llm=llm)
19
+
20
+ arv_index = load_index_from_storage(StorageContext.from_defaults(persist_dir = "./arv/"))
21
+ # arv_summary_index = load_index_from_storage(StorageContext.from_defaults(persist_dir = "./arv_summary/"))
22
+ arv_vector_query_engine = arv_index.as_query_engine(similarity_top_k = 3)
23
+ # arv_summary_query_engine = arv_summary_index.as_query_engine()
24
+
25
+ nishauri_index = load_index_from_storage(StorageContext.from_defaults(persist_dir = "./nishauri/"))
26
+ # nishauri_summary_index = load_index_from_storage(StorageContext.from_defaults(persist_dir = "./nishauri_summary/"))
27
+ nishauri_vector_query_engine = nishauri_index.as_query_engine(similarity_top_k = 3)
28
+ # nishauri_summary_query_engine = nishauri_summary_index.as_query_engine()
29
+
30
+ from llama_index.agent import OpenAIAgent
31
+
32
+ agents = {}
33
+
34
+ # define tools
35
+ query_engine_tools = [
36
+ QueryEngineTool(
37
+ query_engine=arv_vector_query_engine,
38
+ metadata=ToolMetadata(
39
+ name="arv_vector_tool",
40
+ description=(
41
+ "Useful for retrieving specific context about HIV care and treatment."
42
+ ),
43
+ ),
44
+ ),
45
+ # QueryEngineTool(
46
+ # query_engine=arv_summary_query_engine,
47
+ # metadata=ToolMetadata(
48
+ # name="arv_summary_tool",
49
+ # description=(
50
+ # "Useful for summarization questions related to HIV care and treatment."
51
+ # ),
52
+ # ),
53
+ # ),
54
+ ]
55
+
56
+ # build agent
57
+ function_llm = OpenAI(model="gpt-3.5-turbo-0613", temperature = 0)
58
+ agent = OpenAIAgent.from_tools(
59
+ query_engine_tools,
60
+ llm=function_llm,
61
+ verbose=True,
62
+ )
63
+
64
+ agents["arv"] = agent
65
+
66
+ # define tools
67
+ query_engine_tools = [
68
+ QueryEngineTool(
69
+ query_engine=nishauri_vector_query_engine,
70
+ metadata=ToolMetadata(
71
+ name="nishauri_vector_tool",
72
+ description=(
73
+ "Useful for retrieving specific context about the Nishauri mobile application through which users are asking questions"
74
+ ),
75
+ ),
76
+ ),
77
+ # QueryEngineTool(
78
+ # query_engine=nishauri_summary_query_engine,
79
+ # metadata=ToolMetadata(
80
+ # name="nishauri_summary_tool",
81
+ # description=(
82
+ # "Useful for summarization questions related to the Nishauri mobile application through which users are asking questions"
83
+ # ),
84
+ # ),
85
+ # ),
86
+ ]
87
+
88
+ # build agent
89
+ function_llm = OpenAI(model="gpt-3.5-turbo-0613", temperature = 0)
90
+ agent = OpenAIAgent.from_tools(
91
+ query_engine_tools,
92
+ llm=function_llm,
93
+ verbose=True,
94
+ )
95
+
96
+ agents["nishauri"] = agent
97
+
98
+ # define top-level nodes
99
+ nodes = []
100
+
101
+ arv_summary = (
102
+ "This content contains care and treatment guidance for people living with HIV."
103
+ " Use this source to answer questions about ARV medications, side effects from medication,"
104
+ " understanding viral loads, and any question about HIV care and treatment."
105
+ " This is the default source to use for answering any question that isn't about how to find"
106
+ " information in the Nishauri app."
107
+ )
108
+ node = IndexNode(text=arv_summary, index_id="arv")
109
+ nodes.append(node)
110
+
111
+ nishauri_summary = (
112
+ "This content contains guidance on the Nishauri mobile application through which users are asking questions."
113
+ " Reference this document when users ask questions such as how to find their viral load"
114
+ " or lab histories, how to find their appointment histories,"
115
+ " and want to know how to change their upcoming appointments."
116
+ " Do not use this to answer any other questions."
117
+ )
118
+ node = IndexNode(text=nishauri_summary, index_id="nishauri")
119
+ nodes.append(node)
120
+
121
+ # define top-level retriever
122
+ vector_index = VectorStoreIndex(nodes)
123
+ vector_retriever = vector_index.as_retriever(similarity_top_k=1)
124
+
125
+ # define recursive retriever
126
+ from llama_index.retrievers import RecursiveRetriever
127
+ from llama_index.query_engine import RetrieverQueryEngine
128
+ from llama_index.response_synthesizers import get_response_synthesizer
129
+
130
+ # note: can pass `agents` dict as `query_engine_dict` since every agent can be used as a query engine
131
+ recursive_retriever = RecursiveRetriever(
132
+ "vector",
133
+ retriever_dict={"vector": vector_retriever},
134
+ query_engine_dict=agents,
135
+ verbose=True,
136
+ )
137
+
138
+ response_synthesizer = get_response_synthesizer(
139
+ # service_context=service_context,
140
+ response_mode="compact",
141
+ )
142
+ query_engine = RetrieverQueryEngine.from_args(
143
+ recursive_retriever,
144
+ response_synthesizer=response_synthesizer,
145
+ service_context=service_context,
146
+ )
147
+
148
+
149
+ preamble = (" The person asking the following prompt is a person living with HIV in Kenya."
150
+ " For every response, recognize that they already have HIV and do not suggest that they have to get tested"
151
+ " for HIV or take post-exposure prophylaxis, as that is not relevant, though their partners perhaps should."
152
+ " Do not suggest anything that is not relevant to someone who already has HIV."
153
+ " They are asking questions through a mobile application called Nishauri"
154
+ " through which they can see their lab results, appointment histories, and upcoming appointments."
155
+ " Here is some information that is authoritative and should guide responses, when relevant."
156
+ " For questions about viral load, be sure to provide specific information"
157
+ " about cutoffs for viral load categories. Under 50 copies/ml is low detectable level,"
158
+ " 50 - 199 copies/ml is low level viremia, 200 - 999 is high level viremia, and "
159
+ " 1000 and above is suspected treatment failure."
160
+ " A high viral load or non-suppressed viral load is any viral load above 200 copies/ml."
161
+ " A suppressed viral load is one below 200 copies / ml."
162
+ " An established client is one who is on their current ART regimen for a period greater"
163
+ " than 6 months, had no active OI or in the previous 6 months, has adhered to scheduled"
164
+ " clinic visits for the previous 6 months and Viral load results has been less than 200 copies/ml"
165
+ " within the last 6 months."
166
+ " For questions about when patients should get their viral loads taken,"
167
+ " if they are newly initiated on ART, the first viral load sample should be taken after 3 months of"
168
+ " taking ART. Otherwise, if they are not new on ART, then if their previous result was below 50 to 199 cp/ml,"
169
+ " their viral load should be taken after every 12 months. If their previous result was above 200cp/ml,"
170
+ " then viral load sample should be taken after three months."
171
+ " Please answer the prompt using the information retrieved"
172
+ " and do not rely at all on your prior knowledge."
173
+ " Please keep your reply to no longer than three sentences, and please use simple language. ")
174
+
175
+ prompt_intro = (" Here is the prompt: ")
176
+
177
+ import gradio as gr
178
+
179
+ # num_queries = 0
180
+ # conversation_history = []
181
+ # context = ""
182
+
183
+ def nishauri(question: str, conversation_history: list[str]):
184
+
185
+ # global num_queries, context
186
+
187
+ # if num_queries == 0:
188
+
189
+ # response = query_engine.query(preamble + prompt_intro + question)
190
+
191
+ # if num_queries > 0:
192
+
193
+ context = " ".join([item["user"] + " " + item["chatbot"] for item in conversation_history])
194
+ response = query_engine.query(preamble +
195
+ "the user previously asked and received the following: " +
196
+ context +
197
+ prompt_intro +
198
+ question)
199
+
200
+ conversation_history.append({"user": question, "chatbot": response.response})
201
+
202
+ # num_queries += 1
203
+ return response, conversation_history
204
+
205
+ demo = gr.Interface(
206
+ title = "Nishauri Chatbot Demo",
207
+ fn=nishauri,
208
+ inputs=["text", gr.State(value=[])],
209
+ outputs=["text", gr.State()],
210
+ )
211
+
212
+ demo.launch()