rbbist commited on
Commit
29058d8
Β·
verified Β·
1 Parent(s): 548663f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -123
app.py CHANGED
@@ -2,8 +2,7 @@ import streamlit as st
2
  import os
3
  import tempfile
4
  from typing import List, Optional
5
- from pathlib import Path
6
- import time
7
 
8
  # Core libraries
9
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
@@ -13,8 +12,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
13
  from langchain.schema import Document
14
  from langchain import PromptTemplate
15
  from langchain.chains import RetrievalQA
16
- from langchain_pinecone import PineconeVectorStore
17
- from pinecone import Pinecone as PineconeClient
18
 
19
  # Document loaders
20
  from langchain.document_loaders import PyPDFLoader
@@ -81,54 +79,39 @@ if 'chat_history' not in st.session_state:
81
  def setup_llm(model_name="google/flan-t5-small"):
82
  """Setup the language model for text generation"""
83
  with st.spinner("πŸ€– Loading language model..."):
84
- tokenizer = AutoTokenizer.from_pretrained(model_name)
85
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
86
-
87
- pipe = pipeline(
88
- "text2text-generation",
89
- model=model,
90
- tokenizer=tokenizer,
91
- max_new_tokens=300,
92
- temperature=0.3,
93
- do_sample=True
94
- )
95
-
96
- llm = HuggingFacePipeline(pipeline=pipe)
97
- return llm
 
 
 
 
 
98
 
99
  @st.cache_resource
100
  def setup_embeddings(model_name="all-MiniLM-L6-v2"):
101
  """Setup the embedding model for vector generation"""
102
  with st.spinner("πŸ”’ Loading embedding model..."):
103
- embeddings = HuggingFaceEmbeddings(model_name=model_name)
104
- return embeddings
105
-
106
- def setup_pinecone(api_key, environment="us-east-1", index_name="pdf-rag-system"):
107
- """Setup Pinecone vector database connection"""
108
- try:
109
- os.environ["PINECONE_API_KEY"] = api_key
110
- os.environ["PINECONE_ENVIRONMENT"] = environment
111
-
112
- pc = PineconeClient(api_key=api_key, environment=environment)
113
-
114
- existing_indexes = pc.list_indexes()
115
-
116
- if index_name not in [idx.name for idx in existing_indexes]:
117
- st.info(f"πŸ“ Creating new index: {index_name}")
118
- pc.create_index(
119
- name=index_name,
120
- dimension=384,
121
- metric='cosine'
122
- )
123
- time.sleep(30) # Wait for index to be ready
124
-
125
- return pc, index_name
126
- except Exception as e:
127
- st.error(f"❌ Error setting up Pinecone: {e}")
128
- return None, None
129
 
130
- def process_uploaded_files(uploaded_files, embeddings, pc, index_name):
131
- """Process uploaded PDF files and store in vector database"""
132
  if not uploaded_files:
133
  return None, []
134
 
@@ -145,6 +128,11 @@ def process_uploaded_files(uploaded_files, embeddings, pc, index_name):
145
  # Load PDF
146
  loader = PyPDFLoader(tmp_file_path)
147
  docs = loader.load()
 
 
 
 
 
148
  documents.extend(docs)
149
 
150
  # Clean up temporary file
@@ -172,28 +160,23 @@ def process_uploaded_files(uploaded_files, embeddings, pc, index_name):
172
  for i, text in enumerate(text_chunks):
173
  text.metadata.update({
174
  "chunk_id": i,
175
- "source_file": text.metadata.get("source", "unknown"),
176
  "chunk_size": len(text.page_content)
177
  })
178
 
179
  st.info(f"βœ‚οΈ Created {len(text_chunks)} text chunks")
180
 
181
- # Store in Pinecone
182
  try:
183
- vectorstore = PineconeVectorStore.from_documents(
184
- documents=text_chunks,
185
- embedding=embeddings,
186
- index_name=index_name
187
- )
188
- st.success(f"βœ… Successfully stored {len(text_chunks)} chunks in vector database!")
189
  return vectorstore, text_chunks
