Mythus commited on
Commit
22c5eeb
·
verified ·
1 Parent(s): fd2c532

Upload 7 files

Browse files
Files changed (7) hide show
  1. LICENSE +21 -0
  2. README.md +26 -12
  3. app_chat.py +112 -0
  4. constants.py +17 -0
  5. langchain_utils.py +103 -0
  6. requirements.txt +6 -0
  7. search_indexing.py +45 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Elton Vieira
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,26 @@
1
- ---
2
- title: BooksCheating
3
- emoji: 📈
4
- colorFrom: red
5
- colorTo: purple
6
- sdk: streamlit
7
- sdk_version: 1.34.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ![image](https://github.com/ergv03/chat-with-pdf-llm/assets/23053920/969edf03-4451-4909-98d9-601d92a17e83)
3
+
4
+ ## Overview:
5
+
6
+ Simple web-based chat app, built using [Streamlit](https://streamlit.io/) and [Langchain](https://python.langchain.com/). The app backend follows the Retrieval Augmented Generation (RAG) framework.
7
+
8
+ Allows the user to provide a list of PDFs, and ask questions to a LLM (today only OpenAI GPT is implemented) that can be answered by these PDF documents.
9
+
10
+ User needs to provide their own OpenAI API key.
11
+
12
+ ## Instalation:
13
+
14
+ Just clone the repo and install the requirements using ```pip install -r requirements.txt```
15
+
16
+ ## How to run locally:
17
+
18
+ Run ```streamlit run chat_app.py``` in your terminal.
19
+
20
+ Add the URLs of the PDF documents that are relevant to your queries, and start chatting with the bot.
21
+
22
+ ## How it works:
23
+
24
+ The provided PDFs will be downloaded and properly split into chunks, and finally embedding vectors for each chunk will be generated using OpenAI service. These vectors are then indexed using FAISS, and can be quickly retrieved.
25
+
26
+ As the user interacts with the bot, new relevant document chunks/snippets are retrieved and added to the session memory, alongside the past few messages. These snippets and messages are part of the prompt sent to the LLM; this way, the model will have as context not just the latest message and retrieved snippet, but past ones as well.
app_chat.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ from constants import search_number_messages
4
+ from langchain_utils import initialize_chat_conversation
5
+ from search_indexing import download_and_index_pdf
6
+ import re
7
+
8
+
9
+ def remove_url(url_to_remove):
10
+ """
11
+ Remove URLs from the session_state. Triggered by the respective button
12
+ """
13
+ if url_to_remove in st.session_state.urls:
14
+ st.session_state.urls.remove(url_to_remove)
15
+
16
+
17
+ # Page title
18
+ st.set_page_config(page_title='Talk with PDFs using LLMs - Beta')
19
+ st.title('Talk with PDFs using LLMs - (Beta)')
20
+
21
+ # Initialize the faiss_index key in the session state. This can be used to avoid having to download and embed the same PDF
22
+ # every time the user asks a question
23
+ if 'faiss_index' not in st.session_state:
24
+ st.session_state['faiss_index'] = {
25
+ 'indexed_urls': [],
26
+ 'index': None
27
+ }
28
+
29
+ # Initialize conversation memory used by Langchain
30
+ if 'conversation_memory' not in st.session_state:
31
+ st.session_state['conversation_memory'] = None
32
+
33
+ # Initialize chat history used by StreamLit (for display purposes)
34
+ if "messages" not in st.session_state:
35
+ st.session_state.messages = []
36
+
37
+ # Store the URLs added by the user in the UI
38
+ if 'urls' not in st.session_state:
39
+ st.session_state.urls = []
40
+
41
+ with st.sidebar:
42
+
43
+ openai_api_key = st.text_input('Step 1 - OpenAI API Key:', type='password')
44
+
45
+ # Add/Remove URLs form
46
+ with st.form('urls-form', clear_on_submit=True):
47
+ url = st.text_input('Step 2 - URLs to relevant PDFs: ')
48
+ add_url_button = st.form_submit_button('Add')
49
+ if add_url_button:
50
+ if url not in st.session_state.urls:
51
+ st.session_state.urls.append(url)
52
+
53
+ # Display a container with the URLs added by the user so far
54
+ with st.container():
55
+ if st.session_state.urls:
56
+ st.header('URLs added:')
57
+ for url in st.session_state.urls:
58
+ st.write(url)
59
+ st.button(label='Remove', key=f"Remove {url}", on_click=remove_url, kwargs={'url_to_remove': url})
60
+ st.divider()
61
+
62
+ # Display chat messages from history on app rerun
63
+ for message in st.session_state.messages:
64
+ with st.chat_message(message["role"]):
65
+ st.markdown(message["content"])
66
+
67
+ # React to user input
68
+ if query_text := st.chat_input("Your message"):
69
+
70
+ os.environ['OPENAI_API_KEY'] = openai_api_key
71
+
72
+ # Display user message in chat message container, and append to session state
73
+ st.chat_message("user").markdown(query_text)
74
+ st.session_state.messages.append({"role": "user", "content": query_text})
75
+
76
+ # Check if FAISS index already exists, or if it needs to be created as it includes new URLs
77
+ session_urls = st.session_state.urls
78
+ if st.session_state['faiss_index']['index'] is None or set(st.session_state['faiss_index']['indexed_urls']) != set(session_urls):
79
+ st.session_state['faiss_index']['indexed_urls'] = session_urls
80
+ with st.spinner('Downloading and indexing PDFs...'):
81
+ faiss_index = download_and_index_pdf(session_urls)
82
+ st.session_state['faiss_index']['index'] = faiss_index
83
+ else:
84
+ faiss_index = st.session_state['faiss_index']['index']
85
+
86
+ # Check if conversation memory has already been initialized and is part of the session state
87
+ if st.session_state['conversation_memory'] is None:
88
+ conversation = initialize_chat_conversation(faiss_index)
89
+ st.session_state['conversation_memory'] = conversation
90
+ else:
91
+ conversation = st.session_state['conversation_memory']
92
+
93
+ # Search PDF snippets using the last few user messages
94
+ user_messages_history = [message['content'] for message in st.session_state.messages[-search_number_messages:] if message['role'] == 'user']
95
+ user_messages_history = '\n'.join(user_messages_history)
96
+
97
+ with st.spinner('Querying OpenAI GPT...'):
98
+ response = conversation.predict(input=query_text, user_messages_history=user_messages_history)
99
+
100
+ # Display assistant response in chat message container
101
+ with st.chat_message("assistant"):
102
+ st.markdown(response)
103
+ snippet_memory = conversation.memory.memories[1]
104
+ for page_number, snippet in zip(snippet_memory.pages, snippet_memory.snippets):
105
+ with st.expander(f'Snippet from page {page_number + 1}'):
106
+ # Remove the <START> and <END> tags from the snippets before displaying them
107
+ snippet = re.sub("<START_SNIPPET_PAGE_\d+>", '', snippet)
108
+ snippet = re.sub("<END_SNIPPET_PAGE_\d+>", '', snippet)
109
+ st.markdown(snippet)
110
+
111
+ # Add assistant response to chat history
112
+ st.session_state.messages.append({"role": "assistant", "content": response})
constants.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Number of snippets that will be added to the prompt. Too many snippets and you risk both the prompt going over the
2
+ # token limit, and the model not being able to find the correct answer
3
+ prompt_number_snippets = 3
4
+
5
+ # GPT related constants
6
+ gpt_model_to_use = 'gpt-4'
7
+ gpt_max_tokens = 1000
8
+
9
+ # Number of past user messages that will be used to search relevant snippets
10
+ search_number_messages = 4
11
+
12
+ # PDF Chunking constants
13
+ chunk_size = 500
14
+ chunk_overlap = 50
15
+
16
+ # Number of snippets to be retrieved by FAISS
17
+ number_snippets_to_retrieve = 3
langchain_utils.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain import FAISS
2
+ from langchain.chat_models import ChatOpenAI
3
+ from langchain.chains import ConversationChain
4
+ from langchain.memory import ConversationBufferWindowMemory, CombinedMemory
5
+ from langchain import PromptTemplate
6
+ from constants import prompt_number_snippets, gpt_model_to_use, gpt_max_tokens
7
+ from search_indexing import search_faiss_index
8
+
9
+
10
+ class SnippetsBufferWindowMemory(ConversationBufferWindowMemory):
11
+ """
12
+ MemoryBuffer used to hold the document snippets. Inherits from ConversationBufferWindowMemory, and overwrites the
13
+ load_memory_variables method
14
+ """
15
+
16
+ index: FAISS = None
17
+ pages: list = []
18
+ memory_key = 'snippets'
19
+ snippets: list = []
20
+
21
+ def __init__(self, *args, **kwargs):
22
+ ConversationBufferWindowMemory.__init__(self, *args, **kwargs)
23
+ self.index = kwargs['index']
24
+
25
+ def load_memory_variables(self, inputs) -> dict:
26
+ """
27
+ Based on the user inputs, search the index and add the similar snippets to memory (but only if they aren't in the
28
+ memory already)
29
+ """
30
+
31
+ # Search snippets
32
+ similar_snippets = search_faiss_index(self.index, inputs['user_messages_history'])
33
+ # In order to respect the buffer size and make its pruning work, need to reverse the list, and then un-reverse it later
34
+ # This way, the most relevant snippets are kept at the start of the list
35
+ self.snippets = [snippet for snippet in reversed(self.snippets)]
36
+ self.pages = [page for page in reversed(self.pages)]
37
+
38
+ for snippet in similar_snippets:
39
+ page_number = snippet.metadata['page']
40
+ # Load into memory only new snippets
41
+ snippet_to_add = f"The following snippet was extracted from the following document: "
42
+ if snippet.metadata['title'] == snippet.metadata['source']:
43
+ snippet_to_add += f"{snippet.metadata['source']}\n"
44
+ else:
45
+ snippet_to_add += f"[{snippet.metadata['title']}]({snippet.metadata['source']})\n"
46
+
47
+ snippet_to_add += f"<START_SNIPPET_PAGE_{page_number + 1}>\n"
48
+ snippet_to_add += f"{snippet.page_content}\n"
49
+ snippet_to_add += f"<END_SNIPPET_PAGE_{page_number + 1}>\n"
50
+ if snippet_to_add not in self.snippets:
51
+ self.pages.append(page_number)
52
+ self.snippets.append(snippet_to_add)
53
+
54
+ # Reverse list of snippets and pages, in order to keep the most relevant at the top
55
+ # Also prune the list to keep the buffer within the define size (k)
56
+ self.snippets = [snippet for snippet in reversed(self.snippets)][:self.k]
57
+ self.pages = [page for page in reversed(self.pages)][:self.k]
58
+ to_return = ''.join(self.snippets)
59
+
60
+ return {'snippets': to_return}
61
+
62
+
63
+ def construct_conversation(prompt: str, llm, memory) -> ConversationChain:
64
+ """
65
+ Construct a ConversationChain object
66
+ """
67
+
68
+ prompt = PromptTemplate.from_template(
69
+ template=prompt,
70
+ )
71
+
72
+ conversation = ConversationChain(
73
+ llm=llm,
74
+ memory=memory,
75
+ verbose=False,
76
+ prompt=prompt
77
+ )
78
+
79
+ return conversation
80
+
81
+
82
+ def initialize_chat_conversation(index: FAISS,
83
+ model_to_use: str = gpt_model_to_use,
84
+ max_tokens: int = gpt_max_tokens) -> ConversationChain:
85
+
86
+ prompt_header = """You are an expert, tasked with helping customers with their questions. They will ask you questions and provide technical snippets that may or may not contain the answer, and it's your job to find the answer if possible, while taking into account the entire conversation context.
87
+ The following snippets can be used to help you answer the questions:
88
+ {snippets}
89
+ The following is a friendly conversation between a customer and you. Please answer the customer's needs based on the provided snippets and the conversation history. Make sure to take the previous messages in consideration, as they contain additional context.
90
+ If the provided snippets don't include the answer, please say so, and don't try to make up an answer instead. Include in your reply the title of the document and the page from where your answer is coming from, if applicable.
91
+
92
+ {history}
93
+ Customer: {input}
94
+ """
95
+
96
+ llm = ChatOpenAI(model_name=model_to_use, max_tokens=max_tokens)
97
+ conv_memory = ConversationBufferWindowMemory(k=3, input_key="input")
98
+ snippets_memory = SnippetsBufferWindowMemory(k=prompt_number_snippets, index=index, memory_key='snippets', input_key="snippets")
99
+ memory = CombinedMemory(memories=[conv_memory, snippets_memory])
100
+
101
+ conversation = construct_conversation(prompt_header, llm, memory)
102
+
103
+ return conversation
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ faiss-cpu==1.7.4
2
+ langchain==0.0.248
3
+ openai==0.27.7
4
+ streamlit==1.25.0
5
+ pypdfium2==4.18.0
6
+ tiktoken==0.4.0
search_indexing.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain import FAISS
2
+ from langchain.document_loaders import PyPDFium2Loader
3
+ from langchain.embeddings import OpenAIEmbeddings
4
+ from langchain.text_splitter import CharacterTextSplitter
5
+ import pypdfium2 as pdfium
6
+ from constants import chunk_size, chunk_overlap, number_snippets_to_retrieve
7
+
8
+
9
+ def download_and_index_pdf(urls: list[str]) -> FAISS:
10
+ """
11
+ Download and index a list of PDFs based on the URLs
12
+ """
13
+
14
+ def __update_metadata(pages, url):
15
+ """
16
+ Add to the document metadata the title and original URL
17
+ """
18
+ for page in pages:
19
+ pdf = pdfium.PdfDocument(page.metadata['source'])
20
+ title = pdf.get_metadata_dict().get('Title', url)
21
+ page.metadata['source'] = url
22
+ page.metadata['title'] = title
23
+ return pages
24
+
25
+ all_pages = []
26
+ for url in urls:
27
+ loader = PyPDFium2Loader(url)
28
+ splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
29
+ pages = loader.load_and_split(splitter)
30
+ pages = __update_metadata(pages, url)
31
+ all_pages += pages
32
+
33
+ faiss_index = FAISS.from_documents(all_pages, OpenAIEmbeddings())
34
+
35
+ return faiss_index
36
+
37
+
38
+ def search_faiss_index(faiss_index: FAISS, query: str, top_k: int = number_snippets_to_retrieve) -> list:
39
+ """
40
+ Search a FAISS index, using the passed query
41
+ """
42
+
43
+ docs = faiss_index.similarity_search(query, k=top_k)
44
+
45
+ return docs