ash2203 commited on
Commit
ab5a4af
·
verified ·
1 Parent(s): 31e8e92

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +325 -0
app.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import streamlit as st
4
+ from langchain_groq import ChatGroq
5
+ from langchain_core.prompts import ChatPromptTemplate
6
+ from langchain_core.output_parsers import StrOutputParser
7
+ from langchain_community.document_loaders import TextLoader, PyMuPDFLoader, Docx2txtLoader
8
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
9
+ from typing import List
10
+ from langchain_core.documents import Document
11
+ from langchain_openai import OpenAIEmbeddings
12
+ from langchain_core.runnables import RunnablePassthrough
13
+ from langchain_community.retrievers import BM25Retriever
14
+ from langchain.retrievers import EnsembleRetriever
15
+ from langchain_community.retrievers import PineconeHybridSearchRetriever
16
+ from langchain_pinecone import PineconeVectorStore
17
+ from pinecone import Pinecone, ServerlessSpec
18
+ from pinecone import PineconeApiException, NotFoundException
19
+ import shutil
20
+
21
+ from dotenv import load_dotenv
22
+ load_dotenv()
23
+
24
+ # Set page configuration
25
+ st.set_page_config(page_title="Document Analyzer", layout="wide", )
26
+
27
+ st.title("📚 Document Analyzer")
28
+
29
+ # Add instructions in an expander
30
+ with st.expander("ℹ️ Click here to view instructions"):
31
+ st.markdown("""
32
+ - Upload files by clicking on "Browse Files"
33
+ - Avoid interrupting when file/files are under processing, this interrupts the execution and you would have to refresh the page to run the webapp again
34
+ - You can add more files anytime, just avoid adding/removing files when it's processing the uploaded documents
35
+ - The processing will trigger whenever you make any changes to the files
36
+ """)
37
+
38
+ # Initialize session states
39
+ if 'initialized' not in st.session_state:
40
+ st.session_state.initialized = False
41
+ if 'processing' not in st.session_state:
42
+ st.session_state.processing = False
43
+ if 'last_processed_files' not in st.session_state:
44
+ st.session_state.last_processed_files = set()
45
+ if 'chat_history' not in st.session_state:
46
+ st.session_state.chat_history = []
47
+ if 'chat_enabled' not in st.session_state:
48
+ st.session_state.chat_enabled = False
49
+
50
+ if not st.session_state.initialized:
51
+ # Clear everything only on first run or page refresh
52
+ if os.path.exists("data"):
53
+ shutil.rmtree("data")
54
+ os.makedirs("data")
55
+ st.session_state.uploaded_files = {}
56
+ st.session_state.previous_files = set()
57
+ st.session_state.vectorstore = None
58
+ st.session_state.retriever = None
59
+ st.session_state.initialized = True
60
+
61
+ def save_uploaded_file(uploaded_file):
62
+ """Save uploaded file to the data directory"""
63
+ try:
64
+ # Create full path
65
+ file_path = os.path.join("data", uploaded_file.name)
66
+
67
+ # Save the file
68
+ with open(file_path, "wb") as f:
69
+ file_bytes = uploaded_file.getvalue() # Get file bytes
70
+ f.write(file_bytes)
71
+
72
+ # Verify file was saved
73
+ if os.path.exists(file_path):
74
+ return file_path
75
+ else:
76
+ st.error(f"File not saved: {file_path}")
77
+ return None
78
+
79
+ except Exception as e:
80
+ st.error(f"Error saving file: {str(e)}")
81
+ return None
82
+
83
+ def process_documents(uploaded_files_dict):
84
+ """Process documents and store in Pinecone"""
85
+ warning_placeholder = st.empty()
86
+ warning_placeholder.warning("⚠️ Document processing in progress. Please wait before adding or removing files.")
87
+ success_placeholder = st.empty()
88
+
89
+ try:
90
+ with st.spinner('Processing documents...'):
91
+ docs = []
92
+ # Process each file
93
+ for filename, file_info in uploaded_files_dict.items():
94
+ file_path = file_info["path"]
95
+
96
+ if not os.path.exists(file_path):
97
+ st.error(f"File not found: {file_path}")
98
+ continue
99
+
100
+ if filename.endswith(".pdf"):
101
+ document = PyMuPDFLoader(file_path)
102
+ file_doc = document.load()
103
+ docs.extend(file_doc)
104
+ elif filename.endswith(".txt"):
105
+ document = TextLoader(file_path)
106
+ file_doc = document.load()
107
+ docs.extend(file_doc)
108
+ elif filename.endswith(".docx"):
109
+ document = Docx2txtLoader(file_path)
110
+ file_doc = document.load()
111
+ docs.extend(file_doc)
112
+
113
+ if not docs:
114
+ st.error("No documents were successfully processed")
115
+ return False
116
+
117
+ # Split documents
118
+ text_splitter = RecursiveCharacterTextSplitter(
119
+ chunk_size=2000,
120
+ chunk_overlap=400,
121
+ length_function=len
122
+ )
123
+ chunks = text_splitter.split_documents(docs)
124
+
125
+ # Initialize embeddings
126
+ embed_func = OpenAIEmbeddings(model='text-embedding-3-small', dimensions=512)
127
+
128
+ # Initialize Pinecone
129
+ pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
130
+ index_name = os.getenv("VECTORDB_NAME")
131
+
132
+ try:
133
+ # Recreate the index
134
+ if index_name in pc.list_indexes().names():
135
+ pc.delete_index(index_name)
136
+
137
+ pc.create_index(
138
+ name=index_name,
139
+ dimension=512,
140
+ metric='cosine',
141
+ spec=ServerlessSpec(cloud='aws', region='us-east-1')
142
+ )
143
+
144
+ # Wait for index to be ready
145
+ while not pc.describe_index(index_name).status['ready']:
146
+ time.sleep(1)
147
+
148
+ pc_index = pc.Index(index_name)
149
+
150
+ # Create vectorstore and add documents
151
+ vectorstore = PineconeVectorStore(index=pc_index, embedding=embed_func)
152
+ vectorstore.add_documents(documents=chunks)
153
+
154
+ st.session_state.chat_enabled = True
155
+ success_placeholder.success('Documents processed successfully!')
156
+ time.sleep(2) # Show success message for 2 seconds
157
+ success_placeholder.empty() # Clear the success message
158
+ return True
159
+
160
+ except PineconeApiException as e:
161
+ st.error("File upload failed! Avoid interrupting document processing by uploading or removing files. Kindly refresh the app to continue.")
162
+ st.session_state.chat_enabled = False
163
+ return False
164
+
165
+ except Exception as e:
166
+ st.error(f"An error occurred during processing: {str(e)}")
167
+ st.session_state.chat_enabled = False
168
+ return False
169
+ finally:
170
+ warning_placeholder.empty()
171
+
172
+ def doc2str(docs):
173
+ return "\n\n".join(doc for doc in docs)
174
+
175
+ def format_reranked_docs(pc, retriever, question):
176
+ """Rerank documents using Pinecone's reranking model"""
177
+ relevant_docs = [doc.page_content for doc in retriever.invoke(question) if len(doc.page_content)>5]
178
+
179
+ reranked_docs = pc.inference.rerank(
180
+ model="pinecone-rerank-v0",
181
+ query=question,
182
+ documents=relevant_docs,
183
+ top_n=3,
184
+ return_documents=True
185
+ )
186
+
187
+ final_docs = [d.document.text for d in reranked_docs.data]
188
+ context = doc2str(final_docs)
189
+ return context
190
+
191
+ def run_chatbot(retriever, pc, llm):
192
+ """Run the chatbot with the given components"""
193
+ # st.markdown("<h4>💬 Chat with your Documents</h4>", unsafe_allow_html=True)
194
+
195
+ # Initialize chat prompt
196
+ prompt = ChatPromptTemplate.from_template("""
197
+ You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know.
198
+
199
+ <context>
200
+ {context}
201
+ </context>
202
+
203
+ Important: You cannot quote the context in the responses. If you do that, there will be a strict penalty for it.
204
+
205
+ Answer the following question:
206
+
207
+ {question}""")
208
+
209
+ # Create the QA chain with reranking
210
+ qa_chain = (
211
+ RunnablePassthrough.assign(context=lambda input: format_reranked_docs(pc, retriever, input["question"]))
212
+ | prompt
213
+ | llm
214
+ | StrOutputParser()
215
+ )
216
+
217
+ # Initialize messages in session state if not exists
218
+ if "messages" not in st.session_state:
219
+ st.session_state.messages = []
220
+
221
+ # Display chat messages
222
+ for message in st.session_state.messages:
223
+ with st.chat_message(message["role"]):
224
+ st.markdown(message["content"])
225
+
226
+ # Chat input
227
+ if question := st.chat_input("Ask a question about your documents"):
228
+ # Add user message to chat history
229
+ st.session_state.messages.append({"role": "user", "content": question})
230
+ with st.chat_message("user"):
231
+ st.markdown(question)
232
+
233
+ # Create a spinner outside the chat message
234
+ with st.spinner("Thinking..."):
235
+ try:
236
+ # Generate response
237
+ response = qa_chain.invoke({"question": question})
238
+
239
+ # Display response in chat message after generation
240
+ with st.chat_message("assistant"):
241
+ st.markdown(response)
242
+ # Add assistant response to chat history
243
+ st.session_state.messages.append({"role": "assistant", "content": response})
244
+ except Exception as e:
245
+ error_msg = f"An error occurred while processing your question: {str(e)}"
246
+ with st.chat_message("assistant"):
247
+ st.error(error_msg)
248
+ st.session_state.messages.append({"role": "assistant", "content": f"❌ {error_msg}"})
249
+
250
+ def process_and_chat():
251
+ """Process documents and handle chat interface"""
252
+ # File uploader section
253
+ with st.container():
254
+ uploaded_files = st.file_uploader(
255
+ "Upload your documents",
256
+ type=["pdf", "txt", "docx"],
257
+ accept_multiple_files=True,
258
+ key="file_uploader",
259
+ label_visibility="collapsed" if st.session_state.processing else "visible"
260
+ )
261
+
262
+ # Get current uploaded filenames
263
+ current_uploaded_filenames = {file.name for file in uploaded_files} if uploaded_files else set()
264
+
265
+ # Process newly uploaded files
266
+ if uploaded_files:
267
+ files_added = False
268
+ for file in uploaded_files:
269
+ # Only process files that haven't been uploaded before
270
+ if file.name not in st.session_state.uploaded_files:
271
+ file_path = save_uploaded_file(file)
272
+ if file_path: # Only add to session state if file was saved successfully
273
+ st.session_state.uploaded_files[file.name] = {
274
+ "path": file_path,
275
+ "type": file.type
276
+ }
277
+ files_added = True
278
+
279
+ # Check for changes in files
280
+ current_files = set(st.session_state.uploaded_files.keys())
281
+
282
+ # Process documents only if files have changed
283
+ if current_files != st.session_state.previous_files:
284
+ st.session_state.previous_files = current_files
285
+
286
+ if current_files:
287
+ st.session_state.processing = True
288
+ # Process documents and enable chat if successful
289
+ if process_documents(st.session_state.uploaded_files):
290
+ st.session_state.chat_enabled = True
291
+ st.session_state.processing = False
292
+ else:
293
+ st.warning('Please upload a file to continue')
294
+ st.session_state.chat_enabled = False
295
+
296
+ # If files exist and chat is enabled, show chat interface
297
+ if current_files and st.session_state.chat_enabled:
298
+ try:
299
+ # Initialize components for chat
300
+ llm = ChatGroq(temperature=0, model_name="mixtral-8x7b-32768", groq_api_key=os.getenv("GROQ_API_KEY"))
301
+ pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
302
+ index_name = os.getenv("VECTORDB_NAME")
303
+ pc_index = pc.Index(index_name)
304
+
305
+ # Create vectorstore
306
+ embed_func = OpenAIEmbeddings(model='text-embedding-3-small', dimensions=512)
307
+ vectorstore = PineconeVectorStore(index=pc_index, embedding=embed_func)
308
+
309
+ # Create retrievers
310
+ vectorstore_retriever = vectorstore.as_retriever(
311
+ search_type="similarity_score_threshold",
312
+ search_kwargs={"k": 5, "score_threshold": 0.6},
313
+ )
314
+
315
+ # Run chatbot with fresh components
316
+ run_chatbot(vectorstore_retriever, pc, llm)
317
+ except NotFoundException:
318
+ st.error("Vector database not found. Please try uploading your documents again.")
319
+ st.session_state.chat_enabled = False
320
+ # Clear the previous files to force reprocessing
321
+ st.session_state.previous_files = set()
322
+
323
+ # Call the main function
324
+ process_and_chat()
325
+