rbbist commited on
Commit
0446596
Β·
verified Β·
1 Parent(s): b1e9890

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +389 -40
app.py CHANGED
@@ -1,47 +1,396 @@
1
  import streamlit as st
2
- import PyPDF2
3
- from langchain.embeddings import SentenceTransformerEmbeddings
4
- from langchain.vectorstores import FAISS
 
 
 
 
 
 
 
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 
6
  from langchain.chains import RetrievalQA
7
- from langchain.llms import HuggingFacePipeline
8
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- st.set_page_config(page_title="PDF QA App (CPU)", layout="wide")
11
- st.title("πŸ“˜ Ask Questions from Uploaded PDFs (Free & CPU Friendly)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- uploaded_files = st.file_uploader("Upload multiple PDF files", type=["pdf"], accept_multiple_files=True)
 
 
 
 
 
 
 
 
14
 
15
  @st.cache_resource
16
- def load_llm():
17
- model_id = "google/flan-t5-base"
18
- tokenizer = AutoTokenizer.from_pretrained(model_id)
19
- model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
20
- pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
21
- return HuggingFacePipeline(pipeline=pipe)
22
-
23
- if uploaded_files:
24
- st.info("Reading and processing PDFs...")
25
- all_text = ""
26
- for file in uploaded_files:
27
- reader = PyPDF2.PdfReader(file)
28
- for page in reader.pages:
29
- text = page.extract_text()
30
- if text:
31
- all_text += text
32
-
33
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
34
- texts = text_splitter.split_text(all_text)
35
-
36
- embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
37
- db = FAISS.from_texts(texts, embeddings)
38
-
39
- retriever = db.as_retriever()
40
- llm = load_llm()
41
- qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
42
-
43
- question = st.text_input("Ask a question based on the uploaded PDFs:")
44
- if question:
45
- with st.spinner("Generating answer..."):
46
- answer = qa_chain.run(question)
47
- st.success(answer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import os
3
+ import tempfile
4
+ from typing import List, Optional
5
+ from pathlib import Path
6
+ import time
7
+
8
+ # Core libraries
9
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
10
+ from langchain.llms import HuggingFacePipeline
11
+ from langchain.embeddings import HuggingFaceEmbeddings
12
  from langchain.text_splitter import RecursiveCharacterTextSplitter
13
+ from langchain.schema import Document
14
+ from langchain import PromptTemplate
15
  from langchain.chains import RetrievalQA
16
+ from langchain_pinecone import PineconeVectorStore
17
+ from pinecone import Pinecone as PineconeClient
18
+
19
+ # Document loaders
20
+ from langchain.document_loaders import PyPDFLoader
21
+
22
+ # Configure Streamlit page
23
+ st.set_page_config(
24
+ page_title="PDF RAG System",
25
+ page_icon="πŸ“š",
26
+ layout="wide",
27
+ initial_sidebar_state="expanded"
28
+ )
29
 
30
+ # Custom CSS for better styling
31
+ st.markdown("""
32
+ <style>
33
+ .main-header {
34
+ font-size: 2.5rem;
35
+ color: #1f77b4;
36
+ text-align: center;
37
+ margin-bottom: 2rem;
38
+ }
39
+ .sidebar-header {
40
+ font-size: 1.5rem;
41
+ color: #ff7f0e;
42
+ margin-bottom: 1rem;
43
+ }
44
+ .success-message {
45
+ padding: 1rem;
46
+ background-color: #d4edda;
47
+ border: 1px solid #c3e6cb;
48
+ border-radius: 0.5rem;
49
+ color: #155724;
50
+ margin: 1rem 0;
51
+ }
52
+ .error-message {
53
+ padding: 1rem;
54
+ background-color: #f8d7da;
55
+ border: 1px solid #f5c6cb;
56
+ border-radius: 0.5rem;
57
+ color: #721c24;
58
+ margin: 1rem 0;
59
+ }
60
+ .source-box {
61
+ background-color: #f8f9fa;
62
+ border-left: 4px solid #007bff;
63
+ padding: 1rem;
64
+ margin: 0.5rem 0;
65
+ border-radius: 0 0.5rem 0.5rem 0;
66
+ }
67
+ </style>
68
+ """, unsafe_allow_html=True)
69
 
70
+ # Initialize session state
71
+ if 'qa_chain' not in st.session_state:
72
+ st.session_state.qa_chain = None
73
+ if 'vectorstore' not in st.session_state:
74
+ st.session_state.vectorstore = None
75
+ if 'documents_processed' not in st.session_state:
76
+ st.session_state.documents_processed = False
77
+ if 'chat_history' not in st.session_state:
78
+ st.session_state.chat_history = []
79
 
80
  @st.cache_resource
81
+ def setup_llm(model_name="google/flan-t5-small"):
82
+ """Setup the language model for text generation"""
83
+ with st.spinner("πŸ€– Loading language model..."):
84
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
85
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
86
+
87
+ pipe = pipeline(
88
+ "text2text-generation",
89
+ model=model,
90
+ tokenizer=tokenizer,
91
+ max_new_tokens=300,
92
+ temperature=0.3,
93
+ do_sample=True
94
+ )
95
+
96
+ llm = HuggingFacePipeline(pipeline=pipe)
97
+ return llm
98
+
99
+ @st.cache_resource
100
+ def setup_embeddings(model_name="all-MiniLM-L6-v2"):
101
+ """Setup the embedding model for vector generation"""
102
+ with st.spinner("πŸ”’ Loading embedding model..."):
103
+ embeddings = HuggingFaceEmbeddings(model_name=model_name)
104
+ return embeddings
105
+
106
+ def setup_pinecone(api_key, environment="us-east-1", index_name="pdf-rag-system"):
107
+ """Setup Pinecone vector database connection"""
108
+ try:
109
+ os.environ["PINECONE_API_KEY"] = api_key
110
+ os.environ["PINECONE_ENVIRONMENT"] = environment
111
+
112
+ pc = PineconeClient(api_key=api_key, environment=environment)
113
+
114
+ existing_indexes = pc.list_indexes()
115
+
116
+ if index_name not in [idx.name for idx in existing_indexes]:
117
+ st.info(f"πŸ“ Creating new index: {index_name}")
118
+ pc.create_index(
119
+ name=index_name,
120
+ dimension=384,
121
+ metric='cosine'
122
+ )
123
+ time.sleep(30) # Wait for index to be ready
124
+
125
+ return pc, index_name
126
+ except Exception as e:
127
+ st.error(f"❌ Error setting up Pinecone: {e}")
128
+ return None, None
129
+
130
+ def process_uploaded_files(uploaded_files, embeddings, pc, index_name):
131
+ """Process uploaded PDF files and store in vector database"""
132
+ if not uploaded_files:
133
+ return None, []
134
+
135
+ documents = []
136
+
137
+ # Process each uploaded file
138
+ for uploaded_file in uploaded_files:
139
+ try:
140
+ # Create temporary file
141
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
142
+ tmp_file.write(uploaded_file.read())
143
+ tmp_file_path = tmp_file.name
144
+
145
+ # Load PDF
146
+ loader = PyPDFLoader(tmp_file_path)
147
+ docs = loader.load()
148
+ documents.extend(docs)
149
+
150
+ # Clean up temporary file
151
+ os.unlink(tmp_file_path)
152
+
153
+ st.success(f"βœ… Processed: {uploaded_file.name} ({len(docs)} pages)")
154
+
155
+ except Exception as e:
156
+ st.error(f"❌ Error processing {uploaded_file.name}: {e}")
157
+
158
+ if not documents:
159
+ return None, []
160
+
161
+ # Split documents into chunks
162
+ text_splitter = RecursiveCharacterTextSplitter(
163
+ chunk_size=1000,
164
+ chunk_overlap=200,
165
+ length_function=len,
166
+ separators=["\n\n", "\n", " ", ""]
167
+ )
168
+
169
+ text_chunks = text_splitter.split_documents(documents)
170
+
171
+ # Add metadata to chunks
172
+ for i, text in enumerate(text_chunks):
173
+ text.metadata.update({
174
+ "chunk_id": i,
175
+ "source_file": text.metadata.get("source", "unknown"),
176
+ "chunk_size": len(text.page_content)
177
+ })
178
+
179
+ st.info(f"βœ‚οΈ Created {len(text_chunks)} text chunks")
180
+
181
+ # Store in Pinecone
182
+ try:
183
+ vectorstore = PineconeVectorStore.from_documents(
184
+ documents=text_chunks,
185
+ embedding=embeddings,
186
+ index_name=index_name
187
+ )
188
+ st.success(f"βœ… Successfully stored {len(text_chunks)} chunks in vector database!")
189
+ return vectorstore, text_chunks
190
+ except Exception as e:
191
+ st.error(f"❌ Error storing in vector database: {e}")
192
+ return None, []
193
+
194
+ def create_qa_chain(llm, vectorstore, k=5):
195
+ """Create a question-answering chain with retrieval"""
196
+ if not vectorstore:
197
+ return None
198
+
199
+ prompt_template = """Use the following context to answer the question. If you cannot find the answer in the context, say "I cannot find this information in the provided documents."
200
+
201
+ Context: {context}
202
+
203
+ Question: {question}
204
+
205
+ Answer: Let me analyze the provided context to answer your question."""
206
+
207
+ PROMPT = PromptTemplate(
208
+ template=prompt_template,
209
+ input_variables=["context", "question"]
210
+ )
211
+
212
+ qa_chain = RetrievalQA.from_chain_type(
213
+ llm=llm,
214
+ chain_type="stuff",
215
+ retriever=vectorstore.as_retriever(search_kwargs={"k": k}),
216
+ chain_type_kwargs={"prompt": PROMPT},
217
+ return_source_documents=True,
218
+ verbose=True
219
+ )
220
+
221
+ return qa_chain
222
+
223
+ def ask_question(qa_chain, question):
224
+ """Ask a question and get an answer with sources"""
225
+ if not qa_chain:
226
+ return None
227
+
228
+ try:
229
+ result = qa_chain({"query": question})
230
+
231
+ response = {
232
+ "question": question,
233
+ "answer": result["result"],
234
+ "source_documents": result.get("source_documents", [])
235
+ }
236
+
237
+ return response
238
+
239
+ except Exception as e:
240
+ st.error(f"❌ Error processing question: {e}")
241
+ return None
242
+
243
+ # Main App Interface
244
+ def main():
245
+ st.markdown('<h1 class="main-header">πŸ“š PDF RAG System</h1>', unsafe_allow_html=True)
246
+ st.markdown("Upload PDF documents and ask questions about their content using AI-powered retrieval!")
247
+
248
+ # Sidebar for configuration
249
+ with st.sidebar:
250
+ st.markdown('<h2 class="sidebar-header">βš™οΈ Configuration</h2>', unsafe_allow_html=True)
251
+
252
+ # Pinecone configuration
253
+ st.subheader("🌲 Pinecone Settings")
254
+ pinecone_api_key = st.text_input(
255
+ "Pinecone API Key",
256
+ type="password",
257
+ help="Enter your Pinecone API key",
258
+ value=st.secrets.get("PINECONE_API_KEY", "") if "PINECONE_API_KEY" in st.secrets else ""
259
+ )
260
+
261
+ index_name = st.text_input(
262
+ "Index Name",
263
+ value="pdf-rag-system",
264
+ help="Name for your Pinecone index"
265
+ )
266
+
267
+ # Model configuration
268
+ st.subheader("πŸ€– Model Settings")
269
+ llm_model = st.selectbox(
270
+ "Language Model",
271
+ ["google/flan-t5-small", "google/flan-t5-base"],
272
+ help="Choose the language model (smaller models are faster)"
273
+ )
274
+
275
+ embedding_model = st.selectbox(
276
+ "Embedding Model",
277
+ ["all-MiniLM-L6-v2", "all-mpnet-base-v2"],
278
+ help="Choose the embedding model"
279
+ )
280
+
281
+ retrieval_k = st.slider(
282
+ "Number of chunks to retrieve",
283
+ min_value=1,
284
+ max_value=10,
285
+ value=5,
286
+ help="How many relevant chunks to use for answering questions"
287
+ )
288
+
289
+ # Main content area
290
+ col1, col2 = st.columns([1, 1])
291
+
292
+ with col1:
293
+ st.subheader("πŸ“ Upload Documents")
294
+ uploaded_files = st.file_uploader(
295
+ "Choose PDF files",
296
+ type=['pdf'],
297
+ accept_multiple_files=True,
298
+ help="Upload one or more PDF files to analyze"
299
+ )
300
+
301
+ if st.button("πŸš€ Process Documents", type="primary"):
302
+ if not uploaded_files:
303
+ st.warning("Please upload at least one PDF file.")
304
+ elif not pinecone_api_key:
305
+ st.warning("Please enter your Pinecone API key.")
306
+ else:
307
+ with st.spinner("Processing documents..."):
308
+ # Setup models
309
+ llm = setup_llm(llm_model)
310
+ embeddings = setup_embeddings(embedding_model)
311
+
312
+ # Setup Pinecone
313
+ pc, idx_name = setup_pinecone(pinecone_api_key, index_name=index_name)
314
+
315
+ if pc:
316
+ # Process files
317
+ vectorstore, text_chunks = process_uploaded_files(
318
+ uploaded_files, embeddings, pc, idx_name
319
+ )
320
+
321
+ if vectorstore:
322
+ # Create QA chain
323
+ qa_chain = create_qa_chain(llm, vectorstore, k=retrieval_k)
324
+
325
+ # Store in session state
326
+ st.session_state.qa_chain = qa_chain
327
+ st.session_state.vectorstore = vectorstore
328
+ st.session_state.documents_processed = True
329
+
330
+ st.balloons()
331
+ st.success("πŸŽ‰ Documents processed successfully! You can now ask questions.")
332
+
333
+ with col2:
334
+ st.subheader("πŸ’¬ Ask Questions")
335
+
336
+ if st.session_state.documents_processed:
337
+ question = st.text_input(
338
+ "Your question:",
339
+ placeholder="What are the main topics discussed in the documents?",
340
+ help="Ask any question about your uploaded documents"
341
+ )
342
+
343
+ if st.button("πŸ” Get Answer"):
344
+ if question:
345
+ with st.spinner("Searching for answer..."):
346
+ result = ask_question(st.session_state.qa_chain, question)
347
+
348
+ if result:
349
+ # Add to chat history
350
+ st.session_state.chat_history.append({
351
+ "question": question,
352
+ "answer": result["answer"],
353
+ "sources": result["source_documents"]
354
+ })
355
+
356
+ # Display answer
357
+ st.subheader("πŸ’‘ Answer:")
358
+ st.write(result["answer"])
359
+
360
+ # Display sources
361
+ if result["source_documents"]:
362
+ st.subheader("πŸ“š Sources:")
363
+ for i, doc in enumerate(result["source_documents"][:3]):
364
+ with st.expander(f"Source {i+1}: {doc.metadata.get('source', 'Unknown')}"):
365
+ st.write(doc.page_content[:500] + "..." if len(doc.page_content) > 500 else doc.page_content)
366
+ else:
367
+ st.warning("Please enter a question.")
368
+ else:
369
+ st.info("πŸ‘† Please upload and process documents first to start asking questions.")
370
+
371
+ # Chat History
372
+ if st.session_state.chat_history:
373
+ st.subheader("πŸ“ Chat History")
374
+
375
+ for i, chat in enumerate(reversed(st.session_state.chat_history[-5:])): # Show last 5
376
+ with st.expander(f"Q: {chat['question'][:50]}..."):
377
+ st.write("**Question:**", chat['question'])
378
+ st.write("**Answer:**", chat['answer'])
379
+
380
+ if chat['sources']:
381
+ st.write("**Sources:**")
382
+ for j, doc in enumerate(chat['sources'][:2]): # Show top 2 sources
383
+ st.write(f"{j+1}. {doc.metadata.get('source', 'Unknown')}")
384
+
385
+ # Clear session button
386
+ if st.session_state.documents_processed:
387
+ if st.button("πŸ—‘οΈ Clear Session"):
388
+ st.session_state.qa_chain = None
389
+ st.session_state.vectorstore = None
390
+ st.session_state.documents_processed = False
391
+ st.session_state.chat_history = []
392
+ st.success("Session cleared! You can upload new documents.")
393
+ st.experimental_rerun()
394
+
395
+ if __name__ == "__main__":
396
+ main()