Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import os | |
| from pathlib import Path | |
| from llama_index.core.selectors import LLMSingleSelector | |
| from llama_index.core.tools import QueryEngineTool | |
| from llama_index.core import VectorStoreIndex | |
| from llama_index.core import Settings | |
| from llama_index.core import SimpleDirectoryReader | |
| from llama_index.llms.groq import Groq | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| from typing import Tuple | |
| from llama_index.core import StorageContext, load_index_from_storage | |
| from llama_index.core.objects import ObjectIndex | |
| from llama_index.core.agent import ReActAgent | |
| import time | |
| import sys | |
| import io | |
| # Function to process files and create document tools | |
| def create_doc_tools(document_fp: str, doc_name: str, verbose: bool = True) -> Tuple[QueryEngineTool,]: | |
| documents = SimpleDirectoryReader(input_files=[document_fp]).load_data() | |
| Settings.llm = Groq(model="mixtral-8x7b-32768") | |
| Settings.embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-mpnet-base-v2") | |
| load_dir_path = f"/home/user/app/agentic_index_st/{doc_name}" | |
| storage_context = StorageContext.from_defaults(persist_dir=load_dir_path) | |
| vector_index = load_index_from_storage(storage_context) | |
| vector_query_engine = vector_index.as_query_engine() | |
| vector_tool = QueryEngineTool.from_defaults( | |
| name=f"{doc_name}_vector_query_engine_tool", | |
| query_engine=vector_query_engine, | |
| description=f"Useful for retrieving specific context from the {doc_name}.", | |
| ) | |
| return vector_tool | |
| # Function to find and sort .tex files | |
| def find_tex_files(directory: str): | |
| tex_files = [] | |
| for root, dirs, files in os.walk(directory): | |
| for file in files: | |
| if file.endswith(('.tex', '.txt')): | |
| file_path = os.path.abspath(os.path.join(root, file)) | |
| tex_files.append(file_path) | |
| tex_files.sort() | |
| return tex_files | |
| # Main app function | |
| def main(): | |
| st.title("AMGPT, By MAIL") | |
| # API Key input | |
| apikey = st.text_input("Enter your Groq API Key", type="password") | |
| os.environ["GROQ_API_KEY"] = apikey | |
| llm = Groq(model="mixtral-8x7b-32768") | |
| with st.sidebar: | |
| verbose_toggle = st.toggle("Verbose") # get verbose or only LLM response | |
| reset = st.button('Reset Chat!') # reset the chat | |
| if apikey: | |
| if "tools_loaded" not in st.session_state: | |
| try: | |
| directory = '/home/user/app/rag_docs_final_review_tex_merged' | |
| tex_files = find_tex_files(directory) | |
| with st.spinner('Please wait, AMGPT is loading....'): | |
| paper_to_tools_dict = {} | |
| for paper in tex_files: | |
| path = Path(paper) | |
| vector_tool = create_doc_tools(doc_name=path.stem, document_fp=path) | |
| paper_to_tools_dict[path.stem] = [vector_tool] | |
| initial_tools = [t for paper in tex_files for t in paper_to_tools_dict[Path(paper).stem]] | |
| obj_index = ObjectIndex.from_objects( | |
| initial_tools, | |
| index_cls=VectorStoreIndex, | |
| ) | |
| obj_retriever = obj_index.as_retriever(similarity_top_k=6) | |
| context = """You are an agent designed to answer scientific queries over a set of given documents. | |
| Please always use the tools provided to answer a question. Do not rely on prior knowledge. | |
| """ | |
| agent = ReActAgent.from_tools( | |
| tool_retriever=obj_retriever, | |
| llm=llm, | |
| verbose=True, | |
| context=context | |
| ) | |
| st.success('Done!, you may start asking questions now') | |
| # store session state variables | |
| st.session_state["tools_loaded"] = True | |
| st.session_state["agent"] = agent | |
| except Exception as e: | |
| st.error(e) | |
| if "messages" not in st.session_state or reset==True: | |
| st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}] | |
| for msg in st.session_state.messages: | |
| st.chat_message(msg["role"]).write(msg["content"]) | |
| if prompt := st.chat_input(): | |
| # if the user started chatting without setting the OPENAI API KEY | |
| if not apikey: | |
| st.info("Please add your Groq API key to continue.") | |
| st.stop() | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| st.chat_message("user").write(prompt) | |
| try: | |
| with st.spinner('Wait for output...'): | |
| # Redirect stdout | |
| original_stdout = sys.stdout | |
| sys.stdout = io.StringIO() | |
| # query the agent | |
| response = st.session_state.agent.query(prompt) | |
| # Get the captured output and restore stdout | |
| output = sys.stdout.getvalue() | |
| sys.stdout = original_stdout | |
| # format the received verbose output | |
| verbose = '' | |
| for output_string in output.split('==='): | |
| verbose += output_string | |
| verbose += '\n' | |
| # assistant response | |
| msg = f'{verbose}' if verbose_toggle else f'{response.response[:]}' | |
| # write the response | |
| st.session_state.messages.append({"role": "assistant", "content": msg}) | |
| st.chat_message("assistant").markdown(msg) | |
| except Exception as e: | |
| st.error(e) | |
| if __name__ == "__main__": | |
| main() | |