mohamedachraf commited on
Commit
8dc5c8f
·
1 Parent(s): d941be5

Add application file

Browse files
Files changed (2) hide show
  1. app.py +155 -29
  2. requirements.txt +4 -1
app.py CHANGED
@@ -5,10 +5,11 @@ nltk.download('punkt_tab')
5
 
6
  import gradio as gr
7
  from langchain.text_splitter import CharacterTextSplitter
8
- from langchain_community.document_loaders import UnstructuredFileLoader
9
  from langchain.vectorstores.faiss import FAISS
10
  from langchain.vectorstores.utils import DistanceStrategy
11
  from langchain_community.embeddings import HuggingFaceEmbeddings
 
12
 
13
  from langchain.chains import RetrievalQA
14
  from langchain.prompts.prompt import PromptTemplate
@@ -20,6 +21,8 @@ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
20
 
21
  from transformers import TextIteratorStreamer
22
  from threading import Thread
 
 
23
 
24
 
25
  # Prompt template
@@ -33,7 +36,16 @@ If you don't know the answer, just say "Hmm, I'm not sure." Don't try to make up
33
  Question: {question}
34
  Output:\n"""
35
 
 
 
 
 
 
 
 
 
36
  QA_PROMPT = PromptTemplate(template=template, input_variables=["question", "context"])
 
37
 
38
  # Load Phi-2 model from hugging face hub
39
  model_id = "microsoft/phi-2"
@@ -51,10 +63,14 @@ embeddings = HuggingFaceEmbeddings(
51
  )
52
 
53
 
54
- # Returns a faiss vector store retriever given a txt file
55
  def prepare_vector_store_retriever(filename):
56
- # Load data
57
- loader = UnstructuredFileLoader(filename)
 
 
 
 
58
  raw_documents = loader.load()
59
 
60
  # Split the text
@@ -69,25 +85,104 @@ def prepare_vector_store_retriever(filename):
69
  documents, embeddings, distance_strategy=DistanceStrategy.DOT_PRODUCT
70
  )
71
 
72
- return VectorStoreRetriever(vectorstore=vectorstore, search_kwargs={"k": 2})
73
-
74
-
75
- # Retrieveal QA chian
76
- def get_retrieval_qa_chain(text_file, hf_model):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  retriever = default_retriever
 
 
78
  if text_file != default_text_file:
79
- retriever = prepare_vector_store_retriever(text_file)
80
-
81
- chain = RetrievalQA.from_chain_type(
82
- llm=hf_model,
83
- retriever=retriever,
84
- chain_type_kwargs={"prompt": QA_PROMPT},
85
- )
86
- return chain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
 
89
  # Generates response using the question answering chain defined earlier
90
- def generate(question, answer, text_file, max_new_tokens):
91
  streamer = TextIteratorStreamer(
92
  tokenizer=tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=300.0
93
  )
@@ -102,56 +197,87 @@ def generate(question, answer, text_file, max_new_tokens):
102
  )
103
 
104
  hf_model = HuggingFacePipeline(pipeline=phi2_pipeline)
105
- qa_chain = get_retrieval_qa_chain(text_file, hf_model)
106
 
107
  query = f"{question}"
108
 
109
  if len(tokenizer.tokenize(query)) >= 512:
110
  query = "Repeat 'Your question is too long!'"
111
 
112
- thread = Thread(target=qa_chain.invoke, kwargs={"input": {"query": query}})
 
 
 
 
113
  thread.start()
114
 
115
  response = ""
116
  for token in streamer:
117
  response += token
118
  yield response.strip()
 
 
 
 
119
 
120
 
121
- # replaces the retreiver in the question answering chain whenever a new file is uploaded
122
  def upload_file(file):
123
- return file, file
 
 
 
 
 
 
124
 
125
 
126
  with gr.Blocks() as demo:
127
  gr.Markdown(
128
  """
129
  # Retrieval Augmented Generation with Phi-2: Question Answering demo
130
- ### This demo uses the Phi-2 language model and Retrieval Augmented Generation (RAG). It allows you to upload a txt file and ask the model questions related to the content of that file.
 
 
 
 
131
  ### If you don't have one, there is a txt file already loaded, the new Oppenheimer movie's entire wikipedia page. The movie came out very recently in July, 2023, so the Phi-2 model is not aware of it.