190
  except Exception as e:
191
- st.error(f"❌ Error storing in vector database: {e}")
192
  return None, []
193
 
194
  def create_qa_chain(llm, vectorstore, k=5):
195
  """Create a question-answering chain with retrieval"""
196
- if not vectorstore:
197
  return None
198
 
199
  prompt_template = """Use the following context to answer the question. If you cannot find the answer in the context, say "I cannot find this information in the provided documents."
@@ -209,16 +192,18 @@ Answer: Let me analyze the provided context to answer your question."""
209
  input_variables=["context", "question"]
210
  )
211
 
212
- qa_chain = RetrievalQA.from_chain_type(
213
- llm=llm,
214
- chain_type="stuff",
215
- retriever=vectorstore.as_retriever(search_kwargs={"k": k}),
216
- chain_type_kwargs={"prompt": PROMPT},
217
- return_source_documents=True,
218
- verbose=True
219
- )
220
-
221
- return qa_chain
 
 
222
 
223
  def ask_question(qa_chain, question):
224
  """Ask a question and get an answer with sources"""
@@ -240,6 +225,18 @@ def ask_question(qa_chain, question):
240
  st.error(f"❌ Error processing question: {e}")
241
  return None
242
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  # Main App Interface
244
  def main():
245
  st.markdown('<h1 class="main-header">πŸ“š PDF RAG System</h1>', unsafe_allow_html=True)
@@ -249,21 +246,6 @@ def main():
249
  with st.sidebar:
250
  st.markdown('<h2 class="sidebar-header">βš™οΈ Configuration</h2>', unsafe_allow_html=True)
251
 
252
- # Pinecone configuration
253
- st.subheader("🌲 Pinecone Settings")
254
- pinecone_api_key = st.text_input(
255
- "Pinecone API Key",
256
- type="password",
257
- help="Enter your Pinecone API key",
258
- value=st.secrets.get("PINECONE_API_KEY", "") if "PINECONE_API_KEY" in st.secrets else ""
259
- )
260
-
261
- index_name = st.text_input(
262
- "Index Name",
263
- value="pdf-rag-system",
264
- help="Name for your Pinecone index"
265
- )
266
-
267
  # Model configuration
268
  st.subheader("πŸ€– Model Settings")
269
  llm_model = st.selectbox(
@@ -274,7 +256,7 @@ def main():
274
 
275
  embedding_model = st.selectbox(
276
  "Embedding Model",
277
- ["all-MiniLM-L6-v2", "all-mpnet-base-v2"],
278
  help="Choose the embedding model"
279
  )
280
 
@@ -285,6 +267,19 @@ def main():
285
  value=5,
286
  help="How many relevant chunks to use for answering questions"
287
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
  # Main content area
290
  col1, col2 = st.columns([1, 1])
@@ -301,34 +296,32 @@ def main():
301
  if st.button("πŸš€ Process Documents", type="primary"):
302
  if not uploaded_files:
303
  st.warning("Please upload at least one PDF file.")
304
- elif not pinecone_api_key:
305
- st.warning("Please enter your Pinecone API key.")
306
  else:
307
  with st.spinner("Processing documents..."):
308
  # Setup models
309
  llm = setup_llm(llm_model)
310
  embeddings = setup_embeddings(embedding_model)
311
 
312
- # Setup Pinecone
313
- pc, idx_name = setup_pinecone(pinecone_api_key, index_name=index_name)
314
-
315
- if pc:
316
  # Process files
317
- vectorstore, text_chunks = process_uploaded_files(
318
- uploaded_files, embeddings, pc, idx_name
319
- )
320
 
321
  if vectorstore:
322
  # Create QA chain
323
  qa_chain = create_qa_chain(llm, vectorstore, k=retrieval_k)
324
 
325
- # Store in session state
326
- st.session_state.qa_chain = qa_chain
327
- st.session_state.vectorstore = vectorstore
328
- st.session_state.documents_processed = True
329
-
330
- st.balloons()
331
- st.success("πŸŽ‰ Documents processed successfully! You can now ask questions.")
 
 
 
 
 
332
 
333
  with col2:
334
  st.subheader("πŸ’¬ Ask Questions")
@@ -340,31 +333,46 @@ def main():
340
  help="Ask any question about your uploaded documents"
341
  )
