| import os, tempfile |
| |
| from pathlib import Path |
| import traceback |
| from langchain.chains import RetrievalQA, ConversationalRetrievalChain |
| from langchain.embeddings import OpenAIEmbeddings |
| from langchain.vectorstores import Chroma |
| from langchain import OpenAI |
| from langchain.chat_models import ChatOpenAI |
| from langchain.document_loaders import DirectoryLoader |
| from langchain.text_splitter import CharacterTextSplitter |
| from langchain.vectorstores import Chroma |
| from langchain.embeddings.openai import OpenAIEmbeddings |
| from langchain.memory import ConversationBufferMemory |
| from langchain.memory.chat_message_histories import StreamlitChatMessageHistory |
| from dotenv import load_dotenv |
| import streamlit as st |
|
|
| load_dotenv() |
| TMP_DIR = Path(__file__).resolve().parent.joinpath('data', 'tmp') |
| LOCAL_VECTOR_STORE_DIR = Path(__file__).resolve().parent.joinpath('data', 'vector_store') |
|
|
|
|
|
|
| |
| os.makedirs(TMP_DIR, exist_ok=True) |
| os.makedirs(LOCAL_VECTOR_STORE_DIR, exist_ok=True) |
|
|
|
|
|
|
| os.makedirs(TMP_DIR, exist_ok=True) |
| os.makedirs(LOCAL_VECTOR_STORE_DIR, exist_ok=True) |
| st.set_page_config(page_title="RAG") |
| st.title("Retrieval Augmented Generation Engine") |
|
|
| openai_api_key = os.environ.get('OPENAI_API_KEY') |
| st.session_state.openai_api_key = openai_api_key |
|
|
| def load_documents(): |
| loader = DirectoryLoader(TMP_DIR.as_posix(), glob='**/*.pdf') |
| documents = loader.load() |
| return documents |
|
|
| def split_documents(documents): |
| text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) |
| texts = text_splitter.split_documents(documents) |
| return texts |
|
|
| def embeddings_on_local_vectordb(): |
| |
| |
| vectordb=Chroma(persist_directory=LOCAL_VECTOR_STORE_DIR.as_posix(), embedding_function=OpenAIEmbeddings()) |
| vectordb.persist() |
| retriever = vectordb.as_retriever(search_kwargs={'k': 5}) |
| return retriever |
|
|
| |
| |
| |
| |
| |
| |
|
|
| def query_llm(retriever, query): |
| try: |
| qa_chain = ConversationalRetrievalChain.from_llm( |
| llm=ChatOpenAI(temperature=0, openai_api_key=st.session_state.openai_api_key), |
| retriever=retriever, |
| return_source_documents=True, |
| ) |
| result = qa_chain({'question': query, 'chat_history': st.session_state.messages}) |
| result = result.get('answer') |
| except Exception as e: |
| print(f"Exception {e} with traceback : {traceback.format_exc() } occurred for API key: {st.session_state.openai_api_key}") |
| result = "" |
| st.session_state.messages.append((query, result)) |
| return result |
|
|
| def input_fields(): |
| |
| with st.sidebar: |
| |
| openai_key = st.text_input("OpenAI API key", type="password") |
| if openai_key != "": |
| st.session_state.openai_api_key = openai_key |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| st.session_state.source_docs = st.file_uploader(label="Upload Documents", type="pdf", accept_multiple_files=True) |
| |
|
|
| retriever = embeddings_on_local_vectordb() |
|
|
| def process_documents(): |
| |
| if not st.session_state.openai_api_key or not st.session_state.source_docs: |
| st.warning(f"Please upload the documents and provide the missing fields.") |
| else: |
| try: |
| for source_doc in st.session_state.source_docs: |
| |
| with tempfile.NamedTemporaryFile(delete=False, dir=TMP_DIR.as_posix(), suffix='.pdf') as tmp_file: |
| tmp_file.write(source_doc.read()) |
| |
| documents = load_documents() |
| |
| for _file in TMP_DIR.iterdir(): |
| temp_file = TMP_DIR.joinpath(_file) |
| temp_file.unlink() |
| |
| texts = split_documents(documents) |
| |
| print(f"Adding {len(texts)} texts to vector DB") |
| retriever.add_texts(texts) |
| retriever.persist() |
| |
| |
| |
| |
| |
| except Exception as e: |
| st.error(f"An error occurred: {e}") |
|
|
| def boot(): |
| |
| input_fields() |
| |
| st.button("Submit Documents", on_click=process_documents) |
| |
| if "messages" not in st.session_state: |
| st.session_state.messages = [] |
| |
| for message in st.session_state.messages: |
| st.chat_message('human').write(message[0]) |
| st.chat_message('ai').write(message[1]) |
| |
| if query := st.chat_input(): |
| st.chat_message("human").write(query) |
| response = query_llm(retriever, query) |
| st.chat_message("ai").write(response) |
|
|
| if __name__ == '__main__': |
| |
| boot() |
| |