132
  The context size of the Phi-2 model is 2048 tokens, so even this medium size wikipedia page (11.5k tokens) does not fit in the context window.
133
  Retrieval Augmented Generation (RAG) enables us to retrieve just the few small chunks of the document that are relevant to the our query and inject it into our prompt.
134
- The model is then able to answer questions by incorporating knowledge from the newly provided document. RAG can be used with thousands of documents, but this demo is limited to just one txt file.
135
  """
136
  )
137
 
138
  default_text_file = "Oppenheimer-movie-wiki.txt"
139
- default_retriever = prepare_vector_store_retriever(default_text_file)
140
 
141
  text_file = gr.State(default_text_file)
142
 
143
  gr.Markdown(
144
- "## Upload a txt file or Use the Default 'Oppenheimer-movie-wiki.txt' that has already been loaded"
145
  )
146
 
147
  file_name = gr.Textbox(
148
- label="Loaded text file", value=default_text_file, lines=1, interactive=False
149
  )
150
  upload_button = gr.UploadButton(
151
- label="Click to upload a text file", file_types=["text"], file_count="single"
152
  )
153
  upload_button.upload(upload_file, upload_button, [file_name, text_file])
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  gr.Markdown("## Enter your question")
156
  tokens_slider = gr.Slider(
157
  8,
@@ -172,7 +298,7 @@ with gr.Blocks() as demo:
172
  with gr.Column():
173
  clear = gr.ClearButton([ques, ans])
174
 
175
- btn.click(fn=generate, inputs=[ques, ans, text_file, tokens_slider], outputs=[ans])
176
  examples = gr.Examples(
177
  examples=[
178
  "Who portrayed J. Robert Oppenheimer in the new Oppenheimer movie?",
 
5
 
6
  import gradio as gr
7
  from langchain.text_splitter import CharacterTextSplitter
8
+ from langchain_community.document_loaders import UnstructuredFileLoader, PyPDFLoader
9
  from langchain.vectorstores.faiss import FAISS
10
  from langchain.vectorstores.utils import DistanceStrategy
11
  from langchain_community.embeddings import HuggingFaceEmbeddings
12
+ from langchain.schema import Document
13
 
14
  from langchain.chains import RetrievalQA
15
  from langchain.prompts.prompt import PromptTemplate
 
21
 
22
  from transformers import TextIteratorStreamer
23
  from threading import Thread
24
+ import os
25
+ import tempfile
26
 
27
 
28
  # Prompt template
 
36
  Question: {question}
37
  Output:\n"""
38
 
39
+ # Multi-query generation prompt
40
+ multi_query_template = """You are an AI language model assistant. Your task is to generate 3
41
+ different versions of the given user question to retrieve relevant documents from a vector
42
+ database. By generating multiple perspectives on the user question, your goal is to help
43
+ the user overcome some of the limitations of the distance-based similarity search.
44
+ Provide these alternative questions separated by newlines.
45
+ Original question: {question}"""
46
+
47
  QA_PROMPT = PromptTemplate(template=template, input_variables=["question", "context"])
48
+ MULTI_QUERY_PROMPT = PromptTemplate(template=multi_query_template, input_variables=["question"])
49
 
50
  # Load Phi-2 model from hugging face hub
51
  model_id = "microsoft/phi-2"
 
63
  )
64
 
65
 
66
+ # Returns a faiss vector store retriever given a txt or pdf file
67
  def prepare_vector_store_retriever(filename):
68
+ # Load data based on file extension
69
+ if filename.lower().endswith('.pdf'):
70
+ loader = PyPDFLoader(filename)
71
+ else:
72
+ loader = UnstructuredFileLoader(filename)
73
+
74
  raw_documents = loader.load()
75
 
76
  # Split the text
 
85
  documents, embeddings, distance_strategy=DistanceStrategy.DOT_PRODUCT
86
  )
87
 
