wellwisherofindia commited on
Commit
3dc7d4f
Β·
1 Parent(s): 9f8df1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +204 -205
app.py CHANGED
@@ -2,283 +2,282 @@ import os
2
  import tempfile
3
  import gradio as gr
4
  import numpy as np
5
- from sklearn.metrics.pairwise import cosine_similarity
6
  from sentence_transformers import SentenceTransformer
7
 
8
  import google.generativeai as genai
9
  import fitz # PyMuPDF
10
- import traceback # Import traceback for detailed error logging
11
 
12
  # Initialize embedding model
13
  sbert_model = SentenceTransformer('all-MiniLM-L6-v2')
14
 
15
  # Data storage
16
  chunks = []
17
- embeddings = np.array([])
18
- # Global state for API key
19
- # stored_api_key = None # Replaced by gr.State
20
 
21
  def extract_text_from_pdf(pdf_file_path, start_page=None, end_page=None):
22
- """
23
- Extract text from PDF file, optionally from a specific page range.
24
- Page numbers are 1-indexed.
25
- """
26
  doc = fitz.open(pdf_file_path)
27
  text = ""
28
-
29
- pages_to_process = []
30
  num_pages_in_doc = doc.page_count
31
 
32
  if start_page is not None and end_page is not None:
33
  start_idx = start_page - 1
34
  end_idx = end_page - 1
35
-
36
  if 0 <= start_idx <= end_idx < num_pages_in_doc:
37
- for i in range(start_idx, end_idx + 1):
38
- pages_to_process.append(doc.load_page(i))
39
  else:
40
- print(f"Warning: Invalid page range received in extract_text_from_pdf. Defaulting to all pages.")
41
- pages_to_process = [doc.load_page(i) for i in range(num_pages_in_doc)]
42
  else:
43
- pages_to_process = [doc.load_page(i) for i in range(num_pages_in_doc)]
44
 
45
- for page_obj in pages_to_process:
46
- text += page_obj.get_text()
47
 
48
  doc.close()
49
  return text, num_pages_in_doc
50
 
51
-
52
  def chunk_text(text, chunk_size=1000, overlap=200):
53
  """Split text into overlapping chunks"""
54
  doc_chunks = []
55
  for i in range(0, len(text), chunk_size - overlap):
56
  chunk = text[i:i + chunk_size]
57
- if len(chunk) > 100: # Ensure chunks are substantial
58
  doc_chunks.append(chunk)
59
  return doc_chunks
60
 
61
- def process_pdf(pdf_file_obj, processing_mode, start_page_ui, end_page_ui):
62
- """Process PDF (full or page range) and create embeddings."""
63
- global chunks, embeddings
64
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  if pdf_file_obj is None:
66
- chunks = []
67
- embeddings = np.array([])
68
- return "No PDF file provided. Please upload a PDF."
69
 
70
- tmp_path = None
71
  try:
 
72
  with open(pdf_file_obj.name, "rb") as f_in:
73
  pdf_bytes = f_in.read()
74
  with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
75
  tmp.write(pdf_bytes)
76
  tmp_path = tmp.name
77
- except Exception as e:
78
- if tmp_path and os.path.exists(tmp_path):
79
- os.unlink(tmp_path)
80
- return f"Error handling uploaded PDF file: {str(e)}"
81
-
82
- actual_start_page, actual_end_page = None, None
83
- page_info_str = "full document"
84
- pdf_name = os.path.basename(pdf_file_obj.name)
85
-
86
- try:
87
- doc_for_page_count = fitz.open(tmp_path)
88
- total_pages_in_doc = doc_for_page_count.page_count
89
- doc_for_page_count.close()
90
-
91
- if processing_mode == "Page Range":
92
- if start_page_ui is None or end_page_ui is None:
93
- raise ValueError("For 'Page Range' mode, both Start Page and End Page must be specified.")
94
-
95
- s_page = int(start_page_ui)
96
- e_page = int(end_page_ui)
97
-
98
- if not (1 <= s_page <= total_pages_in_doc and \
99
- 1 <= e_page <= total_pages_in_doc and \
100
- s_page <= e_page):
101
- raise ValueError(f"Invalid page range ({s_page}-{e_page}). Document has {total_pages_in_doc} pages.")
102
- actual_start_page, actual_end_page = s_page, e_page
103
- page_info_str = f"pages {s_page}-{e_page}"
104
-
105
- text, _ = extract_text_from_pdf(tmp_path, start_page=actual_start_page, end_page=actual_end_page)
106
 
 
 
 
107
  if not text.strip():
