YoniFriedman commited on
Commit
7c10e5d
·
verified ·
1 Parent(s): b3f0072

removing agents

Browse files
Files changed (1) hide show
  1. app.py +42 -189
app.py CHANGED
@@ -1,206 +1,59 @@
1
  import os
2
  os.environ["OPENAI_API_KEY"]
3
 
4
- from llama_index import (
5
- VectorStoreIndex,
6
- SummaryIndex,
7
- SimpleKeywordTableIndex,
8
- SimpleDirectoryReader,
9
- ServiceContext,
10
- StorageContext,
11
- load_index_from_storage
12
- )
13
- from llama_index.schema import IndexNode
14
- from llama_index.tools import QueryEngineTool, ToolMetadata
15
- from llama_index.llms import OpenAI
16
-
17
- llm = OpenAI(temperature=0, model="gpt-3.5-turbo")
18
- service_context = ServiceContext.from_defaults(llm=llm)
19
-
20
- arv_index = load_index_from_storage(StorageContext.from_defaults(persist_dir = "./arv/"))
21
- # arv_summary_index = load_index_from_storage(StorageContext.from_defaults(persist_dir = "./arv_summary/"))
22
- arv_vector_query_engine = arv_index.as_query_engine(similarity_top_k = 2)
23
- # arv_summary_query_engine = arv_summary_index.as_query_engine()
24
-
25
- nishauri_index = load_index_from_storage(StorageContext.from_defaults(persist_dir = "./nishauri/"))
26
- # nishauri_summary_index = load_index_from_storage(StorageContext.from_defaults(persist_dir = "./nishauri_summary/"))
27
- nishauri_vector_query_engine = nishauri_index.as_query_engine(similarity_top_k = 2)
28
- # nishauri_summary_query_engine = nishauri_summary_index.as_query_engine()
29
-
30
- from llama_index.agent import OpenAIAgent
31
-
32
- agents = {}
33
-
34
- # define tools
35
- query_engine_tools = [
36
- QueryEngineTool(
37
- query_engine=arv_vector_query_engine,
38
- metadata=ToolMetadata(
39
- name="arv_vector_tool",
40
- description=(
41
- "Useful for retrieving specific context about HIV care and treatment."
42
- ),
43
- ),
44
- ),
45
- # QueryEngineTool(
46
- # query_engine=arv_summary_query_engine,
47
- # metadata=ToolMetadata(
48
- # name="arv_summary_tool",
49
- # description=(
50
- # "Useful for summarization questions related to HIV care and treatment."
51
- # ),
52
- # ),
53
- # ),
54
- ]
55
-
56
- # build agent
57
- function_llm = OpenAI(model="gpt-3.5-turbo-0613", temperature = 0)
58
- agent = OpenAIAgent.from_tools(
59
- query_engine_tools,
60
- llm=function_llm,
61
- verbose=True,
62
- )
63
-
64
- agents["arv"] = agent
65
-
66
- # define tools
67
- query_engine_tools = [
68
- QueryEngineTool(
69
- query_engine=nishauri_vector_query_engine,
70
- metadata=ToolMetadata(
71
- name="nishauri_vector_tool",
72
- description=(
73
- "Useful for retrieving specific context about the Nishauri mobile application through which users are asking questions"
74
- ),
75
- ),
76
- ),
77
- # QueryEngineTool(
78
- # query_engine=nishauri_summary_query_engine,
79
- # metadata=ToolMetadata(
80
- # name="nishauri_summary_tool",
81
- # description=(
82
- # "Useful for summarization questions related to the Nishauri mobile application through which users are asking questions"
83
- # ),
84
- # ),
85
- # ),
86
- ]
87
-
88
- # build agent
89
- function_llm = OpenAI(model="gpt-3.5-turbo-0613", temperature = 0)
90
- agent = OpenAIAgent.from_tools(
91
- query_engine_tools,
92
- llm=function_llm,
93
- verbose=True,
94
- )
95
-
96
- agents["nishauri"] = agent
97
-
98
- # define top-level nodes
99
- nodes = []
100
-
101
- arv_summary = (
102
- "This content contains care and treatment guidance for people living with HIV."
103
- " Use this source to answer questions about ARV medications, side effects from medication,"
104
- " understanding viral loads, and any question about HIV care and treatment."
105
- " This is the default source to use for answering any question that isn't about how to find"
106
- " information in the Nishauri app."
107
- )
108
- node = IndexNode(text=arv_summary, index_id="arv")
109
- nodes.append(node)
110
-
111
- nishauri_summary = (
112
- "This content contains guidance on the Nishauri mobile application through which users are asking questions."
113
- " Reference this document when users ask questions such as how to find their viral load"
114
- " or lab histories, how to find their appointment histories,"
115
- " and want to know how to change their upcoming appointments."
116
- " Do not use this to answer any other questions."
117
- )
118
- node = IndexNode(text=nishauri_summary, index_id="nishauri")
119
- nodes.append(node)
120
-
121
- # define top-level retriever
122
- vector_index = VectorStoreIndex(nodes)
123
- vector_retriever = vector_index.as_retriever(similarity_top_k=2)
124
 
125
- # define recursive retriever
126
- from llama_index.retrievers import RecursiveRetriever
127
- from llama_index.query_engine import RetrieverQueryEngine
128
- from llama_index.response_synthesizers import get_response_synthesizer
129
-
130
- # note: can pass `agents` dict as `query_engine_dict` since every agent can be used as a query engine
131
- recursive_retriever = RecursiveRetriever(
132
- "vector",
133
- retriever_dict={"vector": vector_retriever},
134
- query_engine_dict=agents,
135
- verbose=True,
136
- )
137
-
138
- response_synthesizer = get_response_synthesizer(
139
- # service_context=service_context,
140
- response_mode="compact",
141
- )
142
- query_engine = RetrieverQueryEngine.from_args(
143
- recursive_retriever,
144
- response_synthesizer=response_synthesizer,
145
- service_context=service_context,
146
- )
147
 
 
 
 
 
