SimranShaikh commited on
Commit
998a186
Β·
verified Β·
1 Parent(s): 57ecfe6
Files changed (1) hide show
  1. src/streamlit_app.py +248 -54
src/streamlit_app.py CHANGED
@@ -1,4 +1,4 @@
1
- # Fixed SimplePDFRAG with better state management and PDF caching
2
  import streamlit as st
3
  import PyPDF2
4
  from sentence_transformers import SentenceTransformer
@@ -9,6 +9,7 @@ from sklearn.metrics.pairwise import cosine_similarity
9
  import logging
10
  import os
11
  import tempfile
 
12
 
13
  # Configure logging
14
  logging.basicConfig(level=logging.INFO)
@@ -22,6 +23,7 @@ class SimplePDFRAG:
22
  self.granite_model = None
23
  self.tokenizer = None
24
  self.pdf_name = None
 
25
 
26
  def setup_cache_directory(self):
27
  try:
@@ -30,6 +32,7 @@ class SimplePDFRAG:
30
  os.environ['TRANSFORMERS_CACHE'] = cache_dir
31
  os.environ['SENTENCE_TRANSFORMERS_HOME'] = cache_dir
32
  st.info(f"Using cache directory: {cache_dir}")
 
33
  return cache_dir
34
  except Exception as e:
35
  st.error(f"Error setting up cache directory: {e}")
@@ -40,18 +43,46 @@ class SimplePDFRAG:
40
  cache_dir = self.setup_cache_directory()
41
  st.info("Loading embedding model...")
42
  self.embedding_model = SentenceTransformer(
43
- 'all-MiniLM-L6-v2', cache_folder=cache_dir
44
  )
 
45
  st.info("Loading IBM Granite model...")
