Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import embed_pdf | |
| import shutil | |
| from utils import make_discord_trace_text | |
| make_discord_trace_text("RAG UI opened") | |
| def clear_directory(directory): | |
| for filename in os.listdir(directory): | |
| file_path = os.path.join(directory, filename) | |
| try: | |
| if os.path.isfile(file_path) or os.path.islink(file_path): | |
| os.unlink(file_path) | |
| elif os.path.isdir(file_path): | |
| shutil.rmtree(file_path) | |
| except Exception as e: | |
| print(f'Failed to delete {file_path}. Reason: {e}') | |
| def clear_pdf_files(directory): | |
| for filename in os.listdir(directory): | |
| file_path = os.path.join(directory, filename) | |
| try: | |
| if os.path.isfile(file_path) and file_path.endswith('.pdf'): | |
| os.remove(file_path) | |
| except Exception as e: | |
| print(f'Failed to delete {file_path}. Reason: {e}') | |
| # clear_pdf_files("pdf") | |
| # clear_directory("index") | |
| # create sidebar and ask for openai api key if not set in secrets | |
| secrets_file_path = os.path.join(".streamlit", "secrets.toml") | |
| # if os.path.exists(secrets_file_path): | |
| # try: | |
| # if "OPENAI_API_KEY" in st.secrets: | |
| # os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"] | |
| # else: | |
| # print("OpenAI API Key not found in environment variables") | |
| # except FileNotFoundError: | |
| # print('Secrets file not found') | |
| # else: | |
| # print('Secrets file not found') | |
| # if not os.getenv('OPENAI_API_KEY', '').startswith("sk-"): | |
| # os.environ["OPENAI_API_KEY"] = st.sidebar.text_input( | |
| # "OpenAI API Key", type="password" | |
| # ) | |
| # else: | |
| # if st.sidebar.button("Embed Documents"): | |
| # st.sidebar.info("Embedding documents...") | |
| # try: | |
| # embed_pdf.embed_all_pdf_docs() | |
| # st.sidebar.info("Done!") | |
| # except Exception as e: | |
| # st.sidebar.error(e) | |
| # st.sidebar.error("Failed to embed documents.") | |
| os.environ["OPENAI_API_KEY"] = st.sidebar.text_input( | |
| "OpenAI API Key", type="password" | |
| ) | |
| st.sidebar.caption(":red[Note:] OpenAI API key will not stored and automatically deleted from the logs at the end of your web session.") | |
| st.sidebar.write("---") | |
| uploaded_file = st.sidebar.file_uploader("Upload Document", type=['pdf'], disabled=False) | |
| if uploaded_file is None: | |
| file_uploaded_bool = False | |
| else: | |
| file_uploaded_bool = True | |
| if st.sidebar.button("Embed Documents", disabled=not file_uploaded_bool): | |
| st.sidebar.info("Embedding documents...") | |
| try: | |
| embed_pdf.embed_all_inputed_pdf_docs(uploaded_file) | |
| # embed_pdf.embed_all_pdf_docs() | |
| st.sidebar.info("Done!") | |
| except Exception as e: | |
| st.sidebar.error(e) | |
| st.sidebar.error("Failed to embed documents.") | |
| st.sidebar.write("---") | |
| st.sidebar.markdown(''' | |
| Steps to run app | |
| 1. Paste OpenAI API Key and press Enter | |
| 2. Upload PDF file | |
| 3. Click on Embed Documents button | |
| 4. Choose RAG method | |
| 5. Start Chatting with your PDF | |
| ''') | |
| # create the app | |
| st.title("Chat with your PDF") | |
| # chosen_file = st.radio( | |
| # "Choose a file to search", embed_pdf.get_all_index_files(), index=0 | |
| # ) | |
| # check if openai api key is set | |
| if not os.getenv('OPENAI_API_KEY', '').startswith("sk-"): | |
| st.warning("Please enter your OpenAI API key!", icon="⚠") | |
| st.stop() | |
| # load the agent | |
| from llm_helper import convert_message, get_rag_chain, get_rag_fusion_chain | |
| rag_method_map = { | |
| 'Basic RAG': get_rag_chain, | |
| 'RAG Fusion': get_rag_fusion_chain | |
| } | |
| chosen_rag_method = st.radio( | |
| "Choose a RAG method", rag_method_map.keys(), index=0 | |
| ) | |
| get_rag_chain_func = rag_method_map[chosen_rag_method] | |
| ## get the chain WITHOUT the retrieval callback (not used) | |
| # custom_chain = get_rag_chain_func(chosen_file) | |
| # create the message history state | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # render older messages | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| # render the chat input | |
| prompt = st.chat_input("Enter your message...") | |
| if prompt: | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| # render the user's new message | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| make_discord_trace_text(prompt) | |
| # render the assistant's response | |
| with st.chat_message("assistant"): | |
| retrival_container = st.container() | |
| message_placeholder = st.empty() | |
| # retrieval_status = retrival_container.status("**Context Retrieval**") | |
| queried_questions = [] | |
| rendered_questions = set() | |
| def update_retrieval_status(): | |
| for q in queried_questions: | |
| if q in rendered_questions: | |
| continue | |
| rendered_questions.add(q) | |
| # retrieval_status.markdown(f"\n\n`- {q}`") | |
| retrival_container.markdown(f"\n\n`- {q}`") | |
| def retrieval_cb(qs): | |
| for q in qs: | |
| if q not in queried_questions: | |
| queried_questions.append(q) | |
| return qs | |
| # get the chain with the retrieval callback | |
| custom_chain = get_rag_chain_func(uploaded_file.name, retrieval_cb=retrieval_cb) | |
| if "messages" in st.session_state: | |
| chat_history = [convert_message(m) for m in st.session_state.messages[:-1]] | |
| else: | |
| chat_history = [] | |
| full_response = "" | |
| for response in custom_chain.stream( | |
| {"input": prompt, "chat_history": chat_history} | |
| ): | |
| if "output" in response: | |
| full_response += response["output"] | |
| else: | |
| full_response += response.content | |
| message_placeholder.markdown(full_response + "▌") | |
| update_retrieval_status() | |
| # retrival_container.update(state="complete") | |
| # retrieval_status.update(state="complete") | |
| message_placeholder.markdown(full_response) | |
| make_discord_trace_text(full_response) | |
| # add the full response to the message history | |
| st.session_state.messages.append({"role": "assistant", "content": full_response}) |