148
 
149
- preamble = (" The person asking the following prompt is a person living with HIV in Kenya."
150
- " For every response, recognize that they already have HIV and do not suggest that they have to get tested"
151
  " for HIV or take post-exposure prophylaxis, as that is not relevant, though their partners perhaps should."
152
  " Do not suggest anything that is not relevant to someone who already has HIV."
153
- " They are asking questions through a mobile application called Nishauri"
154
- " through which they can see their lab results, appointment histories, and upcoming appointments."
155
- " Here is some information that is authoritative and should guide responses, when relevant."
156
- " For questions about viral load, be sure to provide specific information"
157
- " about cutoffs for viral load categories. Under 50 copies/ml is low detectable level,"
158
  " 50 - 199 copies/ml is low level viremia, 200 - 999 is high level viremia, and "
159
  " 1000 and above is suspected treatment failure."
160
  " A high viral load or non-suppressed viral load is any viral load above 200 copies/ml."
161
- " A suppressed viral load is one below 200 copies / ml."
162
- " An established client is one who is on their current ART regimen for a period greater"
163
- " than 6 months, had no active OI or in the previous 6 months, has adhered to scheduled"
164
- " clinic visits for the previous 6 months and Viral load results has been less than 200 copies/ml"
165
- " within the last 6 months."
166
- " For questions about when patients should get their viral loads taken,"
167
- " if they are newly initiated on ART, the first viral load sample should be taken after 3 months of"
168
- " taking ART. Otherwise, if they are not new on ART, then if their previous result was below 50 to 199 cp/ml,"
169
- " their viral load should be taken after every 12 months. If their previous result was above 200cp/ml,"
170
- " then viral load sample should be taken after three months."
171
- " Please answer the prompt using the information retrieved"
172
- " and do not rely at all on your prior knowledge."
173
- " Please keep your reply to no longer than three sentences, and please use simple language. ")
174
-
175
- prompt_intro = (" Here is the prompt: ")
176
-
177
- import gradio as gr
178
-
179
- # num_queries = 0
180
- # conversation_history = []
181
- # context = ""
182
-
183
- def nishauri(question: str, conversation_history: list[str]):
184
 
185
- # global num_queries, context
186
-
187
- # if num_queries == 0:
188
-
189
- # response = query_engine.query(preamble + prompt_intro + question)
190
-
191
- # if num_queries > 0:
192
-
193
- context = " ".join([item["user"] + " " + item["chatbot"] for item in conversation_history])
194
- response = query_engine.query(preamble +
195
- "the user previously asked and received the following: " +
196
- context +
197
- prompt_intro +
198
- question)
199
-
200
  conversation_history.append({"user": question, "chatbot": response.response})
201
 
202
- # num_queries += 1
203
- return response, conversation_history
204
 
205
  demo = gr.Interface(
206
  title = "Nishauri Chatbot Demo",
 
1
  import os
2
  os.environ["OPENAI_API_KEY"]
3
 
4
+ from llama_index.llms.openai import OpenAI
5
+ from llama_index.core.schema import MetadataMode
6
+ import openai
7
+ from openai import OpenAI as OpenAIOG
8
+ import logging
9
+ import sys
10
+ llm = OpenAI(temperature=0.0, model="gpt-3.5-turbo")
11
+ client = OpenAIOG()
12
+
13
+ # Load index
14
+ from llama_index.core import VectorStoreIndex
15
+ from llama_index.core import StorageContext
16
+ from llama_index.core import load_index_from_storage
17
+ storage_context = StorageContext.from_defaults(persist_dir="arv_metadata")
18
+ index = load_index_from_storage(storage_context)
19
+ query_engine = index.as_query_engine(similarity_top_k=3, llm=llm)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ def nishauri(question: str, conversation_history: list[str]):
24
+
25
+ context = " ".join([item["user"] + " " + item["chatbot"] for item in conversation_history])
26
+ response = query_engine.query(question)
27
 
28
+ background = ("The person who asked the question is a person living with HIV."
29
+ " Recognize that they already have HIV and do not suggest that they have to get tested"
30
  " for HIV or take post-exposure prophylaxis, as that is not relevant, though their partners perhaps should."
31
  " Do not suggest anything that is not relevant to someone who already has HIV."
32
+ " Do not mention in the response that the person is living with HIV."
33
+ " The following information about viral loads is authoritative and should override the initial response if appropriate:"
34
+ " Under 50 copies/ml is low detectable level,"
 
 
35
  " 50 - 199 copies/ml is low level viremia, 200 - 999 is high level viremia, and "
36
  " 1000 and above is suspected treatment failure."
37
  " A high viral load or non-suppressed viral load is any viral load above 200 copies/ml."
38
+ " A suppressed viral load is one below 200 copies / ml.")
39
+
40
+ question_final = (
41
+ f"The user previously asked and answered the following: {context}"
42
+ f" The user just asked the following question: {question}"
43
+ f" The following response was generated in response: {response}"
44
+ f" Please update the response provided only if needed, based on the following background information {background}"
45
+ )
46
+
47
+ completion = client.chat.completions.create(
48
+ model="gpt-3.5-turbo",
49
+ messages=[
50
+ {"role": "user", "content": question_final}
51
+ ]
52
+ )
 
 
 
 
 
 
 
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  conversation_history.append({"user": question, "chatbot": response.response})
55
 
56
+ return completion.choices[0].message.content, conversation_history
 
57
 
58
  demo = gr.Interface(
59
  title = "Nishauri Chatbot Demo",