46
- model_name = "ibm-granite/granite-3-2-2b-instruct"
47
- self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
48
- self.granite_model = AutoModelForCausalLM.from_pretrained(
49
- model_name, cache_dir=cache_dir, torch_dtype=torch.float32
 
 
 
 
 
 
50
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  if self.tokenizer.pad_token is None:
52
  self.tokenizer.pad_token = self.tokenizer.eos_token
 
53
  st.success("Models loaded successfully!")
54
  return True
 
55
  except Exception as e:
56
  st.error(f"Error loading models: {e}")
57
  logger.error(f"Model loading error: {e}")
@@ -63,6 +94,8 @@ class SimplePDFRAG:
63
  pdf_reader = PyPDF2.PdfReader(pdf_file)
64
  text = ""
65
  st.info(f"PDF has {len(pdf_reader.pages)} pages")
 
 
66
  for page_num, page in enumerate(pdf_reader.pages):
67
  try:
68
  page_text = page.extract_text()
@@ -73,42 +106,78 @@ class SimplePDFRAG:
73
  st.warning(f"⚠️ No text found on page {page_num + 1}")
74
  except Exception as page_error:
75
  st.error(f"Error extracting page {page_num + 1}: {page_error}")
 
 
 
 
 
 
76
  if text.strip():
77
- st.success(f"Extracted {len(text)} characters")
78
  st.write("πŸ“„ **Text Preview:**")
79
  st.text(text[:500] + "..." if len(text) > 500 else text)
80
  return text
81
  else:
82
  st.error("No text could be extracted from the PDF")
83
  return None
 
84
  except Exception as e:
85
  st.error(f"Error reading PDF file: {e}")
86
  logger.error(f"PDF extraction error: {e}")
87
  return None
88
 
89
- def chunk_text(self, text, chunk_size=500):
 
90
  if not text or not text.strip():
91
  return []
 
92
  words = text.split()
93
- return [" ".join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]
 
 
 
 
 
 
 
94
 
95
  def process_pdf(self, pdf_file, pdf_name):
96
  try:
97
  self.pdf_name = pdf_name
98
  st.info("πŸ” Extracting text from PDF...")
99
  text = self.extract_pdf_text(pdf_file)
 
100
  if not text:
101
  return False
102
- st.info("βœ‚οΈ Splitting text into chunks...")
 
103
  chunks = self.chunk_text(text)
 
104
  if not chunks:
 
105
  return False
 
106
  st.info(f"πŸ”„ Creating embeddings for {len(chunks)} chunks...")
107
- embeddings = self.embedding_model.encode(chunks, show_progress_bar=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  self.documents = chunks
109
- self.embeddings = embeddings
 
110
  st.success(f"βœ… Successfully processed PDF: {len(chunks)} chunks created with embeddings")
111
  return True
 
112
  except Exception as e:
113
  st.error(f"❌ Error processing PDF: {e}")
114
  logger.error(f"PDF processing error: {e}")
@@ -118,12 +187,26 @@ class SimplePDFRAG:
118
  if not self.documents or len(self.embeddings) == 0:
119
  st.warning("No documents available for search")
120
  return []
 
121
  try:
122
  query_embedding = self.embedding_model.encode([query])
123
  similarities = cosine_similarity(query_embedding, self.embeddings)[0]
124
- top_indices = np.argsort(similarities)[-top_k:][::-1]
 
 
 
 
 
 
 
 
 
 
 
 
125
  return [{'text': self.documents[i], 'score': similarities[i]}
126
- for i in top_indices if similarities[i] > 0.1]
 
127
  except Exception as e:
128
  st.error(f"Error searching documents: {e}")
129
  logger.error(f"Search error: {e}")
@@ -132,8 +215,13 @@ class SimplePDFRAG:
132
  def generate_answer(self, query, context_docs):
133
  if not self.granite_model or not context_docs:
134
  return "I don't have enough information to answer your question."
135
- context = "\n\n".join([doc['text'][:200] for doc in context_docs])
136
- prompt = f"""You are a helpful AI assistant. Based on the following context, provide a clear and accurate answer to the question.
 
 
 
 
 
137
 
138
  Context:
139
  {context}
@@ -141,39 +229,76 @@ Context:
141
  Question: {query}
142
 
143
  Answer:"""
 
144
  try:
145
- inputs = self.tokenizer.encode(prompt, return_tensors='pt', max_length=512, truncation=True)
 
 
 
 
 
 
 
146
  with torch.no_grad():
147
  outputs = self.granite_model.generate(
148
  inputs,
149
- max_length=inputs.shape[1] + 100,
150
  temperature=0.7,
151
  do_sample=True,
152
  pad_token_id=self.tokenizer.eos_token_id,
153
- eos_token_id=self.tokenizer.eos_token_id
 
 
154
  )
155
- response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
156
- return response.strip() if len(response.strip()) >= 10 else context[:300] + "..."
 
 
 
 
 
 
 
 
 
 
 
 
157
  except Exception as e:
158
  logger.error(f"Generation error: {e}")
159
- return context[:300] + "..."
 
 
 
 
160
 
161
  def answer_question(self, query):
162
  if not self.documents:
163
  return {'answer': "No PDF has been processed yet.", 'sources': []}
 
164
  relevant_docs = self.search_documents(query)
 
165
  if not relevant_docs:
166
- return {'answer': "No relevant information found.", 'sources': []}
 
 
 
167
  return {
168
- 'answer': self.generate_answer(query, relevant_docs),
169
  'sources': relevant_docs
170
  }
171
 
172
  def main():
173
- st.set_page_config(page_title="Simple PDF RAG with IBM Granite (Fixed)", page_icon="πŸ“„", layout="wide")
174
- st.title("πŸ“„ Simple PDF RAG with IBM Granite (Fixed)")
175
- st.write("Upload a PDF and ask questions about its content")
 
 
 
 
 
176
 
 
177
  if 'rag_system' not in st.session_state:
178
  st.session_state.rag_system = SimplePDFRAG()
179
  if 'models_loaded' not in st.session_state:
@@ -185,72 +310,141 @@ def main():
185
  if 'uploaded_file_path' not in st.session_state:
186
  st.session_state.uploaded_file_path = None
187
 
 
188
  col1, col2, col3 = st.columns(3)
189
  with col1:
190
- st.success("πŸ€– Models: Loaded" if st.session_state.models_loaded else "πŸ€– Models: Not Loaded")
 
 
 
 
191
  with col2:
192
- st.success(f"πŸ“„ PDF: {st.session_state.current_pdf_name}" if st.session_state.pdf_processed else "πŸ“„ PDF: Not Processed")
 
 
 
 
193
  with col3:
194
- st.success("🟒 Ready" if st.session_state.models_loaded and st.session_state.pdf_processed else "πŸ”΄ Not Ready")
 
 
 
195
 
 
196
  if not st.session_state.models_loaded:
197
- if st.button("πŸ€– Load Models"):
198
- with st.spinner("Loading models..."):
 
 
 
 
199
  success = st.session_state.rag_system.load_models()
200
  st.session_state.models_loaded = success
 
 
201
  st.rerun()
202
 
 
203
  if st.session_state.models_loaded:
204
  st.markdown("---")
205
  st.subheader("πŸ“ PDF Upload and Processing")
206
- uploaded_file = st.file_uploader("Upload PDF", type=["pdf"], key="pdf_uploader")
 
 
 
 
 
 
207
 
208
  if uploaded_file:
 
209
  with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
210
  tmp.write(uploaded_file.read())
211
  st.session_state.uploaded_file_path = tmp.name
212
  st.session_state.uploaded_file_name = uploaded_file.name
213
  st.session_state.pdf_processed = False
214
  st.session_state.current_pdf_name = None
 
215
  st.success(f"πŸ“„ Uploaded: {uploaded_file.name}")
216
 
217
  if st.session_state.uploaded_file_path and not st.session_state.pdf_processed:
218
- if st.button("πŸ“– Process PDF"):
219
- with st.spinner("Processing PDF..."):
220
- with open(st.session_state.uploaded_file_path, "rb") as f:
221
- success = st.session_state.rag_system.process_pdf(f, st.session_state.uploaded_file_name)
222
- if success:
223
- st.session_state.pdf_processed = True
224
- st.session_state.current_pdf_name = st.session_state.uploaded_file_name
225
- st.success("βœ… PDF processed!")
226
- st.rerun()
 
 
 
 
 
 
 
 
 
 
227
 
 
228
  if st.session_state.models_loaded and st.session_state.pdf_processed:
229
  st.markdown("---")
230
  st.subheader("❓ Ask Questions")
231
- st.info(f"πŸ“š Current document: {st.session_state.current_pdf_name}")
232
- query = st.text_input("Ask a question:", placeholder="e.g., What is the main topic?")
233
- if query and st.button("πŸ” Get Answer"):
234
- with st.spinner("Searching and generating answer..."):
 
 
 
 
 
 
235
  result = st.session_state.rag_system.answer_question(query)
 
236
  st.markdown("### πŸ€– Answer:")
237
  st.write(result['answer'])
 
238
  if result.get('sources'):
239
  st.markdown("### πŸ“š Sources:")
240
  for i, src in enumerate(result['sources']):
241
- with st.expander(f"Source {i+1} (Score: {src['score']:.3f})"):
242
  st.write(src['text'][:500] + "..." if len(src['text']) > 500 else src['text'])
243
 
 
244
  with st.sidebar:
245
- st.header("πŸ“‹ Instructions")
246
- st.markdown("1. Load Models\n2. Upload PDF\n3. Process PDF\n4. Ask Questions")
247
- st.header("πŸ”§ Debug Info")
248
- st.write("βœ… Models loaded" if st.session_state.models_loaded else "❌ Models not loaded")
249
- st.write(f"βœ… PDF: {st.session_state.current_pdf_name}" if st.session_state.pdf_processed else "❌ No PDF processed")
250
- if st.button("πŸ”„ Reset All"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  for key in list(st.session_state.keys()):
252
  del st.session_state[key]
 
 
 
 
253
  st.rerun()
254
 
255
  if __name__ == "__main__":
256
- main()
 
1
+ # Improved SimplePDFRAG with better error handling and model optimization
2
  import streamlit as st
3
  import PyPDF2
4
  from sentence_transformers import SentenceTransformer
 
9
  import logging
10
  import os
11
  import tempfile
12
+ import gc
13
 
14
  # Configure logging
15
  logging.basicConfig(level=logging.INFO)
 
23
  self.granite_model = None
24
  self.tokenizer = None
25
  self.pdf_name = None
26
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
 
28
  def setup_cache_directory(self):
29
  try:
 
32
  os.environ['TRANSFORMERS_CACHE'] = cache_dir
33
  os.environ['SENTENCE_TRANSFORMERS_HOME'] = cache_dir
34
  st.info(f"Using cache directory: {cache_dir}")
35
+ st.info(f"Using device: {self.device}")
36
  return cache_dir
37
  except Exception as e:
38
  st.error(f"Error setting up cache directory: {e}")
 
43
  cache_dir = self.setup_cache_directory()
44
  st.info("Loading embedding model...")
45
  self.embedding_model = SentenceTransformer(
46
+ 'all-MiniLM-L6-v2', cache_folder=cache_dir, device=self.device
47
  )
48
+
49
  st.info("Loading IBM Granite model...")
50
+ # Alternative models you could try:
51
+ # model_name = "ibm-granite/granite-3-8b-instruct" # Larger, better performance
52
+ # model_name = "microsoft/DialoGPT-medium"
53
+ # model_name = "google/flan-t5-base"
54
+ model_name = "ibm-granite/granite-3-2b-instruct"
55
+
56
+ self.tokenizer = AutoTokenizer.from_pretrained(
57
+ model_name,
58
+ cache_dir=cache_dir,
59
+ trust_remote_code=True
60
  )
61
+
62
+ # Optimize model loading based on available resources
63
+ model_kwargs = {
64
+ "cache_dir": cache_dir,
65
+ "trust_remote_code": True,
66
+ "low_cpu_mem_usage": True,
67
+ }
68
+
69
+ # Use appropriate dtype based on device
70
+ if self.device.type == "cuda":
71
+ model_kwargs["torch_dtype"] = torch.float16
72
+ else:
73
+ model_kwargs["torch_dtype"] = torch.float32
74
+
75
+ self.granite_model = AutoModelForCausalLM.from_pretrained(
76
+ model_name, **model_kwargs
77
+ ).to(self.device)
78
+
79
+ # Set pad token if not available
80
  if self.tokenizer.pad_token is None:
81
  self.tokenizer.pad_token = self.tokenizer.eos_token
82
+
83
  st.success("Models loaded successfully!")
84
  return True
85
+
86
  except Exception as e:
87
  st.error(f"Error loading models: {e}")
88
  logger.error(f"Model loading error: {e}")
 
94
  pdf_reader = PyPDF2.PdfReader(pdf_file)
95
  text = ""
96
  st.info(f"PDF has {len(pdf_reader.pages)} pages")
97
+
98
+ progress_bar = st.progress(0)
99
  for page_num, page in enumerate(pdf_reader.pages):
100
  try:
101
  page_text = page.extract_text()
 
106
  st.warning(f"⚠️ No text found on page {page_num + 1}")
107
  except Exception as page_error:
108
  st.error(f"Error extracting page {page_num + 1}: {page_error}")
109
+
110
+ # Update progress
111
+ progress_bar.progress((page_num + 1) / len(pdf_reader.pages))
112
+
113
+ progress_bar.empty()
114
+
115
  if text.strip():
116
+ st.success(f"Extracted {len(text)} characters from {len(pdf_reader.pages)} pages")
117
  st.write("πŸ“„ **Text Preview:**")
118
  st.text(text[:500] + "..." if len(text) > 500 else text)
119
  return text
120
  else:
121
  st.error("No text could be extracted from the PDF")
122
  return None
123
+
124
  except Exception as e:
125
  st.error(f"Error reading PDF file: {e}")
126
  logger.error(f"PDF extraction error: {e}")
127
  return None
128
 
129
+ def chunk_text(self, text, chunk_size=400, overlap=50):
130
+ """Improved chunking with overlap for better context preservation"""
131
  if not text or not text.strip():
132
  return []
133
+
134
  words = text.split()
135
+ chunks = []
136
+
137
+ for i in range(0, len(words), chunk_size - overlap):
138
+ chunk = " ".join(words[i:i + chunk_size])
139
+ if chunk.strip(): # Only add non-empty chunks
140
+ chunks.append(chunk)
141
+
142
+ return chunks
143
 
144
  def process_pdf(self, pdf_file, pdf_name):
145
  try:
146
  self.pdf_name = pdf_name
147
  st.info("πŸ” Extracting text from PDF...")
148
  text = self.extract_pdf_text(pdf_file)
149
+
150
  if not text:
151
  return False
152
+
153
+ st.info("βœ‚οΈ Splitting text into chunks with overlap...")
154
  chunks = self.chunk_text(text)
155
+
156
  if not chunks:
157
+ st.error("No valid text chunks created")
158
  return False
159
+
160
  st.info(f"πŸ”„ Creating embeddings for {len(chunks)} chunks...")
161
+
162
+ # Create embeddings in batches to manage memory
163
+ batch_size = 32
164
+ embeddings = []
165
+
166
+ progress_bar = st.progress(0)
167
+ for i in range(0, len(chunks), batch_size):
168
+ batch = chunks[i:i + batch_size]
169
+ batch_embeddings = self.embedding_model.encode(batch, show_progress_bar=False)
170
+ embeddings.extend(batch_embeddings)
171
+ progress_bar.progress(min(i + batch_size, len(chunks)) / len(chunks))
172
+
173
+ progress_bar.empty()
174
+
175
  self.documents = chunks
176
+ self.embeddings = np.array(embeddings)
177
+
178
  st.success(f"βœ… Successfully processed PDF: {len(chunks)} chunks created with embeddings")
179
  return True
180
+
181
  except Exception as e:
182
  st.error(f"❌ Error processing PDF: {e}")
183
  logger.error(f"PDF processing error: {e}")
 
187
  if not self.documents or len(self.embeddings) == 0:
188
  st.warning("No documents available for search")
189
  return []
190
+
191
  try:
192
  query_embedding = self.embedding_model.encode([query])
193
  similarities = cosine_similarity(query_embedding, self.embeddings)[0]
194
+
195
+ # Filter out very low similarity scores
196
+ min_threshold = 0.1
197
+ valid_indices = np.where(similarities > min_threshold)[0]
198
+
199
+ if len(valid_indices) == 0:
200
+ return []
201
+
202
+ # Get top k from valid indices
203
+ valid_similarities = similarities[valid_indices]
204
+ top_valid_indices = np.argsort(valid_similarities)[-top_k:][::-1]
205
+ top_indices = valid_indices[top_valid_indices]
206
+
207
  return [{'text': self.documents[i], 'score': similarities[i]}
208
+ for i in top_indices]
209
+
210
  except Exception as e:
211
  st.error(f"Error searching documents: {e}")
212
  logger.error(f"Search error: {e}")
 
215
  def generate_answer(self, query, context_docs):
216
  if not self.granite_model or not context_docs:
217
  return "I don't have enough information to answer your question."
218
+
219
+ # Create better context from top documents
220
+ context = "\n\n".join([f"Context {i+1}: {doc['text'][:300]}"
221
+ for i, doc in enumerate(context_docs[:2])]) # Use top 2 docs
222
+
223
+ # Improved prompt formatting
224
+ prompt = f"""Based on the following context, provide a clear and accurate answer to the question. If the context doesn't contain enough information, say so.
225
 
226
  Context:
227
  {context}
 
229
  Question: {query}
230
 
231
  Answer:"""
232
+
233
  try:
234
+ # Tokenize with proper attention to length
235
+ inputs = self.tokenizer.encode(
236
+ prompt,
237
+ return_tensors='pt',
238
+ max_length=1024,
239
+ truncation=True
240
+ ).to(self.device)
241
+
242
  with torch.no_grad():
243
  outputs = self.granite_model.generate(
244
  inputs,
245
+ max_new_tokens=150, # Use max_new_tokens instead of max_length
246
  temperature=0.7,
247
  do_sample=True,
248
  pad_token_id=self.tokenizer.eos_token_id,
249
+ eos_token_id=self.tokenizer.eos_token_id,
250
+ repetition_penalty=1.2,
251
+ top_p=0.9
252
  )
253
+
254
+ # Decode only the new tokens
255
+ response = self.tokenizer.decode(
256
+ outputs[0][inputs.shape[1]:],
257
+ skip_special_tokens=True
258
+ )
259
+
260
+ # Clean up the response
261
+ response = response.strip()
262
+ if len(response) < 10:
263
+ return f"Based on the provided context: {context[:200]}..."
264
+
265
+ return response
266
+
267
  except Exception as e:
268
  logger.error(f"Generation error: {e}")
269
+ return f"Error generating response. Here's what I found: {context[:200]}..."
270
+ finally:
271
+ # Clean up GPU memory
272
+ if self.device.type == "cuda":
273
+ torch.cuda.empty_cache()
274
 
275
  def answer_question(self, query):
276
  if not self.documents:
277
  return {'answer': "No PDF has been processed yet.", 'sources': []}
278
+
279
  relevant_docs = self.search_documents(query)
280
+
281
  if not relevant_docs:
282
+ return {'answer': "No relevant information found in the document for your question.", 'sources': []}
283
+
284
+ answer = self.generate_answer(query, relevant_docs)
285
+
286
  return {
287
+ 'answer': answer,
288
  'sources': relevant_docs
289
  }
290
 
291
  def main():
292
+ st.set_page_config(
293
+ page_title="PDF RAG with IBM Granite",
294
+ page_icon="πŸ“„",
295
+ layout="wide"
296
+ )
297
+
298
+ st.title("πŸ“„ PDF RAG with IBM Granite")
299
+ st.write("Upload a PDF and ask questions about its content using AI")
300
 
301
+ # Initialize session state
302
  if 'rag_system' not in st.session_state:
303
  st.session_state.rag_system = SimplePDFRAG()
304
  if 'models_loaded' not in st.session_state:
 
310
  if 'uploaded_file_path' not in st.session_state:
311
  st.session_state.uploaded_file_path = None
312
 
313
+ # Status indicators
314
  col1, col2, col3 = st.columns(3)
315
  with col1:
316
+ if st.session_state.models_loaded:
317
+ st.success("πŸ€– Models: Loaded")
318
+ else:
319
+ st.error("πŸ€– Models: Not Loaded")
320
+
321
  with col2:
322
+ if st.session_state.pdf_processed:
323
+ st.success(f"πŸ“„ PDF: {st.session_state.current_pdf_name}")
324
+ else:
325
+ st.error("πŸ“„ PDF: Not Processed")
326
+
327
  with col3:
328
+ if st.session_state.models_loaded and st.session_state.pdf_processed:
329
+ st.success("🟒 Ready")
330
+ else:
331
+ st.error("πŸ”΄ Not Ready")
332
 
333
+ # Model loading section
334
  if not st.session_state.models_loaded:
335
+ st.markdown("---")
336
+ st.subheader("πŸ€– Model Loading")
337
+ st.info("Click below to load the AI models. This may take a few minutes.")
338
+
339
+ if st.button("πŸ€– Load Models", type="primary"):
340
+ with st.spinner("Loading models... This may take a few minutes."):
341
  success = st.session_state.rag_system.load_models()
342
  st.session_state.models_loaded = success
343
+ if success:
344
+ st.balloons()
345
  st.rerun()
346
 
347
+ # PDF processing section
348
  if st.session_state.models_loaded:
349
  st.markdown("---")
350
  st.subheader("πŸ“ PDF Upload and Processing")
351
+
352
+ uploaded_file = st.file_uploader(
353
+ "Upload PDF",
354
+ type=["pdf"],
355
+ key="pdf_uploader",
356
+ help="Upload a PDF file to analyze and ask questions about"
357
+ )
358
 
359
  if uploaded_file:
360
+ # Save uploaded file
361
  with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
362
  tmp.write(uploaded_file.read())
363
  st.session_state.uploaded_file_path = tmp.name
364
  st.session_state.uploaded_file_name = uploaded_file.name
365
  st.session_state.pdf_processed = False
366
  st.session_state.current_pdf_name = None
367
+
368
  st.success(f"πŸ“„ Uploaded: {uploaded_file.name}")
369
 
370
  if st.session_state.uploaded_file_path and not st.session_state.pdf_processed:
371
+ if st.button("πŸ“– Process PDF", type="primary"):
372
+ with st.spinner("Processing PDF... This may take a moment."):
373
+ try:
374
+ with open(st.session_state.uploaded_file_path, "rb") as f:
375
+ success = st.session_state.rag_system.process_pdf(
376
+ f, st.session_state.uploaded_file_name
377
+ )
378
+
379
+ if success:
380
+ st.session_state.pdf_processed = True
381
+ st.session_state.current_pdf_name = st.session_state.uploaded_file_name
382
+ st.success("βœ… PDF processed successfully!")
383
+ st.balloons()
384
+ st.rerun()
385
+ else:
386
+ st.error("❌ Failed to process PDF")
387
+
388
+ except Exception as e:
389
+ st.error(f"❌ Error processing PDF: {e}")
390
 
391
+ # Q&A section
392
  if st.session_state.models_loaded and st.session_state.pdf_processed:
393
  st.markdown("---")
394
  st.subheader("❓ Ask Questions")
395
+ st.info(f"πŸ“š Current document: **{st.session_state.current_pdf_name}**")
396
+
397
+ query = st.text_input(
398
+ "Ask a question about your PDF:",
399
+ placeholder="What is the main topic discussed in this document?",
400
+ help="Ask specific questions about the content in your PDF"
401
+ )
402
+
403
+ if query and st.button("πŸ” Get Answer", type="primary"):
404
+ with st.spinner("Searching document and generating answer..."):
405
  result = st.session_state.rag_system.answer_question(query)
406
+
407
  st.markdown("### πŸ€– Answer:")
408
  st.write(result['answer'])
409
+
410
  if result.get('sources'):
411
  st.markdown("### πŸ“š Sources:")
412
  for i, src in enumerate(result['sources']):
413
+ with st.expander(f"Source {i+1} (Relevance: {src['score']:.3f})"):
414
  st.write(src['text'][:500] + "..." if len(src['text']) > 500 else src['text'])
415
 
416
+ # Sidebar
417
  with st.sidebar:
418
+ st.header("πŸ“‹ How to Use")
419
+ st.markdown("""
420
+ 1. **Load Models** - Click to download and load AI models
421
+ 2. **Upload PDF** - Select your PDF file
422
+ 3. **Process PDF** - Extract and analyze the text
423
+ 4. **Ask Questions** - Query your document
424
+ """)
425
+
426
+ st.header("πŸ’‘ Tips")
427
+ st.markdown("""
428
+ - Ask specific questions for better results
429
+ - Try different phrasings if unsatisfied
430
+ - The AI uses context from your document
431
+ """)
432
+
433
+ st.header("πŸ”§ System Info")
434
+ device_info = "GPU" if torch.cuda.is_available() else "CPU"
435
+ st.write(f"**Device:** {device_info}")
436
+ st.write(f"**Models:** {'��� Loaded' if st.session_state.models_loaded else '❌ Not loaded'}")
437
+ st.write(f"**PDF:** {'βœ… Processed' if st.session_state.pdf_processed else '❌ Not processed'}")
438
+
439
+ if st.button("πŸ”„ Reset Everything"):
440
+ # Clear all session state
441
  for key in list(st.session_state.keys()):
442
  del st.session_state[key]
443
+ # Force garbage collection
444
+ gc.collect()
445
+ if torch.cuda.is_available():
446
+ torch.cuda.empty_cache()
447
  st.rerun()
448
 
449
  if __name__ == "__main__":
450
+ main()