achuthc1298 commited on
Commit
d19a87c
·
verified ·
1 Parent(s): 9063cbb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -54
app.py CHANGED
@@ -1,11 +1,10 @@
1
  import streamlit as st
2
  import os
3
  from pathlib import Path
4
- from llama_index.core.query_engine.router_query_engine import RouterQueryEngine
5
  from llama_index.core.selectors import LLMSingleSelector
6
  from llama_index.core.tools import QueryEngineTool
7
- from llama_index.core import SummaryIndex, VectorStoreIndex
8
- from llama_index.core import VectorStoreIndex, Settings
9
  from llama_index.core import SimpleDirectoryReader
10
  from llama_index.llms.groq import Groq
11
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
@@ -13,6 +12,9 @@ from typing import Tuple
13
  from llama_index.core import StorageContext, load_index_from_storage
14
  from llama_index.core.objects import ObjectIndex
15
  from llama_index.core.agent import ReActAgent
 
 
 
16
 
17
  # Function to process files and create document tools
18
  def create_doc_tools(document_fp: str, doc_name: str, verbose: bool = True) -> Tuple[QueryEngineTool,]:
@@ -45,8 +47,11 @@ def find_tex_files(directory: str):
45
  tex_files.sort()
46
  return tex_files
47
 
 
 
48
  # Main app function
49
  def main():
 
50
  st.title("AMGPT, Powered by LlamaIndex")
51
 
52
  # API Key input
@@ -56,59 +61,98 @@ def main():
56
 
57
  llm = Groq(model="mixtral-8x7b-32768")
58
 
59
- if "history" not in st.session_state:
60
- st.session_state.history = []
61
-
62
- if apikey:
63
- directory = '/home/user/app/rag_docs_final_review_tex_merged'
64
- tex_files = find_tex_files(directory)
65
-
66
- paper_to_tools_dict = {}
67
- for paper in tex_files:
68
- path = Path(paper)
69
- vector_tool = create_doc_tools(doc_name=path.stem, document_fp=path)
70
- paper_to_tools_dict[path.stem] = [vector_tool]
71
-
72
- initial_tools = [t for paper in tex_files for t in paper_to_tools_dict[Path(paper).stem]]
73
-
74
- obj_index = ObjectIndex.from_objects(
75
- initial_tools,
76
- index_cls=VectorStoreIndex,
77
- )
78
-
79
- obj_retriever = obj_index.as_retriever(similarity_top_k=6)
80
 
81
- context = """You are an agent designed to answer scientific queries over a set of given documents.
82
- Please always use the tools provided to answer a question. Do not rely on prior knowledge.
83
- """
84
 
