aach456 commited on
Commit
0a8bbc5
·
verified ·
1 Parent(s): f02d10f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +353 -0
app.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import gradio as gr
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
+ from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader, UnstructuredPowerPointLoader
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain_community.embeddings import HuggingFaceEmbeddings
9
+ from langchain_community.vectorstores import FAISS
10
+ from langchain.chains import ConversationalRetrievalChain
11
+ from langchain_community.llms import HuggingFacePipeline
12
+
13
+ # Configure environment
14
+ EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
15
+ LLM_MODEL = "google/flan-t5-large"
16
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
+ THRESHOLD = 0.7 # Relevance threshold for retrieval
18
+ CHUNK_SIZE = 1000
19
+ CHUNK_OVERLAP = 200
20
+ TEMPERATURE = 0.1
21
+ MAX_NEW_TOKENS = 512
22
+ TOP_K = 3 # Number of chunks to retrieve
23
+
24
+ # Store for conversation history
25
+ conversation_history = {}
26
+ current_session_id = None
27
+ current_document_store = None
28
+ current_document_name = None
29
+ FILE_EXTENSIONS = {
30
+ ".pdf": PyPDFLoader,
31
+ ".txt": TextLoader,
32
+ ".docx": Docx2txtLoader,
33
+ ".pptx": UnstructuredPowerPointLoader,
34
+ }
35
+
36
+ class DocumentAIBot:
37
+ def __init__(self):
38
+ self.setup_models()
39
+
40
+ def setup_models(self):
41
+ print("Setting up models...")
42
+ # Set up embedding model
43
+ self.embedding_model = HuggingFaceEmbeddings(
44
+ model_name=EMBEDDING_MODEL,
45
+ model_kwargs={"device": DEVICE},
46
+ encode_kwargs={"normalize_embeddings": True}
47
+ )
48
+
49
+ # Set up LLM model
50
+ self.tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
51
+ self.llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_MODEL).to(DEVICE)
52
+
53
+ # Create text generation pipeline
54
+ self.text_generation_pipeline = pipeline(
55
+ "text2text-generation",
56
+ model=self.llm_model,
57
+ tokenizer=self.tokenizer,
58
+ max_new_tokens=MAX_NEW_TOKENS,
59
+ temperature=TEMPERATURE,
60
+ device=0 if DEVICE == "cuda" else -1
61
+ )
62
+
63
+ # Create HuggingFace pipeline for LangChain
64
+ self.llm = HuggingFacePipeline(pipeline=self.text_generation_pipeline)
65
+
66
+ # Text splitter for document chunking
67
+ self.text_splitter = RecursiveCharacterTextSplitter(
68
+ chunk_size=CHUNK_SIZE,
69
+ chunk_overlap=CHUNK_OVERLAP,
70
+ length_function=len
71
+ )
72
+
73
+ print("Models loaded successfully!")
74
+
75
+ def process_document(self, file_path):
76
+ """Process a document and create a vector store."""
77
+ print(f"Processing document: {file_path}")
78
+ file_extension = os.path.splitext(file_path)[1].lower()
79
+
80
+ if file_extension not in FILE_EXTENSIONS:
81
+ raise ValueError(f"Unsupported file format: {file_extension}")
82
+
83
+ # Select appropriate loader
84
+ loader_class = FILE_EXTENSIONS[file_extension]
85
+ loader = loader_class(file_path)
86
+
87
+ # Load and split the document
88
+ documents = loader.load()
89
+ chunks = self.text_splitter.split_documents(documents)
90
+
91
+ if not chunks:
92
+ raise ValueError("No content extracted from the document")
93
+
94
+ print(f"Document split into {len(chunks)} chunks")
95
+
96
+ # Create vector store
97
+ vector_store = FAISS.from_documents(chunks, self.embedding_model)
98
+ return vector_store
99
+
100
+ def setup_retrieval_chain(self, vector_store):
101
+ """Set up the retrieval chain with the vector store."""
102
+ retriever = vector_store.as_retriever(
103
+ search_type="similarity_score_threshold",
104
+ search_kwargs={
105
+ "k": TOP_K,
106
+ "score_threshold": THRESHOLD
107
+ }
108
+ )
109
+
110
+ chain = ConversationalRetrievalChain.from_llm(
111
+ llm=self.llm,
112
+ retriever=retriever,
113
+ return_source_documents=True,
114
+ verbose=True
115
+ )
116
+
117
+ return chain
118
+
119
+ def get_answer(self, question, session_id, vector_store, chat_history):
120
+ """Get answer for a question using the retrieval chain."""
121
+ if not question.strip():
122
+ return "Please enter a question related to the document.", chat_history
123
+
124
+ # Setup retrieval chain if needed
125
+ retrieval_chain = self.setup_retrieval_chain(vector_store)
126
+
127
+ # Format chat history for the model
128
+ formatted_chat_history = [(q, a) for q, a in chat_history]
129
+
130
+ # Get response from the chain
131
+ response = retrieval_chain(
132
+ {"question": question, "chat_history": formatted_chat_history}
133
+ )
134
+
135
+ answer = response["answer"]
136
+ source_documents = response.get("source_documents", [])
137
+
138
+ # Format answer with source information
139
+ if source_documents:
140
+ source_info = "\n\nSources:"
141
+ seen_sources = set()
142
+
143
+ for doc in source_documents:
144
+ source = doc.metadata.get("source", "Unknown source")
145
+ page = doc.metadata.get("page", "Unknown page")
146
+
147
+ source_key = f"{source}-{page}"
148
+ if source_key not in seen_sources:
149
+ seen_sources.add(source_key)
150
+ if source == "Unknown source":
151
+ source_info += f"\n- Document chunk (page {page})"
152
+ else:
153
+ source_info += f"\n- {os.path.basename(source)} (page {page})"
154
+
155
+ answer += source_info
156
+
157
+ return answer, chat_history + [(question, answer)]
158
+
159
+ def generate_session_id():
160
+ """Generate a unique session ID."""
161
+ import uuid
162
+ return str(uuid.uuid4())
163
+
164
+ def save_uploaded_file(file):
165
+ """Save uploaded file to a temporary location and return the path."""
166
+ temp_dir = tempfile.gettempdir()
167
+ temp_path = os.path.join(temp_dir, file.name)
168
+
169
+ with open(temp_path, "wb") as f:
170
+ f.write(file.read())
171
+
172
+ return temp_path
173
+
174
+ def clear_conversation():
175
+ """Clear the conversation history for the current session."""
176
+ global conversation_history, current_session_id
177
+
178
+ if current_session_id and current_session_id in conversation_history:
179
+ conversation_history[current_session_id] = []
180
+
181
+ return [], f"Conversation cleared. You can continue asking questions about '{current_document_name}'."
182
+
183
+ def process_uploaded_document(file):
184
+ """Process an uploaded document and set up the session."""
185
+ global current_session_id, current_document_store, current_document_name, conversation_history
186
+
187
+ try:
188
+ if file is None:
189
+ return None, "Please upload a document first."
190
+
191
+ # Save the uploaded file
192
+ file_path = save_uploaded_file(file)
193
+
194
+ # Create document AI bot if not already created
195
+ if not hasattr(process_uploaded_document, "bot"):
196
+ process_uploaded_document.bot = DocumentAIBot()
197
+
198
+ # Process the document
199
+ vector_store = process_uploaded_document.bot.process_document(file_path)
200
+
201
+ # Create a new session
202
+ session_id = generate_session_id()
203
+ conversation_history[session_id] = []
204
+
205
+ # Update global variables
206
+ current_session_id = session_id
207
+ current_document_store = vector_store
208
+ current_document_name = file.name
209
+
210
+ return [], f"Document '{file.name}' processed successfully. You can now ask questions about it."
211
+
212
+ except Exception as e:
213
+ import traceback
214
+ traceback.print_exc()
215
+ return None, f"Error processing document: {str(e)}"
216
+
217
+ def answer_question(question, history):
218
+ """Answer a question about the current document."""
219
+ global current_session_id, current_document_store, conversation_history
220
+
221
+ if not current_document_store:
222
+ return "Please upload a document first."
223
+
224
+ if not hasattr(process_uploaded_document, "bot"):
225
+ return "Document AI bot not initialized. Please reload the page and try again."
226
+
227
+ try:
228
+ # Get current chat history
229
+ chat_history = conversation_history.get(current_session_id, [])
230
+
231
+ # Get answer
232
+ answer, updated_history = process_uploaded_document.bot.get_answer(
233
+ question,
234
+ current_session_id,
235
+ current_document_store,
236
+ chat_history
237
+ )
238
+
239
+ # Update conversation history
240
+ conversation_history[current_session_id] = updated_history
241
+
242
+ return answer
243
+
244
+ except Exception as e:
245
+ import traceback
246
+ traceback.print_exc()
247
+ return f"Error generating answer: {str(e)}"
248
+
249
+ def build_interface():
250
+ """Build and launch the Gradio interface."""
251
+ # Define the Gradio blocks
252
+ with gr.Blocks(title="Document AI Chatbot") as interface:
253
+ gr.Markdown("# 📄 Document AI Chatbot")
254
+ gr.Markdown("Upload a document (PDF, TXT, DOCX, PPTX) and ask questions about its content.")
255
+
256
+ with gr.Row():
257
+ with gr.Column(scale=1):
258
+ # Document upload and processing section
259
+ file_input = gr.File(
260
+ label="Upload Document",
261
+ file_types=[".pdf", ".txt", ".docx", ".pptx"],
262
+ type="file"
263
+ )
264
+
265
+ upload_button = gr.Button("Process Document", variant="primary")
266
+ upload_status = gr.Textbox(label="Upload Status", interactive=False)
267
+
268
+ clear_button = gr.Button("Clear Conversation")
269
+
270
+ gr.Markdown("### System Information")
271
+ system_info = gr.Markdown(f"""
272
+ - Embedding Model: {EMBEDDING_MODEL}
273
+ - Language Model: {LLM_MODEL}
274
+ - Running on: {DEVICE}
275
+ - Chunk Size: {CHUNK_SIZE}
276
+ - Relevance Threshold: {THRESHOLD}
277
+ """)
278
+
279
+ with gr.Column(scale=2):
280
+ # Chat interface
281
+ chatbot = gr.Chatbot(
282
+ label="Conversation",
283
+ height=500,
284
+ show_label=True,
285
+ )
286
+
287
+ with gr.Row():
288
+ question_input = gr.Textbox(
289
+ label="Ask a question about the document",
290
+ placeholder="What is the main topic of this document?",
291
+ lines=2,
292
+ max_lines=5,
293
+ interactive=True,
294
+ show_label=True
295
+ )
296
+
297
+ submit_button = gr.Button("Submit", variant="primary")
298
+
299
+ # Set up event handlers
300
+ upload_button.click(
301
+ process_uploaded_document,
302
+ inputs=[file_input],
303
+ outputs=[chatbot, upload_status]
304
+ )
305
+
306
+ submit_button.click(
307
+ answer_question,
308
+ inputs=[question_input, chatbot],
309
+ outputs=chatbot
310
+ ).then(
311
+ lambda: "",
312
+ None,
313
+ question_input
314
+ )
315
+
316
+ question_input.submit(
317
+ answer_question,
318
+ inputs=[question_input, chatbot],
319
+ outputs=chatbot
320
+ ).then(
321
+ lambda: "",
322
+ None,
323
+ question_input
324
+ )
325
+
326
+ clear_button.click(
327
+ clear_conversation,
328
+ inputs=[],
329
+ outputs=[chatbot, upload_status]
330
+ )
331
+
332
+ # Add CSS for better styling
333
+ interface.load(
334
+ js="""
335
+ () => {
336
+ document.querySelector('body').style.backgroundColor = '#f7f7f7';
337
+ document.querySelector('.gradio-container').style.maxWidth = '1200px';
338
+ }
339
+ """
340
+ )
341
+
342
+ return interface
343
+
344
+ # Main execution
345
+ if __name__ == "__main__":
346
+ demo = build_interface()
347
+ demo.launch(
348
+ share=True,
349
+ server_name="0.0.0.0",
350
+ server_port=7860,
351
+ debug=True,
352
+ show_api=False
353
+ )