108
- chunks = []
109
- embeddings = np.array([])
110
- return f"Processed {page_info_str} of '{pdf_name}', but no text content found. Old data cleared."
111
-
112
- current_book_chunks = chunk_text(text)
113
-
114
- if not current_book_chunks:
115
- chunks = []
116
- embeddings = np.array([])
117
- return f"Processed {page_info_str} of '{pdf_name}', but no substantial chunks created. Old data cleared."
118
-
119
- current_book_embeddings = sbert_model.encode(current_book_chunks)
120
-
121
- chunks = current_book_chunks
122
- embeddings = np.array(current_book_embeddings)
123
-
124
- return f"Successfully processed {page_info_str} of '{pdf_name}'. Created {len(chunks)} chunks. Ready for questions."
 
 
 
 
 
 
 
 
 
125
 
126
- except ValueError as ve:
127
- return f"Error: {str(ve)}"
128
  except Exception as e:
129
  chunks = []
130
- embeddings = np.array([])
131
- error_msg = f"Error processing '{pdf_name}' ({page_info_str}): {str(e)}\n{traceback.format_exc()}"
132
- print(error_msg)
133
- return error_msg
134
- finally:
135
- if tmp_path and os.path.exists(tmp_path):
136
- os.unlink(tmp_path)
137
-
138
 
139
- def retrieve_relevant_chunks(query, top_k=5):
140
- """Retrieve most relevant chunks based on query."""
141
- global chunks, embeddings
142
 
143
- if not chunks or not isinstance(embeddings, np.ndarray) or embeddings.size == 0:
144
- return ["No document processed or no content yielded. Please process a PDF."]
145
 
