YoniFriedman commited on
Commit
530b77b
·
verified ·
1 Parent(s): c01ce6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -4
app.py CHANGED
@@ -17,10 +17,133 @@ from llama_index.llms import OpenAI
17
  llm = OpenAI(temperature=0, model="gpt-3.5-turbo")
18
  service_context = ServiceContext.from_defaults(llm=llm)
19
 
20
- PERSIST_DIR = "arv_metadata"
21
- storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
22
- index = load_index_from_storage(storage_context)
23
- query_engine = index.as_query_engine(similarity_top_k=3, llm=OpenAI(model="gpt-3.5-turbo"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
  preamble = (" The person asking the following prompt is a person living with HIV in Kenya."
 
17
  llm = OpenAI(temperature=0, model="gpt-3.5-turbo")
18
  service_context = ServiceContext.from_defaults(llm=llm)
19
 
20
+ arv_index = load_index_from_storage(StorageContext.from_defaults(persist_dir = "./arv/"))
21
+ # arv_summary_index = load_index_from_storage(StorageContext.from_defaults(persist_dir = "./arv_summary/"))
22
+ arv_vector_query_engine = arv_index.as_query_engine(similarity_top_k = 3)
23
+ # arv_summary_query_engine = arv_summary_index.as_query_engine()
24
+
25
+ nishauri_index = load_index_from_storage(StorageContext.from_defaults(persist_dir = "./nishauri/"))
26
+ # nishauri_summary_index = load_index_from_storage(StorageContext.from_defaults(persist_dir = "./nishauri_summary/"))
27
+ nishauri_vector_query_engine = nishauri_index.as_query_engine(similarity_top_k = 3)
28
+ # nishauri_summary_query_engine = nishauri_summary_index.as_query_engine()
29
+
30
+ from llama_index.agent import OpenAIAgent
31
+
32
+ agents = {}
33
+
34
+ # define tools
35
+ query_engine_tools = [
36
+ QueryEngineTool(
37
+ query_engine=arv_vector_query_engine,
38
+ metadata=ToolMetadata(
39
+ name="arv_vector_tool",
40
+ description=(
41
+ "Useful for retrieving specific context about HIV care and treatment."
42
+ ),
43
+ ),
44
+ ),
45
+ # QueryEngineTool(
46
+ # query_engine=arv_summary_query_engine,
47
+ # metadata=ToolMetadata(
48
+ # name="arv_summary_tool",
49
+ # description=(
50
+ # "Useful for summarization questions related to HIV care and treatment."
51
+ # ),
52
+ # ),
53
+ # ),
54
+ ]
55
+
56
+ # build agent
57
+ function_llm = OpenAI(model="gpt-3.5-turbo-0613", temperature = 0)
58
+ agent = OpenAIAgent.from_tools(
59
+ query_engine_tools,
60
+ llm=function_llm,
61
+ verbose=True,
62
+ )
63
+
64
+ agents["arv"] = agent
65
+
66
+ # define tools
67
+ query_engine_tools = [
68
+ QueryEngineTool(
69
+ query_engine=nishauri_vector_query_engine,
70
+ metadata=ToolMetadata(
71
+ name="nishauri_vector_tool",
72
+ description=(
73
+ "Useful for retrieving specific context about the Nishauri mobile application through which users are asking questions"
74
+ ),
75
+ ),
76
+ ),
77
+ # QueryEngineTool(
78
+ # query_engine=nishauri_summary_query_engine,
79
+ # metadata=ToolMetadata(
80
+ # name="nishauri_summary_tool",
81
+ # description=(
82
+ # "Useful for summarization questions related to the Nishauri mobile application through which users are asking questions"
83
+ # ),
84
+ # ),
85
+ # ),
86
+ ]
87
+
88
+ # build agent
89
+ function_llm = OpenAI(model="gpt-3.5-turbo-0613", temperature = 0)
90
+ agent = OpenAIAgent.from_tools(
91
+ query_engine_tools,
92
+ llm=function_llm,
93
+ verbose=True,
94
+ )
95
+
96
+ agents["nishauri"] = agent
97
+
98
+ # define top-level nodes
99
+ nodes = []
100
+
101
+ arv_summary = (
102
+ "This content contains care and treatment guidance for people living with HIV."
103
+ " Use this source to answer questions about ARV medications, side effects from medication,"
104
+ " understanding viral loads, and any question about HIV care and treatment."
105
+ " This is the default source to use for answering any question that isn't about how to find"
106
+ " information in the Nishauri app."
107
+ )
108
+ node = IndexNode(text=arv_summary, index_id="arv")
109
+ nodes.append(node)
110
+
111
+ nishauri_summary = (
112
+ "This content contains guidance on the Nishauri mobile application through which users are asking questions."
113
+ " Reference this document when users ask questions such as how to find their viral load"
114
+ " or lab histories, how to find their appointment histories,"
115
+ " and want to know how to change their upcoming appointments."
116
+ " Do not use this to answer any other questions."
117
+ )
118
+ node = IndexNode(text=nishauri_summary, index_id="nishauri")
119
+ nodes.append(node)
120
+
121
+ # define top-level retriever
122
+ vector_index = VectorStoreIndex(nodes)
123
+ vector_retriever = vector_index.as_retriever(similarity_top_k=1)
124
+
125
+ # define recursive retriever
126
+ from llama_index.retrievers import RecursiveRetriever
127
+ from llama_index.query_engine import RetrieverQueryEngine
128
+ from llama_index.response_synthesizers import get_response_synthesizer
129
+
130
+ # note: can pass `agents` dict as `query_engine_dict` since every agent can be used as a query engine
131
+ recursive_retriever = RecursiveRetriever(
132
+ "vector",
133
+ retriever_dict={"vector": vector_retriever},
134
+ query_engine_dict=agents,
135
+ verbose=True,
136
+ )
137
+
138
+ response_synthesizer = get_response_synthesizer(
139
+ # service_context=service_context,
140
+ response_mode="compact",
141
+ )
142
+ query_engine = RetrieverQueryEngine.from_args(
143
+ recursive_retriever,
144
+ response_synthesizer=response_synthesizer,
145
+ service_context=service_context,
146
+ )
147
 
148
 
149
  preamble = (" The person asking the following prompt is a person living with HIV in Kenya."