342
 
343
- if st.button("πŸ” Get Answer"):
344
- if question:
345
- with st.spinner("Searching for answer..."):
346
- result = ask_question(st.session_state.qa_chain, question)
347
-
348
- if result:
349
- # Add to chat history
350
- st.session_state.chat_history.append({
351
- "question": question,
352
- "answer": result["answer"],
353
- "sources": result["source_documents"]
354
- })
355
 
356
- # Display answer
357
- st.subheader("πŸ’‘ Answer:")
358
- st.write(result["answer"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
 
360
- # Display sources
361
- if result["source_documents"]:
362
- st.subheader("πŸ“š Sources:")
363
- for i, doc in enumerate(result["source_documents"][:3]):
364
- with st.expander(f"Source {i+1}: {doc.metadata.get('source', 'Unknown')}"):
365
- st.write(doc.page_content[:500] + "..." if len(doc.page_content) > 500 else doc.page_content)
366
- else:
367
- st.warning("Please enter a question.")
368
  else:
369
  st.info("πŸ‘† Please upload and process documents first to start asking questions.")
370
 
@@ -380,7 +388,7 @@ def main():
380
  if chat['sources']:
381
  st.write("**Sources:**")
382
  for j, doc in enumerate(chat['sources'][:2]): # Show top 2 sources
383
- st.write(f"{j+1}. {doc.metadata.get('source', 'Unknown')}")
384
 
385
  # Clear session button
386
  if st.session_state.documents_processed:
@@ -390,7 +398,7 @@ def main():
390
  st.session_state.documents_processed = False
391
  st.session_state.chat_history = []
392
  st.success("Session cleared! You can upload new documents.")
393
- st.experimental_rerun()
394
 
395
  if __name__ == "__main__":
396
  main()
 
2
  import os
3
  import tempfile
4
  from typing import List, Optional
5
+ import pickle
 
6
 
7
  # Core libraries
8
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
 
12
  from langchain.schema import Document
13
  from langchain import PromptTemplate
14
  from langchain.chains import RetrievalQA
15
+ from langchain.vectorstores import FAISS
 
16
 
17
  # Document loaders
18
  from langchain.document_loaders import PyPDFLoader
 
79
  def setup_llm(model_name="google/flan-t5-small"):
80
  """Setup the language model for text generation"""
81
  with st.spinner("πŸ€– Loading language model..."):
82
+ try:
83
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
84
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
85
+
86
+ pipe = pipeline(
87
+ "text2text-generation",
88
+ model=model,
89
+ tokenizer=tokenizer,
90
+ max_new_tokens=300,
91
+ temperature=0.3,
92
+ do_sample=True,
93
+ device_map="auto" if st.secrets.get("DEVICE", "cpu") == "gpu" else None
94
+ )
95
+
96
+ llm = HuggingFacePipeline(pipeline=pipe)
97
+ return llm
98
+ except Exception as e:
99
+ st.error(f"Error loading model: {e}")
100
+ return None
101
 
102
  @st.cache_resource
103
  def setup_embeddings(model_name="all-MiniLM-L6-v2"):
104
  """Setup the embedding model for vector generation"""
105
  with st.spinner("πŸ”’ Loading embedding model..."):
106
+ try:
107
+ embeddings = HuggingFaceEmbeddings(model_name=model_name)
108
+ return embeddings
109
+ except Exception as e:
110
+ st.error(f"Error loading embeddings: {e}")
111
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
+ def process_uploaded_files(uploaded_files, embeddings):
114
+ """Process uploaded PDF files and create FAISS vector store"""
115
  if not uploaded_files:
116
  return None, []
117
 
 
128
  # Load PDF
129
  loader = PyPDFLoader(tmp_file_path)
130
  docs = loader.load()
131
+
132
+ # Add file name to metadata
133
+ for doc in docs:
134
+ doc.metadata['source_file'] = uploaded_file.name
135
+
136
  documents.extend(docs)
137
 
138
  # Clean up temporary file
 
160
  for i, text in enumerate(text_chunks):
161
  text.metadata.update({
162
  "chunk_id": i,
 
163
  "chunk_size": len(text.page_content)
164
  })
165
 
166
  st.info(f"βœ‚οΈ Created {len(text_chunks)} text chunks")
167
 
168
+ # Create FAISS vector store
169
  try:
170
+ vectorstore = FAISS.from_documents(text_chunks, embeddings)
171
+ st.success(f"βœ… Successfully created vector database with {len(text_chunks)} chunks!")
 
 
 
 
172
  return vectorstore, text_chunks
173
  except Exception as e:
174
+ st.error(f"❌ Error creating vector database: {e}")
175
  return None, []
176
 
177
  def create_qa_chain(llm, vectorstore, k=5):
178
  """Create a question-answering chain with retrieval"""
179
+ if not vectorstore or not llm:
180
  return None
181
 
182
  prompt_template = """Use the following context to answer the question. If you cannot find the answer in the context, say "I cannot find this information in the provided documents."
 
192
  input_variables=["context", "question"]
193
  )
194
 
195
+ try:
196
+ qa_chain = RetrievalQA.from_chain_type(
197
+ llm=llm,
198
+ chain_type="stuff",
199
+ retriever=vectorstore.as_retriever(search_kwargs={"k": k}),
200
+ chain_type_kwargs={"prompt": PROMPT},
201
+ return_source_documents=True
202
+ )
203
+ return qa_chain
204
+ except Exception as e:
205
+ st.error(f"Error creating QA chain: {e}")
206
+ return None
207
 
208
  def ask_question(qa_chain, question):
209
  """Ask a question and get an answer with sources"""
 
225
  st.error(f"❌ Error processing question: {e}")
226
  return None
227
 
228
+ def search_similar_chunks(vectorstore, query, k=5):
229
+ """Search for similar chunks without generating an answer"""
230
+ if not vectorstore:
231
+ return []
232
+
233
+ try:
234
+ results = vectorstore.similarity_search(query, k=k)
235
+ return results
236
+ except Exception as e:
237
+ st.error(f"Error searching: {e}")
238
+ return []
239
+
240
  # Main App Interface
241
  def main():
242
  st.markdown('<h1 class="main-header">πŸ“š PDF RAG System</h1>', unsafe_allow_html=True)
 
246
  with st.sidebar:
247
  st.markdown('<h2 class="sidebar-header">βš™οΈ Configuration</h2>', unsafe_allow_html=True)
248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  # Model configuration
250
  st.subheader("πŸ€– Model Settings")
251
  llm_model = st.selectbox(
 
256
 
257
  embedding_model = st.selectbox(
258
  "Embedding Model",
259
+ ["all-MiniLM-L6-v2", "sentence-transformers/all-mpnet-base-v2"],
260
  help="Choose the embedding model"
261
  )
262
 
 
267
  value=5,
268
  help="How many relevant chunks to use for answering questions"
269
  )
270
+
271
+ st.subheader("πŸ’Ύ Vector Store")
272
+ st.info("Using FAISS (local vector storage)")
273
+
274
+ # Option to save/load vector store
275
+ if st.session_state.vectorstore:
276
+ if st.button("πŸ’Ύ Save Vector Store"):
277
+ try:
278
+ # Save vector store to session state or file
279
+ st.session_state.vectorstore.save_local("faiss_index")
280
+ st.success("Vector store saved!")
281
+ except Exception as e:
282
+ st.error(f"Error saving: {e}")
283
 
284
  # Main content area
285
  col1, col2 = st.columns([1, 1])
 
296
  if st.button("πŸš€ Process Documents", type="primary"):
297
  if not uploaded_files:
298
  st.warning("Please upload at least one PDF file.")
 
 
299
  else:
300
  with st.spinner("Processing documents..."):
301
  # Setup models
302
  llm = setup_llm(llm_model)
303
  embeddings = setup_embeddings(embedding_model)
304
 
305
+ if llm and embeddings:
 
 
 
306
  # Process files
307
+ vectorstore, text_chunks = process_uploaded_files(uploaded_files, embeddings)
 
 
308
 
309
  if vectorstore:
310
  # Create QA chain
311
  qa_chain = create_qa_chain(llm, vectorstore, k=retrieval_k)
312
 
313
+ if qa_chain:
314
+ # Store in session state
315
+ st.session_state.qa_chain = qa_chain
316
+ st.session_state.vectorstore = vectorstore
317
+ st.session_state.documents_processed = True
318
+
319
+ st.balloons()
320
+ st.success("πŸŽ‰ Documents processed successfully! You can now ask questions.")
321
+ else:
322
+ st.error("Failed to create QA chain.")
323
+ else:
324
+ st.error("Failed to load models.")
325
 
326
  with col2:
327
  st.subheader("πŸ’¬ Ask Questions")
 
333
  help="Ask any question about your uploaded documents"
334
  )
335
 
336
+ col2a, col2b = st.columns([1, 1])
337
+
338
+ with col2a:
339
+ if st.button("πŸ” Get Answer"):
340
+ if question:
341
+ with st.spinner("Searching for answer..."):
342
+ result = ask_question(st.session_state.qa_chain, question)
 
 
 
 
 
343
 
344
+ if result:
345
+ # Add to chat history
346
+ st.session_state.chat_history.append({
347
+ "question": question,
348
+ "answer": result["answer"],
349
+ "sources": result["source_documents"]
350
+ })
351
+
352
+ # Display answer
353
+ st.subheader("πŸ’‘ Answer:")
354
+ st.write(result["answer"])
355
+
356
+ # Display sources
357
+ if result["source_documents"]:
358
+ st.subheader("πŸ“š Sources:")
359
+ for i, doc in enumerate(result["source_documents"][:3]):
360
+ with st.expander(f"Source {i+1}: {doc.metadata.get('source_file', 'Unknown')}"):
361
+ st.write(doc.page_content[:500] + "..." if len(doc.page_content) > 500 else doc.page_content)
362
+ else:
363
+ st.warning("Please enter a question.")
364
+
365
+ with col2b:
366
+ if st.button("πŸ” Search Similar"):
367
+ if question:
368
+ with st.spinner("Searching for similar content..."):
369
+ results = search_similar_chunks(st.session_state.vectorstore, question, k=5)
370
 
371
+ if results:
372
+ st.subheader("πŸ” Similar Content:")
373
+ for i, doc in enumerate(results):
374
+ with st.expander(f"Match {i+1}: {doc.metadata.get('source_file', 'Unknown')}"):
375
+ st.write(doc.page_content[:300] + "..." if len(doc.page_content) > 300 else doc.page_content)
 
 
 
376
  else:
377
  st.info("πŸ‘† Please upload and process documents first to start asking questions.")
378
 
 
388
  if chat['sources']:
389
  st.write("**Sources:**")
390
  for j, doc in enumerate(chat['sources'][:2]): # Show top 2 sources
391
+ st.write(f"{j+1}. {doc.metadata.get('source_file', 'Unknown')}")
392
 
393
  # Clear session button
394
  if st.session_state.documents_processed:
 
398
  st.session_state.documents_processed = False
399
  st.session_state.chat_history = []
400
  st.success("Session cleared! You can upload new documents.")
401
+ st.rerun()
402
 
403
  if __name__ == "__main__":
404
  main()