85
- agent = ReActAgent.from_tools(
86
- tool_retriever=obj_retriever,
87
- llm=llm,
88
- verbose=True,
89
- context=context
90
- )
91
-
92
- user_prompt = st.text_input("Enter your question")
93
-
94
- if user_prompt:
95
- st.session_state.history.append({"user": user_prompt})
96
-
97
- with st.spinner("Processing..."):
98
- response = agent.query(user_prompt)
99
- st.session_state.history.append({"agent": response})
100
-
101
- # Display the latest response
102
- st.markdown(f"### Query Response:\n{response}")
103
-
104
- # Display chat history
105
- if st.session_state.history:
106
- st.markdown("## Chat History")
107
- for chat in st.session_state.history:
108
- if "user" in chat:
109
- st.markdown(f"**User:** {chat['user']}")
110
- if "agent" in chat:
111
- st.markdown(f"**Agent:** {chat['agent']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  if __name__ == "__main__":
114
  main()
 
 
 
1
  import streamlit as st
2
  import os
3
  from pathlib import Path
 
4
  from llama_index.core.selectors import LLMSingleSelector
5
  from llama_index.core.tools import QueryEngineTool
6
+ from llama_index.core import VectorStoreIndex
7
+ from llama_index.core import Settings
8
  from llama_index.core import SimpleDirectoryReader
9
  from llama_index.llms.groq import Groq
10
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
 
12
  from llama_index.core import StorageContext, load_index_from_storage
13
  from llama_index.core.objects import ObjectIndex
14
  from llama_index.core.agent import ReActAgent
15
+ import time
16
+ import sys
17
+ import io
18
 
19
  # Function to process files and create document tools
20
  def create_doc_tools(document_fp: str, doc_name: str, verbose: bool = True) -> Tuple[QueryEngineTool,]:
 
47
  tex_files.sort()
48
  return tex_files
49
 
50
+
51
+
52
  # Main app function
53
  def main():
54
+
55
  st.title("AMGPT, Powered by LlamaIndex")
56
 
57
  # API Key input
 
61
 
62
  llm = Groq(model="mixtral-8x7b-32768")
63
 
64
+ with st.sidebar:
65
+ verbose_toggle = st.toggle("Verbose") # get verbose or only LLM response
66
+ reset = st.button('Reset Chat!') # reset the chat
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
 
 
68
 
69
+ if apikey:
70
+ if "tools_loaded" not in st.session_state:
71
+ try:
72
+ directory = '/home/user/app/rag_docs_final_review_tex_merged'
73
+ tex_files = find_tex_files(directory)
74
+
75
+ paper_to_tools_dict = {}
76
+ for paper in tex_files:
77
+ path = Path(paper)
78
+ vector_tool = create_doc_tools(doc_name=path.stem, document_fp=path)
79
+ paper_to_tools_dict[path.stem] = [vector_tool]
80
+
81
+ initial_tools = [t for paper in tex_files for t in paper_to_tools_dict[Path(paper).stem]]
82
+
83
+ obj_index = ObjectIndex.from_objects(
84
+ initial_tools,
85
+ index_cls=VectorStoreIndex,
86
+ )
87
+
88
+ obj_retriever = obj_index.as_retriever(similarity_top_k=6)
89
+
90
+
91
+ context = """You are an agent designed to answer scientific queries over a set of given documents.
92
+ Please always use the tools provided to answer a question. Do not rely on prior knowledge.
93
+ """
94
+
95
+ agent = ReActAgent.from_tools(
96
+ tool_retriever=obj_retriever,
97
+ llm=llm,
98
+ verbose=True,
99
+ context=context
100
+ )
101
+
102
+ # store session state variables
103
+ st.session_state["tools_loaded"] = True
104
+ st.session_state["agent"] = agent
105
+ except Exception as e:
106
+ st.error(e)
107
+
108
+
109
+
110
+
111
+
112
+ if "messages" not in st.session_state or reset==True:
113
+ st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
114
+
115
+ for msg in st.session_state.messages:
116
+ st.chat_message(msg["role"]).write(msg["content"])
117
+
118
+ if prompt := st.chat_input():
119
+
120
+ # if the user started chatting without setting the OPENAI API KEY
121
+ if not apikey:
122
+ st.info("Please add your Groq API key to continue.")
123
+ st.stop()
124
+
125
+ st.session_state.messages.append({"role": "user", "content": prompt})
126
+ st.chat_message("user").write(prompt)
127
+ try:
128
+ with st.spinner('Wait for output...'):
129
+ # Redirect stdout
130
+ original_stdout = sys.stdout
131
+ sys.stdout = io.StringIO()
132
+
133
+ # query the agent
134
+ response = st.session_state.agent.query(prompt)
135
+
136
+ # Get the captured output and restore stdout
137
+ output = sys.stdout.getvalue()
138
+ sys.stdout = original_stdout
139
+
140
+ # format the received verbose output
141
+ verbose = ''
142
+ for output_string in output.split('==='):
143
+ verbose+=output_string
144
+ verbose+='\n'
145
+
146
+ # assistant response
147
+ msg = f'{verbose}' if verbose_toggle else f'{response.response[10:]}'
148
+
149
+ # write the response
150
+ st.session_state.messages.append({"role": "assistant", "content": msg})
151
+ st.chat_message("assistant").write(msg)
152
+ except Exception as e:
153
+ st.error(e)
154
 
155
  if __name__ == "__main__":
156
  main()
157
+
158
+