88
+ return VectorStoreRetriever(vectorstore=vectorstore, search_kwargs={"k": 2}), vectorstore
89
+
90
+
91
+ # Generate multiple queries for better retrieval
92
+ def generate_multiple_queries(question, hf_model):
93
+ """Generate multiple variations of the question for better retrieval"""
94
+ try:
95
+ result = hf_model.invoke(MULTI_QUERY_PROMPT.format(question=question))
96
+ queries = [q.strip() for q in result.split('\n') if q.strip()]
97
+ # Always include the original question
98
+ if question not in queries:
99
+ queries.insert(0, question)
100
+ return queries[:4] # Limit to 4 queries max
101
+ except:
102
+ # Fallback to original question if generation fails
103
+ return [question]
104
+
105
+ # Multi-query retrieval function
106
+ def multi_query_retrieve(queries, retriever):
107
+ """Retrieve documents using multiple queries and combine results"""
108
+ all_docs = []
109
+ seen_content = set()
110
+
111
+ for query in queries:
112
+ try:
113
+ docs = retriever.get_relevant_documents(query)
114
+ for doc in docs:
115
+ if doc.page_content not in seen_content:
116
+ all_docs.append(doc)
117
+ seen_content.add(doc.page_content)
118
+ except:
119
+ continue
120
+
121
+ return all_docs[:6] # Limit to top 6 unique documents
122
+
123
+ # Store Q&A pairs in vector database
124
+ def store_qa_pair(question, answer, vectorstore):
125
+ """Store the question-answer pair as a new document in the vector database"""
126
+ try:
127
+ qa_content = f"Question: {question}\nAnswer: {answer}"
128
+ qa_doc = Document(page_content=qa_content, metadata={"type": "qa_pair"})
129
+
130
+ # Add the Q&A pair to the existing vectorstore
131
+ vectorstore.add_documents([qa_doc])
132
+ return True
133
+ except Exception as e:
134
+ print(f"Error storing Q&A pair: {e}")
135
+ return False
136
+ # Retrieval QA chain with multi-query support
137
+ def get_retrieval_qa_chain(text_file, hf_model, use_multi_query=False):
138
  retriever = default_retriever
139
+ vectorstore = default_vectorstore
140
+
141
  if text_file != default_text_file:
142
+ retriever, vectorstore = prepare_vector_store_retriever(text_file)
143
+
144
+ if use_multi_query:
145
+ # Custom retrieval function for multi-query
146
+ class MultiQueryRetriever:
147
+ def __init__(self, retriever, vectorstore, hf_model):
148
+ self.retriever = retriever
149
+ self.vectorstore = vectorstore
150
+ self.hf_model = hf_model
151
+
152
+ def get_relevant_documents(self, query):
153
+ # Generate multiple queries
154
+ queries = generate_multiple_queries(query, self.hf_model)
155
+ # Retrieve documents using all queries
156
+ return multi_query_retrieve(queries, self.retriever)
157
+
158
+ multi_retriever = MultiQueryRetriever(retriever, vectorstore, hf_model)
159
+
160
+ # Custom chain that uses multi-query retrieval
161
+ class MultiQueryRetrievalQA:
162
+ def __init__(self, llm, retriever, prompt):
163
+ self.llm = llm
164
+ self.retriever = retriever
165
+ self.prompt = prompt
166
+
167
+ def invoke(self, input_dict):
168
+ query = input_dict["query"]
169
+ docs = self.retriever.get_relevant_documents(query)
170
+ context = "\n\n".join([doc.page_content for doc in docs])
171
+ prompt_text = self.prompt.format(context=context, question=query)
172
+ return self.llm.invoke(prompt_text)
173
+
174
+ return MultiQueryRetrievalQA(hf_model, multi_retriever, QA_PROMPT), vectorstore
175
+ else:
176
+ chain = RetrievalQA.from_chain_type(
177
+ llm=hf_model,
178
+ retriever=retriever,
179
+ chain_type_kwargs={"prompt": QA_PROMPT},
180
+ )
181
+ return chain, vectorstore
182
 
183
 
184
  # Generates response using the question answering chain defined earlier
185
+ def generate(question, answer, text_file, max_new_tokens, use_multi_query, store_qa):
186
  streamer = TextIteratorStreamer(
187
  tokenizer=tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=300.0
188
  )
 
197
  )
198
 
199
  hf_model = HuggingFacePipeline(pipeline=phi2_pipeline)
200
+ qa_chain, vectorstore = get_retrieval_qa_chain(text_file, hf_model, use_multi_query)
201
 
202
  query = f"{question}"
203
 
204
  if len(tokenizer.tokenize(query)) >= 512:
205
  query = "Repeat 'Your question is too long!'"
206
 
