dev2607 commited on
Commit
92efe67
·
verified ·
1 Parent(s): 9de82c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +391 -0
app.py CHANGED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import streamlit as st
4
+ import fitz # PyMuPDF
5
+ from typing import List, Dict, Any, Optional
6
+ from langchain_community.llms import HuggingFaceEndpoint
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain_community.embeddings import HuggingFaceEmbeddings
9
+ from langchain_community.vectorstores import Chroma
10
+ from langchain.chains import ConversationalRetrievalChain
11
+ from langchain.memory import ConversationBufferMemory
12
+ from langchain.prompts import PromptTemplate
13
+
14
+ # Configure page
15
+ st.set_page_config(
16
+ page_title="PDF Q&A Assistant",
17
+ page_icon="📚",
18
+ layout="wide"
19
+ )
20
+
21
+ # Initialize session state variables if they don't exist
22
+ if "chat_history" not in st.session_state:
23
+ st.session_state.chat_history = []
24
+ if "conversation_chain" not in st.session_state:
25
+ st.session_state.conversation_chain = None
26
+ if "document_processed" not in st.session_state:
27
+ st.session_state.document_processed = False
28
+ if "file_names" not in st.session_state:
29
+ st.session_state.file_names = []
30
+
31
+ class PDFQAAssistant:
32
+ def __init__(self,
33
+ hf_token: str = None,
34
+ model_name: str = "mistralai/Mistral-7B-Instruct-v0.2",
35
+ embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
36
+ """
37
+ Initialize the PDF Q&A Assistant with Hugging Face models.
38
+
39
+ Args:
40
+ hf_token: Hugging Face API token
41
+ model_name: HF model to use for Q&A
42
+ embedding_model_name: HF model to use for embeddings
43
+ """
44
+ self.model_name = model_name
45
+ self.embedding_model_name = embedding_model_name
46
+ self.hf_token = hf_token
47
+
48
+ # Create a temp directory for the vector store
49
+ self.persist_directory = os.path.join(tempfile.gettempdir(), "pdf_qa_vectorstore")
50
+
51
+ # Initialize LLM with Hugging Face
52
+ self.llm = HuggingFaceEndpoint(
53
+ repo_id=model_name,
54
+ huggingfacehub_api_token=hf_token,
55
+ max_length=1024,
56
+ temperature=0.5
57
+ )
58
+
59
+ # Initialize embeddings with Hugging Face
60
+ self.embeddings = HuggingFaceEmbeddings(
61
+ model_name=embedding_model_name,
62
+ model_kwargs={'device': 'cpu'}
63
+ )
64
+
65
+ # Initialize text splitter for chunking documents
66
+ self.text_splitter = RecursiveCharacterTextSplitter(
67
+ chunk_size=1000,
68
+ chunk_overlap=200,
69
+ length_function=len
70
+ )
71
+
72
+ # Vector store and conversation chain will be initialized when documents are loaded
73
+ self.vectorstore = None
74
+ self.memory = ConversationBufferMemory(
75
+ memory_key="chat_history",
76
+ return_messages=True
77
+ )
78
+
79
+ # Create directories if they don't exist
80
+ os.makedirs(self.persist_directory, exist_ok=True)
81
+
82
+ def extract_text_from_pdf(self, pdf_file) -> str:
83
+ """
84
+ Extract text from a PDF file using PyMuPDF.
85
+
86
+ Args:
87
+ pdf_file: Uploaded PDF file
88
+
89
+ Returns:
90
+ Extracted text as a string
91
+ """
92
+ try:
93
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
94
+ tmp_file.write(pdf_file.getvalue())
95
+ tmp_path = tmp_file.name
96
+
97
+ # Open the PDF
98
+ doc = fitz.open(tmp_path)
99
+
100
+ # Extract text from each page
101
+ text = ""
102
+ for page_num, page in enumerate(doc):
103
+ text += page.get_text()
104
+
105
+ # Clean up
106
+ doc.close()
107
+ os.unlink(tmp_path)
108
+
109
+ return text
110
+
111
+ except Exception as e:
112
+ st.error(f"Error extracting text from PDF: {e}")
113
+ raise
114
+
115
+ def process_pdf(self, pdf_file, document_name: str) -> None:
116
+ """
117
+ Process a PDF file and prepare it for question answering.
118
+
119
+ Args:
120
+ pdf_file: Uploaded PDF file
121
+ document_name: Name to identify the document
122
+ """
123
+ # Extract text from PDF
124
+ with st.status("Extracting text from PDF..."):
125
+ text = self.extract_text_from_pdf(pdf_file)
126
+ st.write(f"Extracted {len(text)} characters")
127
+
128
+ # Split text into chunks
129
+ with st.status("Splitting document into chunks..."):
130
+ chunks = self.text_splitter.split_text(text)
131
+ st.write(f"Document split into {len(chunks)} chunks")
132
+
133
+ # Create vector embeddings
134
+ with st.status("Creating vector embeddings..."):
135
+ # Create metadata for each chunk
136
+ metadatas = [{"source": document_name, "chunk": i} for i in range(len(chunks))]
137
+
138
+ # If vectorstore already exists, add to it, otherwise create a new one
139
+ if self.vectorstore is None:
140
+ self.vectorstore = Chroma.from_texts(
141
+ texts=chunks,
142
+ embedding=self.embeddings,
143
+ metadatas=metadatas,
144
+ persist_directory=self.persist_directory
145
+ )
146
+ else:
147
+ self.vectorstore.add_texts(texts=chunks, metadatas=metadatas)
148
+
149
+ # Persist the vector store
150
+ if hasattr(self.vectorstore, 'persist'):
151
+ self.vectorstore.persist()
152
+
153
+ # Initialize the conversation chain
154
+ with st.status("Setting up Q&A system..."):
155
+ retriever = self.vectorstore.as_retriever(
156
+ search_kwargs={"k": 4} # Retrieve top 4 most relevant chunks
157
+ )
158
+
159
+ # Create a custom prompt template that includes the source information
160
+ qa_prompt = PromptTemplate(
161
+ input_variables=["context", "question", "chat_history"],
162
+ template="""
163
+ You are an AI assistant specializing in answering questions about documents.
164
+ Use the following pieces of context to answer the question at the end.
165
+ If you don't know the answer, just say you don't know. Don't try to make up an answer.
166
+ Always cite the specific source or page number when possible.
167
+
168
+ Context:
169
+ {context}
170
+
171
+ Chat History:
172
+ {chat_history}
173
+
174
+ Question:
175
+ {question}
176
+
177
+ Answer:
178
+ """
179
+ )
180
+
181
+ self.conversation_chain = ConversationalRetrievalChain.from_llm(
182
+ llm=self.llm,
183
+ retriever=retriever,
184
+ memory=self.memory,
185
+ combine_docs_chain_kwargs={"prompt": qa_prompt},
186
+ return_source_documents=True
187
+ )
188
+
189
+ # Store the conversation chain in session state
190
+ st.session_state.conversation_chain = self.conversation_chain
191
+
192
+ st.success(f"Successfully processed {document_name}")
193
+ st.session_state.document_processed = True
194
+
195
+ def ask(self, question: str) -> Dict[str, Any]:
196
+ """
197
+ Ask a question about the loaded documents.
198
+
199
+ Args:
200
+ question: The question to ask
201
+
202
+ Returns:
203
+ Dictionary with the answer and source documents
204
+ """
205
+ if self.conversation_chain is None:
206
+ return {"answer": "Please load a document first before asking questions."}
207
+
208
+ try:
209
+ result = self.conversation_chain({"question": question})
210
+
211
+ # Format sources for better readability
212
+ sources = []
213
+ if "source_documents" in result:
214
+ for doc in result["source_documents"]:
215
+ source = doc.metadata.get("source", "Unknown")
216
+ chunk = doc.metadata.get("chunk", "Unknown")
217
+ if source not in [s["source"] for s in sources]:
218
+ sources.append({"source": source, "chunk": chunk})
219
+
220
+ return {
221
+ "answer": result["answer"],
222
+ "sources": sources
223
+ }
224
+
225
+ except Exception as e:
226
+ st.error(f"Error processing question: {e}")
227
+ return {"answer": f"Error processing your question: {e}"}
228
+
229
+ def clear_memory(self) -> None:
230
+ """Clear the conversation memory."""
231
+ self.memory.clear()
232
+
233
+ def get_document_summary(assistant, document_name):
234
+ """Get a summary of the loaded document."""
235
+ st.subheader("Document Summary")
236
+
237
+ with st.status("Generating document summary..."):
238
+ questions = [
239
+ "What is the main topic of this document?",
240
+ "What are the key points from this document?",
241
+ "Could you provide a summary of this document in 3-5 bullet points?"
242
+ ]
243
+
244
+ for question in questions:
245
+ result = assistant.ask(question)
246
+ st.write(f"**{question}**")
247
+ st.write(result["answer"])
248
+ st.divider()
249
+
250
+ # Main app function
251
+ def main():
252
+ st.title("📚 AI-Powered PDF Reader & Q&A Assistant")
253
+
254
+ # Sidebar for settings and uploads
255
+ with st.sidebar:
256
+ st.header("Settings")
257
+
258
+ # Option to use HF token from environment or manual entry
259
+ use_env_token = st.checkbox("Use HF_TOKEN from environment", value=True)
260
+
261
+ if use_env_token:
262
+ hf_token = os.environ.get("HF_TOKEN", None)
263
+ if not hf_token:
264
+ st.warning("HF_TOKEN not found in environment variables.")
265
+ else:
266
+ hf_token = st.text_input("Enter Hugging Face API Token:", type="password")
267
+
268
+ # Model selection
269
+ st.subheader("Model Settings")
270
+ model_name = st.selectbox(
271
+ "Select LLM model:",
272
+ ["mistralai/Mistral-7B-Instruct-v0.2",
273
+ "google/flan-t5-large",
274
+ "tiiuae/falcon-7b-instruct"],
275
+ index=0
276
+ )
277
+
278
+ embedding_model = st.selectbox(
279
+ "Select Embedding model:",
280
+ ["sentence-transformers/all-MiniLM-L6-v2",
281
+ "sentence-transformers/all-mpnet-base-v2"],
282
+ index=0
283
+ )
284
+
285
+ # Document upload
286
+ st.subheader("Upload Documents")
287
+ uploaded_files = st.file_uploader("Upload PDF documents",
288
+ type="pdf",
289
+ accept_multiple_files=True)
290
+
291
+ if uploaded_files:
292
+ process_btn = st.button("Process Documents")
293
+ if process_btn:
294
+ # Initialize the assistant
295
+ assistant = PDFQAAssistant(
296
+ hf_token=hf_token,
297
+ model_name=model_name,
298
+ embedding_model_name=embedding_model
299
+ )
300
+
301
+ # Process each uploaded file
302
+ for pdf_file in uploaded_files:
303
+ file_name = pdf_file.name
304
+ st.session_state.file_names.append(file_name)
305
+ assistant.process_pdf(pdf_file, file_name)
306
+
307
+ # Store the assistant in session state
308
+ st.session_state.assistant = assistant
309
+
310
+ # Document management
311
+ if st.session_state.document_processed:
312
+ st.subheader("Document Management")
313
+
314
+ if st.button("Clear Chat History"):
315
+ st.session_state.assistant.clear_memory()
316
+ st.session_state.chat_history = []
317
+ st.success("Chat history cleared!")
318
+
319
+ if st.button("Generate Document Summary"):
320
+ get_document_summary(st.session_state.assistant,
321
+ st.session_state.file_names[0])
322
+
323
+ # Main area for chat interface
324
+ if not st.session_state.document_processed:
325
+ st.info("👈 Please upload and process a PDF document to get started.")
326
+
327
+ # Display demo information
328
+ st.header("How It Works")
329
+ col1, col2, col3 = st.columns(3)
330
+
331
+ with col1:
332
+ st.subheader("1. Upload PDF")
333
+ st.markdown("Upload any PDF document you want to query.")
334
+
335
+ with col2:
336
+ st.subheader("2. Process Document")
337
+ st.markdown("The AI will extract text and create searchable embeddings.")
338
+
339
+ with col3:
340
+ st.subheader("3. Ask Questions")
341
+ st.markdown("Ask any question about your document and get accurate answers.")
342
+ else:
343
+ # Chat interface
344
+ st.header("Ask Questions About Your Documents")
345
+
346
+ # Display processed files
347
+ st.caption(f"Processed Files: {', '.join(st.session_state.file_names)}")
348
+
349
+ # Display chat history
350
+ for message in st.session_state.chat_history:
351
+ if message["role"] == "user":
352
+ st.chat_message("user").write(message["content"])
353
+ else:
354
+ st.chat_message("assistant").write(message["content"])
355
+ if "sources" in message:
356
+ with st.expander("View Sources"):
357
+ for source in message["sources"]:
358
+ st.write(f"- {source['source']} (chunk {source['chunk']})")
359
+
360
+ # Input for new question
361
+ if question := st.chat_input("Ask a question about your documents..."):
362
+ # Add user question to chat history
363
+ st.session_state.chat_history.append({
364
+ "role": "user",
365
+ "content": question
366
+ })
367
+
368
+ # Display user question
369
+ st.chat_message("user").write(question)
370
+
371
+ # Get the answer
372
+ with st.chat_message("assistant"):
373
+ with st.spinner("Thinking..."):
374
+ result = st.session_state.assistant.ask(question)
375
+ st.write(result["answer"])
376
+
377
+ # Show sources if available
378
+ if result["sources"]:
379
+ with st.expander("View Sources"):
380
+ for source in result["sources"]:
381
+ st.write(f"- {source['source']} (chunk {source['chunk']})")
382
+
383
+ # Add assistant response to chat history
384
+ st.session_state.chat_history.append({
385
+ "role": "assistant",
386
+ "content": result["answer"],
387
+ "sources": result["sources"]
388
+ })
389
+
390
+ if __name__ == "__main__":
391
+ main()