146
- query_embedding = sbert_model.encode([query])[0]
147
- similarities = cosine_similarity([query_embedding], embeddings)[0]
148
-
149
- num_available_chunks = len(chunks)
150
- actual_top_k = min(top_k, num_available_chunks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
- if actual_top_k == 0:
153
- return ["No relevant chunks found."]
 
 
154
 
155
- top_indices = np.argsort(similarities)[-actual_top_k:][::-1]
156
- # Return the actual chunk text
157
- top_chunks_text = [chunks[i] for i in top_indices]
158
 
159
- return top_chunks_text
 
 
160
 
161
- def generate_response(query, current_api_key_state):
162
- """Generate response using Gemini API and RAG, including sources."""
163
- if not current_api_key_state:
164
- return "API key not set. Please enter your API key and click 'Confirm API Key'.", "" # Return empty string for sources
165
 
166
  try:
167
- genai.configure(api_key=current_api_key_state)
168
- context_chunks = retrieve_relevant_chunks(query)
 
 
 
 
 
 
169
 
170
- if not context_chunks or "No document" in context_chunks[0] or "No relevant chunks" in context_chunks[0]:
171
- return f"Could not retrieve relevant context. Ensure a PDF is processed.\n\nDetails: {context_chunks[0]}", "" # Return empty string for sources
 
172
 
173
- context_for_prompt = "\n\n".join(context_chunks)
 
174
 
175
- prompt = f"""
176
- Based on the following context from a book, please answer the query.
177
 
178
- CONTEXT:
179
- {context_for_prompt}
180
 
181
- QUERY:
182
- {query}
183
-
184
- Please provide a helpful and accurate response based only on the information in the context. If the context doesn't provide an answer, say so.
185
- """
186
-
187
- gemini_model_instance = genai.GenerativeModel('gemini-1.5-flash-latest')
188
- response = gemini_model_instance.generate_content(prompt)
189
-
190
- # Prepare sources text
191
- sources_text = "--- Sources (Context Chunks) ---\n"
192
- for i, chunk in enumerate(context_chunks):
193
- sources_text += f"\n[Source {i+1}]:\n{chunk}\n"
194
-
195
- return response.text, sources_text # Return answer and sources separately
196
 
197
  except Exception as e:
198
- return f"Error generating response: {str(e)}\n{traceback.format_exc()}", "" # Return empty string for sources
199
-
200
- with gr.Blocks(title="RAG Book Assistant") as demo:
201
- api_key_state = gr.State(None) # Store the API key
202
-
203
- gr.Markdown("# πŸ“š RAG Book Assistant")
204
- gr.Markdown(
205
- "Upload a PDF book. You can choose to process the full book or a specific page range. "
206
- "Processing a new PDF (or range) will **replace the current one**.\n\n"
207
- )
208
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  with gr.Row():
210
  with gr.Column(scale=2):
211
- with gr.Group() as api_key_group:
212
- api_key_input = gr.Textbox(label="Gemini API Key", type="password", elem_id="api_key_input_id")
213
- confirm_api_key_btn = gr.Button("Confirm API Key", elem_id="confirm_api_key_btn_id")
214
-
215
- api_key_status_output = gr.Markdown(visible=False, value="API Key Set!", elem_id="api_key_status_id")
216
-
217
- pdf_input = gr.File(label="Upload PDF Book", file_types=['.pdf'])
218
-
219
- processing_mode_input = gr.Radio(
220
- label="Processing Mode",
221
- choices=["Full Book", "Page Range"],
222
- value="Full Book",
223
- interactive=True
224
  )
225
-
226
- with gr.Row(visible=False) as page_range_ui_row:
227
- start_page_input = gr.Number(label="Start Page", minimum=1, precision=0, interactive=True)
228
- end_page_input = gr.Number(label="End Page", minimum=1, precision=0, interactive=True)
229
-
230
- process_btn = gr.Button("Process PDF (Replaces Current Book)")
231
-
232
- query_input = gr.Textbox(label="Ask a question about the current book", lines=3)
233
- submit_btn = gr.Button("Submit Question")
234
-
235
  with gr.Column(scale=1):
236
- status_output = gr.Textbox(label="Processing Status", interactive=False, lines=5)
237
- response_output = gr.Textbox(label="Response (Answer)", interactive=False, lines=10)
238
- sources_output = gr.Textbox(label="Sources (Context Chunks)", interactive=False, lines=10)
239
-
240
- # Logic to show/hide page range inputs
241
- def update_page_range_visibility(mode):
242
- return gr.Row(visible=(mode == "Page Range"))
243
-
244
- processing_mode_input.change(
245
- fn=update_page_range_visibility,
246
- inputs=processing_mode_input,
247
- outputs=page_range_ui_row
248
  )
249
-
250
- # Logic to handle API key confirmation
251
- def confirm_api_key(api_key):
252
- if api_key:
253
- return {
254
- api_key_state: api_key,
255
- api_key_group: gr.Group(visible=False),
256
- api_key_status_output: gr.Markdown(visible=True, value="API Key Set and Hidden.")
257
- }
258
- else:
259
- return {
260
- api_key_state: None,
261
- api_key_group: gr.Group(visible=True),
262
- api_key_status_output: gr.Markdown(visible=True, value="Please enter an API Key.")
263
- }
264
-
265
- confirm_api_key_btn.click(
266
- fn=confirm_api_key,
267
- inputs=[api_key_input],
268
- outputs=[api_key_state, api_key_group, api_key_status_output]
269
  )
270
-
271
- process_btn.click(
272
- process_pdf,
273
- inputs=[pdf_input, processing_mode_input, start_page_input, end_page_input],
274
- outputs=[status_output]
 
 
 
 
 
 
 
 
275
  )
276
-
 
 
 
 
 
 
277
  submit_btn.click(
278
- generate_response,
279
- inputs=[query_input, api_key_state],
280
- outputs=[response_output, sources_output]
 
 
 
 
 
 
 
 
 
 
 
281
  )
282
 
283
  if __name__ == "__main__":
284
- demo.launch(share=True)
 
2
  import tempfile
3
  import gradio as gr
4
  import numpy as np
5
+ import faiss
6
  from sentence_transformers import SentenceTransformer
7
 
8
  import google.generativeai as genai
9
  import fitz # PyMuPDF
10
+ import traceback
11
 
12
  # Initialize embedding model
13
  sbert_model = SentenceTransformer('all-MiniLM-L6-v2')
14
 
15
  # Data storage
16
  chunks = []
17
+ faiss_index = None
18
+ embedding_dimension = 384 # all-MiniLM-L6-v2 embedding dimension
 
19
 
20
  def extract_text_from_pdf(pdf_file_path, start_page=None, end_page=None):
21
+ """Extract text from PDF file, optionally from a specific page range."""
 
 
 
22
  doc = fitz.open(pdf_file_path)
23
  text = ""
 
 
24
  num_pages_in_doc = doc.page_count
25
 
26
  if start_page is not None and end_page is not None:
27
  start_idx = start_page - 1
28
  end_idx = end_page - 1
 
29
  if 0 <= start_idx <= end_idx < num_pages_in_doc:
30
+ pages_to_process = range(start_idx, end_idx + 1)
 
31
  else:
32
+ pages_to_process = range(num_pages_in_doc)
 
33
  else:
34
+ pages_to_process = range(num_pages_in_doc)
35
 
36
+ for i in pages_to_process:
37
+ text += doc.load_page(i).get_text()
38
 
39
  doc.close()
40
  return text, num_pages_in_doc
41
 
 
42
  def chunk_text(text, chunk_size=1000, overlap=200):
43
  """Split text into overlapping chunks"""
44
  doc_chunks = []
45
  for i in range(0, len(text), chunk_size - overlap):
46
  chunk = text[i:i + chunk_size]
47
+ if len(chunk) > 100:
48
  doc_chunks.append(chunk)
49
  return doc_chunks
50
 
51
+ def create_faiss_index(embeddings):
52
+ """Create FAISS index for fast similarity search."""
53
+ global embedding_dimension
54
+
55
+ # Normalize embeddings for cosine similarity
56
+ faiss.normalize_L2(embeddings)
57
+
58
+ # Create index - using IndexFlatIP for cosine similarity
59
+ index = faiss.IndexFlatIP(embedding_dimension)
60
+ index.add(embeddings)
61
+
62
+ return index
63
+
64
+ def process_pdf(pdf_file_obj, api_key):
65
+ """Process PDF and create FAISS index."""
66
+ global chunks, faiss_index
67
+
68
+ if not api_key:
69
+ return None, [["System", "⚠️ Please set your Gemini API key first."]]
70
+
71
  if pdf_file_obj is None:
72
+ return None, [["System", "πŸ“„ Please upload a PDF file."]]
 
 
73
 
 
74
  try:
75
+ # Save uploaded file temporarily
76
  with open(pdf_file_obj.name, "rb") as f_in:
77
  pdf_bytes = f_in.read()
78
  with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
79
  tmp.write(pdf_bytes)
80
  tmp_path = tmp.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+ # Extract text
83
+ text, total_pages = extract_text_from_pdf(tmp_path)
84
+
85
  if not text.strip():
86
+ return None, [["System", "⚠️ No text found in the PDF. Please try a different file."]]
87
+
88
+ # Create chunks
89
+ current_chunks = chunk_text(text)
90
+ if not current_chunks:
91
+ return None, [["System", "⚠️ Could not create text chunks from the PDF."]]
92
+
93
+ # Generate embeddings
94
+ current_embeddings = sbert_model.encode(current_chunks)
95
+ current_embeddings = np.array(current_embeddings, dtype=np.float32)
96
+
97
+ # Create FAISS index
98
+ current_index = create_faiss_index(current_embeddings)
99
+
100
+ # Update global storage
101
+ chunks = current_chunks
102
+ faiss_index = current_index
103
+
104
+ pdf_name = os.path.basename(pdf_file_obj.name)
105
+ success_msg = f"βœ… Successfully processed '{pdf_name}' ({total_pages} pages, {len(chunks)} chunks). FAISS index created! You can now ask questions!"
106
+
107
+ # Clean up
108
+ if os.path.exists(tmp_path):
109
+ os.unlink(tmp_path)
110
+
111
+ return None, [["System", success_msg]]
112
 
 
 
113
  except Exception as e:
114
  chunks = []
115
+ faiss_index = None
116
+ error_msg = f"❌ Error processing PDF: {str(e)}"
117
+ return None, [["System", error_msg]]
 
 
 
 
 
118
 
119
+ def retrieve_relevant_chunks(query, top_k=3):
120
+ """Retrieve most relevant chunks using FAISS search."""
121
+ global chunks, faiss_index
122
 
123
+ if not chunks or faiss_index is None:
124
+ return []
125
 
126
+ try:
127
+ # Encode query
128
+ query_embedding = sbert_model.encode([query])
129
+ query_embedding = np.array(query_embedding, dtype=np.float32)
130
+
131
+ # Normalize for cosine similarity
132
+ faiss.normalize_L2(query_embedding)
133
+
134
+ # Search using FAISS
135
+ scores, indices = faiss_index.search(query_embedding, top_k)
136
+
137
+ # Get top chunks
138
+ top_chunks = []
139
+ for idx in indices[0]:
140
+ if idx < len(chunks): # Safety check
141
+ top_chunks.append(chunks[idx])
142
+
143
+ return top_chunks
144
+
145
+ except Exception as e:
146
+ print(f"Error in FAISS search: {str(e)}")
147
+ return []
148
 
149
+ def chat_fn(message, history, api_key):
150
+ """Handle chat interaction."""
151
+ if not message.strip():
152
+ return history, ""
153
 
154
+ # Add user message to history
155
+ history = history + [[message, None]]
 
156
 
157
+ if not api_key:
158
+ history[-1][1] = "⚠️ Please set your Gemini API key first."
159
+ return history, ""
160
 
161
+ if not chunks or faiss_index is None:
162
+ history[-1][1] = "πŸ“„ Please upload and process a PDF document first."
163
+ return history, ""
 
164
 
165
  try:
166
+ # Configure Gemini
167
+ genai.configure(api_key=api_key)
168
+
169
+ # Get relevant context using FAISS
170
+ context_chunks = retrieve_relevant_chunks(message, top_k=5)
171
+ if not context_chunks:
172
+ history[-1][1] = "❌ Could not find relevant information in the document."
173
+ return history, ""
174
 
175
+ # Generate response
176
+ context = "\n\n".join(context_chunks)
177
+ prompt = f"""Based on the following context from the document, answer the user's question.
178
 
179
+ Context:
180
+ {context}
181
 
182
+ Question: {message}
 
183
 
184
+ Please provide a clear, accurate answer based only on the information in the context. If the context doesn't contain enough information to answer the question, say so."""
 
185
 
186
+ model = genai.GenerativeModel('gemini-1.5-flash-latest')
187
+ response = model.generate_content(prompt)
188
+
189
+ history[-1][1] = response.text
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  except Exception as e:
192
+ history[-1][1] = f"❌ Error: {str(e)}"
193
+
194
+ return history, ""
195
+
196
+ # Custom CSS for better chat appearance
197
+ css = """
198
+ .gradio-container {
199
+ max-width: 800px !important;
200
+ margin: auto !important;
201
+ }
202
+ .chat-message {
203
+ padding: 10px !important;
204
+ margin: 5px 0 !important;
205
+ border-radius: 10px !important;
206
+ }
207
+ """
208
+
209
+ with gr.Blocks(css=css, title="πŸ“š Chat with Your PDF") as demo:
210
+ api_key_state = gr.State("")
211
+
212
+ gr.Markdown("""
213
+ # πŸ“š Chat with Your PDF (FAISS-Powered)
214
+ Upload a PDF document and chat with it naturally. Now with FAISS for faster vector search!
215
+ """)
216
+
217
  with gr.Row():
218
  with gr.Column(scale=2):
219
+ api_key_input = gr.Textbox(
220
+ label="πŸ”‘ Gemini API Key",
221
+ type="password",
222
+ placeholder="Enter your API key here..."
 
 
 
 
 
 
 
 
 
223
  )
 
 
 
 
 
 
 
 
 
 
224
  with gr.Column(scale=1):
225
+ pdf_input = gr.File(
226
+ label="πŸ“„ Upload PDF",
227
+ file_types=['.pdf']
228
+ )
229
+
230
+ # Chat interface
231
+ chatbot = gr.Chatbot(
232
+ label="πŸ’¬ Chat",
233
+ height=500,
234
+ show_label=False,
235
+ bubble_full_width=False
 
236
  )
237
+
238
+ msg_input = gr.Textbox(
239
+ label="Message",
240
+ placeholder="Ask anything about your PDF...",
241
+ show_label=False,
242
+ container=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  )
244
+
245
+ with gr.Row():
246
+ submit_btn = gr.Button("Send πŸ’¬", variant="primary")
247
+ clear_btn = gr.Button("Clear Chat πŸ—‘οΈ")
248
+
249
+ # Event handlers
250
+ def update_api_key(key):
251
+ return key
252
+
253
+ api_key_input.change(
254
+ fn=update_api_key,
255
+ inputs=api_key_input,
256
+ outputs=api_key_state
257
  )
258
+
259
+ pdf_input.upload(
260
+ fn=process_pdf,
261
+ inputs=[pdf_input, api_key_state],
262
+ outputs=[msg_input, chatbot]
263
+ )
264
+
265
  submit_btn.click(
266
+ fn=chat_fn,
267
+ inputs=[msg_input, chatbot, api_key_state],
268
+ outputs=[chatbot, msg_input]
269
+ )
270
+
271
+ msg_input.submit(
272
+ fn=chat_fn,
273
+ inputs=[msg_input, chatbot, api_key_state],
274
+ outputs=[chatbot, msg_input]
275
+ )
276
+
277
+ clear_btn.click(
278
+ fn=lambda: ([], ""),
279
+ outputs=[chatbot, msg_input]
280
  )
281
 
282
  if __name__ == "__main__":
283
+ demo.launch(share=True)