207
+ def run_chain():
208
+ result = qa_chain.invoke({"input": {"query": query}} if hasattr(qa_chain, 'retriever') else {"query": query})
209
+ return result
210
+
211
+ thread = Thread(target=run_chain)
212
  thread.start()
213
 
214
  response = ""
215
  for token in streamer:
216
  response += token
217
  yield response.strip()
218
+
219
+ # Store Q&A pair if requested
220
+ if store_qa and response.strip() and "Your question is too long!" not in response:
221
+ store_qa_pair(question, response.strip(), vectorstore)
222
 
223
 
224
+ # replaces the retriever in the question answering chain whenever a new file is uploaded
225
  def upload_file(file):
226
+ if file is not None:
227
+ # Save uploaded file to temporary location
228
+ temp_path = os.path.join(tempfile.gettempdir(), file.name)
229
+ with open(temp_path, 'wb') as f:
230
+ f.write(file.read())
231
+ return file.name, temp_path
232
+ return None, None
233
 
234
 
235
  with gr.Blocks() as demo:
236
  gr.Markdown(
237
  """
238
  # Retrieval Augmented Generation with Phi-2: Question Answering demo
239
+ ### This demo uses the Phi-2 language model and Retrieval Augmented Generation (RAG). It allows you to upload a txt or PDF file and ask the model questions related to the content of that file.
240
+ ### Features:
241
+ - Support for both PDF and text files
242
+ - Multi-query RAG for improved retrieval
243
+ - Store Q&A pairs in vector database for future reference
244
  ### If you don't have one, there is a txt file already loaded, the new Oppenheimer movie's entire wikipedia page. The movie came out very recently in July, 2023, so the Phi-2 model is not aware of it.
245
  The context size of the Phi-2 model is 2048 tokens, so even this medium size wikipedia page (11.5k tokens) does not fit in the context window.
246
  Retrieval Augmented Generation (RAG) enables us to retrieve just the few small chunks of the document that are relevant to the our query and inject it into our prompt.
247
+ The model is then able to answer questions by incorporating knowledge from the newly provided document. RAG can be used with thousands of documents, but this demo is limited to just one file at a time.
248
  """
249
  )
250
 
251
  default_text_file = "Oppenheimer-movie-wiki.txt"
252
+ default_retriever, default_vectorstore = prepare_vector_store_retriever(default_text_file)
253
 
254
  text_file = gr.State(default_text_file)
255
 
256
  gr.Markdown(
257
+ "## Upload a txt or PDF file or Use the Default 'Oppenheimer-movie-wiki.txt' that has already been loaded"
258
  )
259
 
260
  file_name = gr.Textbox(
261
+ label="Loaded file", value=default_text_file, lines=1, interactive=False
262
  )
263
  upload_button = gr.UploadButton(
264
+ label="Click to upload a text or PDF file", file_types=[".txt", ".pdf"], file_count="single"
265
  )
266
  upload_button.upload(upload_file, upload_button, [file_name, text_file])
267
 
268
+ gr.Markdown("## RAG Settings")
269
+ with gr.Row():
270
+ use_multi_query = gr.Checkbox(
271
+ label="Use Multi-Query RAG",
272
+ value=False,
273
+ info="Generate multiple query variations for better retrieval"
274
+ )
275
+ store_qa = gr.Checkbox(
276
+ label="Store Q&A pairs",
277
+ value=True,
278
+ info="Add question-answer pairs to vector database"
279
+ )
280
+
281
  gr.Markdown("## Enter your question")
282
  tokens_slider = gr.Slider(
283
  8,
 
298
  with gr.Column():
299
  clear = gr.ClearButton([ques, ans])
300
 
301
+ btn.click(fn=generate, inputs=[ques, ans, text_file, tokens_slider, use_multi_query, store_qa], outputs=[ans])
302
  examples = gr.Examples(
303
  examples=[
304
  "Who portrayed J. Robert Oppenheimer in the new Oppenheimer movie?",
requirements.txt CHANGED
@@ -9,4 +9,7 @@ langchain-community==0.0.13
9
  unstructured==0.12.2
10
  huggingface_hub>=0.20.0
11
  gradio
12
- nltk
 
 
 
 
9
  unstructured==0.12.2
10
  huggingface_hub>=0.20.0
11
  gradio
12
+ nltk
13
+ pypdf2
14
+ pdfplumber
15
+ python-multipart