msmaje commited on
Commit
1ae38bf
Β·
verified Β·
1 Parent(s): 37b704a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +386 -381
app.py CHANGED
@@ -1,30 +1,30 @@
1
  """
2
- Qwen2.5 PDF RAG System with ChromaDB, LangChain and Gradio
3
- This script implements a Retrieval-Augmented Generation system for PDF documents
4
- using Qwen2.5 models, ChromaDB for vector storage, and LangChain for the RAG pipeline.
5
- The user interface is built with Gradio.
6
  """
7
 
8
  import os
9
  import time
10
- import argparse
11
  import gradio as gr
12
  from typing import List, Dict, Any, Tuple
 
13
 
14
- # LangChain imports
15
  from langchain.text_splitter import RecursiveCharacterTextSplitter
16
- from langchain.vectorstores import Chroma
17
- from langchain.embeddings import HuggingFaceEmbeddings
18
- from langchain.document_loaders import PyPDFLoader
19
  from langchain.schema import Document
20
 
21
- # LMDeploy for Qwen2.5 models
22
- from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig
 
 
23
 
24
  class PDFRagSystem:
25
- """PDF RAG System using Qwen2.5, ChromaDB, and LangChain"""
26
 
27
- def __init__(self, model_name: str, persist_directory: str = "db"):
28
  """
29
  Initialize the RAG system
30
 
@@ -35,20 +35,48 @@ class PDFRagSystem:
35
  self.model_name = model_name
36
  self.persist_directory = persist_directory
37
  self.pipe = None
 
 
38
  self.vectorstore = None
39
  self.embeddings = None
40
- self.top_sources = [] # Store top sources for each query
 
 
 
 
41
 
42
  # Initialize embedding model
43
- self.embeddings = HuggingFaceEmbeddings(
44
- model_name="sentence-transformers/all-MiniLM-L6-v2",
45
- model_kwargs={"device": "cuda"},
46
- encode_kwargs={"normalize_embeddings": True}
47
- )
 
 
 
 
 
 
 
 
48
 
49
  # Load LLM
50
  self._load_llm()
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def change_model(self, model_name: str) -> str:
53
  """
54
  Change the LLM model
@@ -62,31 +90,75 @@ class PDFRagSystem:
62
  if self.model_name == model_name:
63
  return f"Already using model: {model_name}"
64
 
65
- # Update model name
66
  self.model_name = model_name
67
 
68
- # Reload LLM
69
  try:
 
 
 
 
 
 
 
 
70
  self._load_llm()
71
  return f"Successfully switched to model: {model_name}"
72
  except Exception as e:
73
  return f"Error switching model: {str(e)}"
74
 
75
  def _load_llm(self):
76
- """Load the Qwen2.5 model with optimized settings"""
77
  print(f"\nLoading {self.model_name} model...")
78
  start_time = time.time()
79
 
80
- # Configure engine for memory optimization
81
- engine_config = TurbomindEngineConfig(
82
- cache_max_entry_count=0.5, # Use 50% of free GPU memory for KV cache
83
- session_len=4096 # Reduce context length if memory is limited
84
- )
85
-
86
- # Create the pipeline
87
- self.pipe = pipeline(self.model_name, backend_config=engine_config)
88
- load_time = time.time() - start_time
89
- print(f"Model loaded in {load_time:.2f} seconds")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  def process_pdf(self, pdf_file: str) -> List[Document]:
92
  """
@@ -98,28 +170,38 @@ class PDFRagSystem:
98
  Returns:
99
  List of document chunks
100
  """
101
- # Load PDF
102
- loader = PyPDFLoader(pdf_file)
103
- documents = loader.load()
104
-
105
- # Split documents into chunks
106
- text_splitter = RecursiveCharacterTextSplitter(
107
- chunk_size=1000,
108
- chunk_overlap=200,
109
- separators=["\n\n", "\n", " ", ""]
110
- )
111
-
112
- chunks = text_splitter.split_documents(documents)
113
- return chunks
 
 
 
114
 
115
- def create_vectorstore(self, pdf_files: List[str]) -> None:
116
  """
117
  Create or update the vector store with documents from PDF files
118
 
119
  Args:
120
  pdf_files: List of paths to PDF files
 
 
 
121
  """
 
 
 
122
  all_chunks = []
 
123
 
124
  for pdf_file in pdf_files:
125
  if not os.path.exists(pdf_file):
@@ -128,31 +210,42 @@ class PDFRagSystem:
128
 
129
  print(f"Processing {pdf_file}...")
130
  chunks = self.process_pdf(pdf_file)
131
- print(f"Created {len(chunks)} chunks from {pdf_file}")
132
- all_chunks.extend(chunks)
133
-
134
- # Create or update vectorstore
135
- if os.path.exists(self.persist_directory) and len(os.listdir(self.persist_directory)) > 0:
136
- print("Loading existing vectorstore...")
137
- self.vectorstore = Chroma(
138
- persist_directory=self.persist_directory,
139
- embedding_function=self.embeddings
140
- )
141
- print("Adding new documents to existing vectorstore...")
142
- self.vectorstore.add_documents(all_chunks)
143
- else:
144
- print("Creating new vectorstore...")
145
- self.vectorstore = Chroma.from_documents(
146
- documents=all_chunks,
147
- embedding=self.embeddings,
148
- persist_directory=self.persist_directory
149
- )
150
 
151
- # Persist to disk
152
- self.vectorstore.persist()
153
- print(f"Vectorstore created with {len(all_chunks)} chunks and persisted to {self.persist_directory}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
- def retrieve_context(self, query: str, k: int = 5) -> Tuple[str, List[Dict]]:
156
  """
157
  Retrieve relevant context for a query
158
 
@@ -166,55 +259,41 @@ class PDFRagSystem:
166
  if not self.vectorstore:
167
  return "", []
168
 
169
- # Search for relevant documents
170
- docs_with_scores = self.vectorstore.similarity_search_with_score(query, k=k)
171
-
172
- # Format context
173
- context_parts = []
174
- sources = []
175
-
176
- for i, (doc, score) in enumerate(docs_with_scores):
177
- # Format document content with score
178
- context_part = f"Document {i+1} [Relevance: {score:.2f}]:\n{doc.page_content}\n"
179
- context_parts.append(context_part)
180
 
181
- # Convert any complex objects to simple types for serialization
182
- try:
183
- # Make a clean copy of metadata with only string keys and simple values
 
 
 
 
 
184
  clean_metadata = {}
185
  for key, value in doc.metadata.items():
186
- # Convert key to string
187
  str_key = str(key)
188
- # Convert value to a simple type
189
  if isinstance(value, (str, int, float, bool, type(None))):
190
  clean_metadata[str_key] = value
191
  else:
192
  clean_metadata[str_key] = str(value)
193
 
194
- # Prepare source info with clean metadata
195
  source_info = {
196
  "content": str(doc.page_content),
197
  "metadata": clean_metadata,
198
  "score": float(score),
199
  "source_id": i+1
200
  }
201
- except Exception as e:
202
- # Fallback if there's an error creating the source info
203
- source_info = {
204
- "content": str(doc.page_content)[:1000], # Limit length if it's problematic
205
- "metadata": {"error": f"Error processing metadata: {str(e)}"},
206
- "score": float(score),
207
- "source_id": i+1
208
- }
209
 
210
- sources.append(source_info)
211
-
212
- # Store sources for display in UI
213
- self.top_sources = sources
214
-
215
- # Combine all context
216
- context = "\n".join(context_parts)
217
- return context, sources
218
 
219
  def generate_response(self, query: str, system_prompt: str = "You are a helpful assistant that answers questions based on the provided documents.") -> str:
220
  """
@@ -234,180 +313,131 @@ class PDFRagSystem:
234
  return "No relevant documents found in the database. Please upload some PDF files first."
235
 
236
  # Create RAG prompt
237
- rag_prompt = f"""Please answer the following question based only on the provided context. If the context doesn't contain relevant information, say that you don't know.
238
 
239
  Context:
240
  {context}
241
 
242
- Question: {query}"""
243
-
244
- # Configure generation parameters
245
- gen_config = GenerationConfig(
246
- max_new_tokens=512,
247
- temperature=0.7,
248
- top_p=0.9,
249
- top_k=40
250
- )
251
-
252
- # Format in chat-style
253
- chat_prompt = [
254
- {"role": "system", "content": system_prompt},
255
- {"role": "user", "content": rag_prompt}
256
- ]
257
-
258
- # Generate response
259
- print(f"Running inference for query: {query}")
260
- start_time = time.time()
261
- response = self.pipe(chat_prompt, gen_config=gen_config)
262
- inference_time = time.time() - start_time
263
 
264
- # Extract text from response
265
- if hasattr(response, 'text'):
266
- result = response.text
267
- else:
268
- result = str(response)
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
- print(f"Inference completed in {inference_time:.2f} seconds")
271
- return result
 
 
 
 
 
 
 
 
 
272
 
273
  def get_top_sources(self) -> List[Dict]:
274
  """Get the top sources used for the last query"""
275
  return self.top_sources
276
 
277
 
278
- # Gradio UI Implementation
279
  class RagUI:
280
- """Gradio UI for the PDF RAG System"""
281
 
282
  def __init__(self, rag_system: PDFRagSystem):
283
- """
284
- Initialize the UI
285
-
286
- Args:
287
- rag_system: The RAG system to use
288
- """
289
  self.rag_system = rag_system
290
  self.interface = None
291
 
292
- # Define model mapping
293
  self.models = {
294
- "Qwen2.5-7B": "Qwen/Qwen2.5-7B-Instruct-1M",
295
- "Qwen2.5-3B": "Qwen/Qwen2.5-3B-Instruct",
296
- "Qwen2.5-1.5B": "Qwen/Qwen2.5-1.5B-Instruct"
297
  }
298
 
299
- # Get the current model's display name
300
- self.current_model = next(
301
- (k for k, v in self.models.items() if v == self.rag_system.model_name),
302
- "Qwen2.5-3B" # Default fallback
303
- )
304
 
305
- def _upload_files(self, files: List[str]) -> str:
306
- """
307
- Handle file upload
308
-
309
- Args:
310
- files: List of file paths
311
-
312
- Returns:
313
- Status message
314
- """
315
  if not files:
316
  return "No files selected."
317
 
318
  try:
319
- self.rag_system.create_vectorstore([f.name for f in files])
320
- return f"Successfully processed {len(files)} PDFs."
321
  except Exception as e:
322
  return f"Error processing files: {str(e)}"
323
 
324
  def _switch_model(self, model_name: str) -> str:
325
- """
326
- Switch the model
327
-
328
- Args:
329
- model_name: Name of model to switch to (display name)
330
-
331
- Returns:
332
- Status message
333
- """
334
  if model_name not in self.models:
335
  return f"Unknown model: {model_name}"
336
 
337
- # Get the full model name
338
  full_model_name = self.models[model_name]
339
-
340
- # Update the current model
341
  self.current_model = model_name
342
 
343
- # Switch the model in the RAG system
344
  return self.rag_system.change_model(full_model_name)
345
 
346
- def _query(self, query: str, system_prompt: str) -> Tuple[str, List[Dict]]:
347
- """
348
- Process a query
349
-
350
- Args:
351
- query: User question
352
- system_prompt: System prompt to set assistant behavior
353
-
354
- Returns:
355
- Tuple of (response text, sources)
356
- """
357
  if not query.strip():
358
- return "Please enter a question.", []
359
 
360
  response = self.rag_system.generate_response(query, system_prompt)
361
  sources = self.rag_system.get_top_sources()
 
362
 
363
- return response, sources
364
 
365
  def _format_source_display(self, sources: List[Dict]) -> str:
366
- """
367
- Format sources for display
368
-
369
- Args:
370
- sources: List of source documents
371
-
372
- Returns:
373
- Formatted HTML for display
374
- """
375
  if not sources:
376
  return "<div class='source-container'>No sources available.</div>"
377
 
378
  html = "<div class='source-container'>"
379
 
380
- # Make sure we're working with actual dictionaries
381
  for i, source in enumerate(sources):
382
  try:
383
- # Handle case where source might not be properly formed
384
  if not isinstance(source, dict):
385
  continue
386
 
387
- # Extract metadata safely
388
  metadata = source.get("metadata", {})
389
  if not isinstance(metadata, dict):
390
  metadata = {}
391
 
392
  page_num = metadata.get("page", "Unknown")
393
  source_file = metadata.get("source", "Unknown")
394
- content = source.get("content", "No content available")
395
  score = source.get("score", 0.0)
396
  source_id = source.get("source_id", i+1)
397
 
398
- # Determine relevance class based on score
399
- if score >= 0.8:
400
  relevance_class = "relevance-high"
401
- elif score >= 0.6:
402
  relevance_class = "relevance-medium"
403
  else:
404
  relevance_class = "relevance-low"
405
 
406
- # Format as a card with our CSS classes
407
  html += f"""
408
  <div class="source-card">
409
  <div class="source-header">
410
- Source {source_id} (<span class="{relevance_class}">Relevance: {score:.2f}</span>)
411
  </div>
412
  <div class="source-meta">
413
  <strong>File:</strong> {os.path.basename(str(source_file))}
@@ -421,7 +451,6 @@ class RagUI:
421
  </div>
422
  """
423
  except Exception as e:
424
- # Handle any formatting errors
425
  html += f'<div class="source-card">Error displaying source {i+1}: {str(e)}</div>'
426
 
427
  html += "</div>"
@@ -429,168 +458,169 @@ class RagUI:
429
 
430
  def build_interface(self):
431
  """Build the Gradio interface"""
432
- with gr.Blocks(title="Qwen2.5 PDF RAG System") as interface:
433
- gr.Markdown("# Qwen2.5 PDF RAG System")
434
- gr.Markdown("Upload PDF files, then ask questions about their content.")
435
-
436
- # Model selection section at the top
437
- with gr.Row():
438
- with gr.Column(scale=1):
439
- gr.Markdown("### Model Selection")
440
- model_dropdown = gr.Dropdown(
441
- choices=list(self.models.keys()),
442
- value=self.current_model,
443
- label="Select Qwen2.5 Model",
444
- info="Larger models are more accurate but slower"
445
- )
446
- model_status = gr.Textbox(
447
- label="Model Status",
448
- value=f"Currently using: {self.current_model}",
449
- interactive=False
450
- )
451
- model_switch_btn = gr.Button("Switch Model", variant="secondary")
452
-
453
- with gr.Tab("Upload & Query"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
  with gr.Row():
455
  with gr.Column(scale=1):
456
- # File upload section
457
- gr.Markdown("### Upload PDFs")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  file_input = gr.File(
459
  file_count="multiple",
460
- label="Upload PDF Files"
 
461
  )
462
- upload_button = gr.Button("Process PDFs", variant="primary")
463
- upload_status = gr.Textbox(label="Upload Status", interactive=False)
464
-
465
- # System prompt
466
- system_prompt = gr.Textbox(
467
- label="System Prompt",
468
- value="You are a helpful assistant that answers questions based only on the provided documents. You must cite your sources.",
469
- lines=2
470
  )
471
 
472
  with gr.Column(scale=2):
473
- # Query section
474
- gr.Markdown("### Ask Questions")
 
 
 
 
 
 
475
  query_input = gr.Textbox(
476
  label="Your Question",
477
- placeholder="Ask a question about the uploaded PDFs...",
478
  lines=2
479
  )
480
- query_button = gr.Button("Ask", variant="primary")
 
481
  answer_output = gr.Textbox(
482
  label="Answer",
483
  interactive=False,
484
- lines=10
 
485
  )
486
 
487
- # Source Documents Tab
488
- with gr.Tab("Reference Sources"):
489
- gr.Markdown("### Sources Used for Answer")
490
- gr.Markdown("This tab shows the top document chunks that were used to generate the answer.")
491
 
492
- # Add some styling to make the display more user-friendly
493
- gr.HTML("""
494
- <style>
495
- .source-container {
496
- max-height: 800px;
497
- overflow-y: auto;
498
- padding: 10px;
499
- }
500
- .source-card {
501
- margin-bottom: 20px;
502
- padding: 15px;
503
- border: 1px solid #ddd;
504
- border-radius: 5px;
505
- background-color: #fff;
506
- box-shadow: 0 2px 4px rgba(0,0,0,0.05);
507
- }
508
- .source-header {
509
- font-size: 18px;
510
- font-weight: bold;
511
- margin-bottom: 10px;
512
- color: #333;
513
- }
514
- .source-meta {
515
- color: #666;
516
- margin-bottom: 8px;
517
- }
518
- .source-content {
519
- background-color: #f9f9f9;
520
- padding: 12px;
521
- border-radius: 4px;
522
- border-left: 3px solid #2c7be5;
523
- font-family: monospace;
524
- white-space: pre-wrap;
525
- overflow-x: auto;
526
- }
527
- .relevance-high {
528
- color: #1e7e34;
529
- }
530
- .relevance-medium {
531
- color: #1f75cb;
532
- }
533
- .relevance-low {
534
- color: #6c757d;
535
- }
536
- </style>
537
- """)
538
-
539
- sources_display = gr.HTML(label="Sources")
540
 
541
- # System Info Tab
542
- with gr.Tab("System Info"):
543
- gr.Markdown("### System Information")
544
  gr.Markdown("""
545
- This PDF RAG (Retrieval-Augmented Generation) system uses:
 
 
 
 
 
 
 
546
 
547
- - **Qwen2.5 Models** for text generation
548
- - **ChromaDB** for vector storage and similarity search
549
- - **LangChain** for the RAG pipeline
550
 
551
- #### Available Models:
 
552
 
553
- 1. **Qwen2.5-1.5B** - Fastest, smallest model for simple queries (1.5 billion parameters)
554
- 2. **Qwen2.5-3B** - Good balance of speed and quality (3 billion parameters)
555
- 3. **Qwen2.5-7B** - Most accurate model for complex questions (7 billion parameters)
556
 
557
- #### Memory Usage:
 
 
 
558
 
559
- - The 1.5B model requires approximately 3GB of VRAM
560
- - The 3B model requires approximately 6GB of VRAM
561
- - The 7B model requires approximately 14GB of VRAM
562
 
563
- Model switching happens in real-time and takes a few seconds.
 
 
 
564
  """)
565
 
566
- # Set up events
567
  upload_button.click(
568
  fn=self._upload_files,
569
  inputs=[file_input],
570
  outputs=[upload_status]
571
  )
572
 
573
- # Define a wrapper function that returns formatted HTML directly
574
- def query_and_format(query, system_prompt):
575
- response, sources = self._query(query, system_prompt)
576
- sources_html = self._format_source_display(sources)
577
- return response, sources_html
578
-
579
- # Use the wrapper function for query events
580
  query_button.click(
581
- fn=query_and_format,
582
  inputs=[query_input, system_prompt],
583
  outputs=[answer_output, sources_display]
584
  )
585
 
586
- # Also trigger query on pressing Enter in the query input
587
  query_input.submit(
588
- fn=query_and_format,
589
  inputs=[query_input, system_prompt],
590
  outputs=[answer_output, sources_display]
591
  )
592
 
593
- # Model switching event
594
  model_switch_btn.click(
595
  fn=self._switch_model,
596
  inputs=[model_dropdown],
@@ -604,68 +634,43 @@ class RagUI:
604
  """Launch the Gradio interface"""
605
  if not self.interface:
606
  self.build_interface()
607
-
608
- self.interface.launch(**kwargs)
609
 
610
 
 
611
  def main():
612
- """Main function to run the application"""
613
- # Set up argument parser
614
- parser = argparse.ArgumentParser(description="Run Qwen2.5 PDF RAG System")
615
-
616
- # Model selection argument
617
- parser.add_argument(
618
- "--model",
619
- type=str,
620
- choices=["7b", "3b", "1.5b"],
621
- default="3b",
622
- help="Model size to use: 7b, 3b, or 1.5b"
623
- )
624
-
625
- # Database directory
626
- parser.add_argument(
627
- "--db_dir",
628
- type=str,
629
- default="chroma_db",
630
- help="Directory to store the vector database"
631
- )
632
 
633
- # Gradio server settings
634
- parser.add_argument(
635
- "--share",
636
- action="store_true", default=True,
637
- help="Create a shareable link"
638
- )
639
 
640
- parser.add_argument(
641
- "--port",
642
- type=int,
643
- default=7860,
644
- help="Port to run the Gradio server on"
645
- )
646
-
647
- # Parse arguments
648
- args = parser.parse_args()
649
-
650
- # Define model mapping
651
- models = {
652
- "7b": "Qwen/Qwen2.5-7B-Instruct-1M",
653
- "3b": "Qwen/Qwen2.5-3B-Instruct",
654
- "1.5b": "Qwen/Qwen2.5-1.5B-Instruct"
655
- }
656
-
657
- model_name = models[args.model]
658
-
659
- print(f"Starting PDF RAG system with model: {model_name}")
660
- print(f"Vector database directory: {args.db_dir}")
661
-
662
- # Create the RAG system
663
- rag_system = PDFRagSystem(model_name, args.db_dir)
664
-
665
- # Create and launch the UI
666
- ui = RagUI(rag_system)
667
- ui.launch(share=args.share)
668
-
669
 
670
  if __name__ == "__main__":
671
- main()
 
1
  """
2
+ Qwen2.5 PDF RAG System for Hugging Face Spaces
3
+ Adapted for deployment on Hugging Face Spaces with optimizations for the cloud environment.
 
 
4
  """
5
 
6
  import os
7
  import time
 
8
  import gradio as gr
9
  from typing import List, Dict, Any, Tuple
10
+ import torch
11
 
12
+ # LangChain imports - updated to avoid deprecation warnings
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
14
+ from langchain_community.vectorstores import Chroma
15
+ from langchain_community.embeddings import HuggingFaceEmbeddings
16
+ from langchain_community.document_loaders import PyPDFLoader
17
  from langchain.schema import Document
18
 
19
+ # Transformers for Qwen2.5 models (more compatible with HF Spaces)
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
21
+ import warnings
22
+ warnings.filterwarnings("ignore")
23
 
24
  class PDFRagSystem:
25
+ """PDF RAG System using Qwen2.5, ChromaDB, and LangChain - HF Spaces optimized"""
26
 
27
+ def __init__(self, model_name: str = "Qwen/Qwen2.5-1.5B-Instruct", persist_directory: str = "db"):
28
  """
29
  Initialize the RAG system
30
 
 
35
  self.model_name = model_name
36
  self.persist_directory = persist_directory
37
  self.pipe = None
38
+ self.tokenizer = None
39
+ self.model = None
40
  self.vectorstore = None
41
  self.embeddings = None
42
+ self.top_sources = []
43
+
44
+ # Check available device
45
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
46
+ print(f"Using device: {self.device}")
47
 
48
  # Initialize embedding model
49
+ print("Loading embedding model...")
50
+ try:
51
+ self.embeddings = HuggingFaceEmbeddings(
52
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
53
+ model_kwargs={"device": self.device},
54
+ encode_kwargs={"normalize_embeddings": True}
55
+ )
56
+ except Exception as e:
57
+ print(f"Warning: Error loading HuggingFaceEmbeddings, trying alternative: {e}")
58
+ # Fallback to basic embeddings if HuggingFaceEmbeddings fails
59
+ from sentence_transformers import SentenceTransformer
60
+ self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
61
+ self.embeddings = self._create_custom_embeddings()
62
 
63
  # Load LLM
64
  self._load_llm()
65
 
66
+ def _create_custom_embeddings(self):
67
+ """Create custom embeddings wrapper if HuggingFaceEmbeddings fails"""
68
+ class CustomEmbeddings:
69
+ def __init__(self, model):
70
+ self.model = model
71
+
72
+ def embed_documents(self, texts):
73
+ return self.model.encode(texts).tolist()
74
+
75
+ def embed_query(self, text):
76
+ return self.model.encode([text])[0].tolist()
77
+
78
+ return CustomEmbeddings(self.embedding_model)
79
+
80
  def change_model(self, model_name: str) -> str:
81
  """
82
  Change the LLM model
 
90
  if self.model_name == model_name:
91
  return f"Already using model: {model_name}"
92
 
 
93
  self.model_name = model_name
94
 
 
95
  try:
96
+ # Clear GPU memory
97
+ if hasattr(self, 'model') and self.model is not None:
98
+ del self.model
99
+ del self.tokenizer
100
+ del self.pipe
101
+ if torch.cuda.is_available():
102
+ torch.cuda.empty_cache()
103
+
104
  self._load_llm()
105
  return f"Successfully switched to model: {model_name}"
106
  except Exception as e:
107
  return f"Error switching model: {str(e)}"
108
 
109
  def _load_llm(self):
110
+ """Load the Qwen2.5 model with optimized settings for HF Spaces"""
111
  print(f"\nLoading {self.model_name} model...")
112
  start_time = time.time()
113
 
114
+ try:
115
+ # Load tokenizer
116
+ self.tokenizer = AutoTokenizer.from_pretrained(
117
+ self.model_name,
118
+ trust_remote_code=True
119
+ )
120
+
121
+ # Configure model loading for limited resources
122
+ model_kwargs = {
123
+ "trust_remote_code": True,
124
+ "torch_dtype": torch.float16 if self.device == "cuda" else torch.float32,
125
+ "low_cpu_mem_usage": True,
126
+ }
127
+
128
+ if self.device == "cuda":
129
+ model_kwargs["device_map"] = "auto"
130
+
131
+ # Load model
132
+ self.model = AutoModelForCausalLM.from_pretrained(
133
+ self.model_name,
134
+ **model_kwargs
135
+ )
136
+
137
+ if self.device == "cpu":
138
+ self.model = self.model.to(self.device)
139
+
140
+ # Create pipeline
141
+ self.pipe = pipeline(
142
+ "text-generation",
143
+ model=self.model,
144
+ tokenizer=self.tokenizer,
145
+ device=0 if self.device == "cuda" else -1,
146
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
147
+ return_full_text=False
148
+ )
149
+
150
+ load_time = time.time() - start_time
151
+ print(f"Model loaded in {load_time:.2f} seconds")
152
+
153
+ except Exception as e:
154
+ print(f"Error loading model: {e}")
155
+ # Fallback to a smaller model if the requested one fails
156
+ if "1.5B" not in self.model_name:
157
+ print("Falling back to Qwen2.5-1.5B-Instruct...")
158
+ self.model_name = "Qwen/Qwen2.5-1.5B-Instruct"
159
+ self._load_llm()
160
+ else:
161
+ raise e
162
 
163
  def process_pdf(self, pdf_file: str) -> List[Document]:
164
  """
 
170
  Returns:
171
  List of document chunks
172
  """
173
+ try:
174
+ loader = PyPDFLoader(pdf_file)
175
+ documents = loader.load()
176
+
177
+ # Split documents into chunks
178
+ text_splitter = RecursiveCharacterTextSplitter(
179
+ chunk_size=800, # Smaller chunks for better performance
180
+ chunk_overlap=150,
181
+ separators=["\n\n", "\n", ". ", " ", ""]
182
+ )
183
+
184
+ chunks = text_splitter.split_documents(documents)
185
+ return chunks
186
+ except Exception as e:
187
+ print(f"Error processing PDF {pdf_file}: {e}")
188
+ return []
189
 
190
+ def create_vectorstore(self, pdf_files: List[str]) -> str:
191
  """
192
  Create or update the vector store with documents from PDF files
193
 
194
  Args:
195
  pdf_files: List of paths to PDF files
196
+
197
+ Returns:
198
+ Status message
199
  """
200
+ if not pdf_files:
201
+ return "No files provided."
202
+
203
  all_chunks = []
204
+ processed_files = 0
205
 
206
  for pdf_file in pdf_files:
207
  if not os.path.exists(pdf_file):
 
210
 
211
  print(f"Processing {pdf_file}...")
212
  chunks = self.process_pdf(pdf_file)
213
+ if chunks:
214
+ print(f"Created {len(chunks)} chunks from {pdf_file}")
215
+ all_chunks.extend(chunks)
216
+ processed_files += 1
217
+ else:
218
+ print(f"Failed to process {pdf_file}")
219
+
220
+ if not all_chunks:
221
+ return "No documents were successfully processed."
 
 
 
 
 
 
 
 
 
 
222
 
223
+ try:
224
+ # Create or update vectorstore
225
+ if os.path.exists(self.persist_directory) and len(os.listdir(self.persist_directory)) > 0:
226
+ print("Loading existing vectorstore...")
227
+ self.vectorstore = Chroma(
228
+ persist_directory=self.persist_directory,
229
+ embedding_function=self.embeddings
230
+ )
231
+ print("Adding new documents to existing vectorstore...")
232
+ self.vectorstore.add_documents(all_chunks)
233
+ else:
234
+ print("Creating new vectorstore...")
235
+ self.vectorstore = Chroma.from_documents(
236
+ documents=all_chunks,
237
+ embedding=self.embeddings,
238
+ persist_directory=self.persist_directory
239
+ )
240
+
241
+ # Persist to disk
242
+ self.vectorstore.persist()
243
+ return f"Successfully processed {processed_files} PDFs with {len(all_chunks)} chunks."
244
+
245
+ except Exception as e:
246
+ return f"Error creating vectorstore: {str(e)}"
247
 
248
+ def retrieve_context(self, query: str, k: int = 4) -> Tuple[str, List[Dict]]:
249
  """
250
  Retrieve relevant context for a query
251
 
 
259
  if not self.vectorstore:
260
  return "", []
261
 
262
+ try:
263
+ # Search for relevant documents
264
+ docs_with_scores = self.vectorstore.similarity_search_with_score(query, k=k)
 
 
 
 
 
 
 
 
265
 
266
+ context_parts = []
267
+ sources = []
268
+
269
+ for i, (doc, score) in enumerate(docs_with_scores):
270
+ context_part = f"Document {i+1}:\n{doc.page_content}\n"
271
+ context_parts.append(context_part)
272
+
273
+ # Clean metadata for serialization
274
  clean_metadata = {}
275
  for key, value in doc.metadata.items():
 
276
  str_key = str(key)
 
277
  if isinstance(value, (str, int, float, bool, type(None))):
278
  clean_metadata[str_key] = value
279
  else:
280
  clean_metadata[str_key] = str(value)
281
 
 
282
  source_info = {
283
  "content": str(doc.page_content),
284
  "metadata": clean_metadata,
285
  "score": float(score),
286
  "source_id": i+1
287
  }
288
+ sources.append(source_info)
 
 
 
 
 
 
 
289
 
290
+ self.top_sources = sources
291
+ context = "\n".join(context_parts)
292
+ return context, sources
293
+
294
+ except Exception as e:
295
+ print(f"Error retrieving context: {e}")
296
+ return "", []
 
297
 
298
  def generate_response(self, query: str, system_prompt: str = "You are a helpful assistant that answers questions based on the provided documents.") -> str:
299
  """
 
313
  return "No relevant documents found in the database. Please upload some PDF files first."
314
 
315
  # Create RAG prompt
316
+ rag_prompt = f"""Based on the following context, please answer the question. If the answer is not in the context, say that you don't know.
317
 
318
  Context:
319
  {context}
320
 
321
+ Question: {query}
322
+
323
+ Answer:"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
+ try:
326
+ # Generate response
327
+ print(f"Running inference for query: {query}")
328
+ start_time = time.time()
329
+
330
+ # Use the pipeline for generation
331
+ response = self.pipe(
332
+ rag_prompt,
333
+ max_new_tokens=300,
334
+ temperature=0.7,
335
+ top_p=0.9,
336
+ do_sample=True,
337
+ pad_token_id=self.tokenizer.eos_token_id
338
+ )
339
+
340
+ inference_time = time.time() - start_time
341
+ print(f"Inference completed in {inference_time:.2f} seconds")
342
 
343
+ # Extract the generated text
344
+ if isinstance(response, list) and len(response) > 0:
345
+ result = response[0].get('generated_text', '').strip()
346
+ else:
347
+ result = str(response).strip()
348
+
349
+ return result if result else "I couldn't generate a response. Please try again."
350
+
351
+ except Exception as e:
352
+ print(f"Error generating response: {e}")
353
+ return f"Error generating response: {str(e)}"
354
 
355
  def get_top_sources(self) -> List[Dict]:
356
  """Get the top sources used for the last query"""
357
  return self.top_sources
358
 
359
 
 
360
  class RagUI:
361
+ """Gradio UI for the PDF RAG System - HF Spaces optimized"""
362
 
363
  def __init__(self, rag_system: PDFRagSystem):
 
 
 
 
 
 
364
  self.rag_system = rag_system
365
  self.interface = None
366
 
367
+ # Define available models (optimized for HF Spaces)
368
  self.models = {
369
+ "Qwen2.5-1.5B (Recommended)": "Qwen/Qwen2.5-1.5B-Instruct",
370
+ "Qwen2.5-3B": "Qwen/Qwen2.5-3B-Instruct"
 
371
  }
372
 
373
+ self.current_model = "Qwen2.5-1.5B (Recommended)"
 
 
 
 
374
 
375
+ def _upload_files(self, files) -> str:
376
+ """Handle file upload"""
 
 
 
 
 
 
 
 
377
  if not files:
378
  return "No files selected."
379
 
380
  try:
381
+ file_paths = [f.name for f in files]
382
+ return self.rag_system.create_vectorstore(file_paths)
383
  except Exception as e:
384
  return f"Error processing files: {str(e)}"
385
 
386
  def _switch_model(self, model_name: str) -> str:
387
+ """Switch the model"""
 
 
 
 
 
 
 
 
388
  if model_name not in self.models:
389
  return f"Unknown model: {model_name}"
390
 
 
391
  full_model_name = self.models[model_name]
 
 
392
  self.current_model = model_name
393
 
 
394
  return self.rag_system.change_model(full_model_name)
395
 
396
+ def _query(self, query: str, system_prompt: str) -> Tuple[str, str]:
397
+ """Process a query"""
 
 
 
 
 
 
 
 
 
398
  if not query.strip():
399
+ return "Please enter a question.", ""
400
 
401
  response = self.rag_system.generate_response(query, system_prompt)
402
  sources = self.rag_system.get_top_sources()
403
+ sources_html = self._format_source_display(sources)
404
 
405
+ return response, sources_html
406
 
407
  def _format_source_display(self, sources: List[Dict]) -> str:
408
+ """Format sources for display"""
 
 
 
 
 
 
 
 
409
  if not sources:
410
  return "<div class='source-container'>No sources available.</div>"
411
 
412
  html = "<div class='source-container'>"
413
 
 
414
  for i, source in enumerate(sources):
415
  try:
 
416
  if not isinstance(source, dict):
417
  continue
418
 
 
419
  metadata = source.get("metadata", {})
420
  if not isinstance(metadata, dict):
421
  metadata = {}
422
 
423
  page_num = metadata.get("page", "Unknown")
424
  source_file = metadata.get("source", "Unknown")
425
+ content = source.get("content", "No content available")[:500] + "..." # Limit content length
426
  score = source.get("score", 0.0)
427
  source_id = source.get("source_id", i+1)
428
 
429
+ # Determine relevance class
430
+ if score <= 0.5: # Lower is better for distance-based similarity
431
  relevance_class = "relevance-high"
432
+ elif score <= 0.8:
433
  relevance_class = "relevance-medium"
434
  else:
435
  relevance_class = "relevance-low"
436
 
 
437
  html += f"""
438
  <div class="source-card">
439
  <div class="source-header">
440
+ Source {source_id} (<span class="{relevance_class}">Distance: {score:.2f}</span>)
441
  </div>
442
  <div class="source-meta">
443
  <strong>File:</strong> {os.path.basename(str(source_file))}
 
451
  </div>
452
  """
453
  except Exception as e:
 
454
  html += f'<div class="source-card">Error displaying source {i+1}: {str(e)}</div>'
455
 
456
  html += "</div>"
 
458
 
459
  def build_interface(self):
460
  """Build the Gradio interface"""
461
+ # Custom CSS for better appearance
462
+ css = """
463
+ .source-container {
464
+ max-height: 600px;
465
+ overflow-y: auto;
466
+ padding: 10px;
467
+ }
468
+ .source-card {
469
+ margin-bottom: 15px;
470
+ padding: 12px;
471
+ border: 1px solid #ddd;
472
+ border-radius: 6px;
473
+ background-color: #f8f9fa;
474
+ box-shadow: 0 1px 3px rgba(0,0,0,0.1);
475
+ }
476
+ .source-header {
477
+ font-size: 16px;
478
+ font-weight: bold;
479
+ margin-bottom: 8px;
480
+ color: #333;
481
+ }
482
+ .source-meta {
483
+ color: #666;
484
+ margin-bottom: 6px;
485
+ font-size: 14px;
486
+ }
487
+ .source-content {
488
+ background-color: #fff;
489
+ padding: 10px;
490
+ border-radius: 4px;
491
+ border-left: 3px solid #007bff;
492
+ font-family: 'Segoe UI', sans-serif;
493
+ line-height: 1.4;
494
+ font-size: 14px;
495
+ }
496
+ .relevance-high { color: #28a745; font-weight: bold; }
497
+ .relevance-medium { color: #ffc107; font-weight: bold; }
498
+ .relevance-low { color: #dc3545; font-weight: bold; }
499
+ """
500
+
501
+ with gr.Blocks(title="Qwen2.5 PDF RAG System", css=css) as interface:
502
+ gr.Markdown("""
503
+ # πŸ€– Qwen2.5 PDF RAG System
504
+
505
+ Upload PDF documents and ask questions about their content using advanced AI.
506
+
507
+ **⚑ Powered by Qwen2.5 Language Models**
508
+ """)
509
+
510
+ with gr.Tab("πŸ“š Main Interface"):
511
  with gr.Row():
512
  with gr.Column(scale=1):
513
+ gr.Markdown("### πŸ”§ Settings")
514
+
515
+ # Model selection
516
+ model_dropdown = gr.Dropdown(
517
+ choices=list(self.models.keys()),
518
+ value=self.current_model,
519
+ label="AI Model",
520
+ info="1.5B model recommended for stability"
521
+ )
522
+ model_switch_btn = gr.Button("πŸ”„ Switch Model", size="sm")
523
+ model_status = gr.Textbox(
524
+ label="Model Status",
525
+ value=f"Using: {self.current_model}",
526
+ interactive=False
527
+ )
528
+
529
+ gr.Markdown("### πŸ“„ Upload Documents")
530
  file_input = gr.File(
531
  file_count="multiple",
532
+ file_types=[".pdf"],
533
+ label="PDF Files"
534
  )
535
+ upload_button = gr.Button("πŸ“€ Process PDFs", variant="primary")
536
+ upload_status = gr.Textbox(
537
+ label="Status",
538
+ interactive=False,
539
+ placeholder="Upload status will appear here..."
 
 
 
540
  )
541
 
542
  with gr.Column(scale=2):
543
+ gr.Markdown("### πŸ’¬ Ask Questions")
544
+
545
+ system_prompt = gr.Textbox(
546
+ label="System Instructions",
547
+ value="You are a helpful AI assistant. Answer questions based only on the provided documents. Be concise and cite relevant information.",
548
+ lines=3
549
+ )
550
+
551
  query_input = gr.Textbox(
552
  label="Your Question",
553
+ placeholder="What would you like to know about your documents?",
554
  lines=2
555
  )
556
+ query_button = gr.Button("πŸ” Ask Question", variant="primary")
557
+
558
  answer_output = gr.Textbox(
559
  label="Answer",
560
  interactive=False,
561
+ lines=8,
562
+ placeholder="Answers will appear here..."
563
  )
564
 
565
+ with gr.Tab("πŸ“– Sources"):
566
+ gr.Markdown("### πŸ“š Reference Documents")
567
+ gr.Markdown("View the source documents used to generate answers.")
 
568
 
569
+ sources_display = gr.HTML(
570
+ label="Sources",
571
+ value="<p>No sources available yet. Ask a question first!</p>"
572
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573
 
574
+ with gr.Tab("ℹ️ Info"):
 
 
575
  gr.Markdown("""
576
+ ### About This System
577
+
578
+ This is a **Retrieval-Augmented Generation (RAG)** system that:
579
+
580
+ - πŸ“€ **Processes PDF documents** and stores them in a vector database
581
+ - πŸ” **Searches** for relevant content based on your questions
582
+ - πŸ€– **Generates answers** using Qwen2.5 language models
583
+ - πŸ“š **Shows sources** so you can verify the information
584
 
585
+ ### Available Models
 
 
586
 
587
+ - **Qwen2.5-1.5B** ⚑ - Fast and efficient (Recommended for HF Spaces)
588
+ - **Qwen2.5-3B** 🧠 - More capable but slower
589
 
590
+ ### Tips for Best Results
 
 
591
 
592
+ 1. πŸ“„ Upload clear, text-based PDFs (not scanned images)
593
+ 2. ❓ Ask specific questions rather than broad topics
594
+ 3. πŸ” Check the "Sources" tab to see what documents were used
595
+ 4. πŸ”„ Try rephrasing your question if you don't get good results
596
 
597
+ ### Technical Details
 
 
598
 
599
+ - **Vector Store**: ChromaDB with cosine similarity
600
+ - **Embeddings**: sentence-transformers/all-MiniLM-L6-v2
601
+ - **Chunk Size**: 800 tokens with 150 token overlap
602
+ - **Context Window**: Up to 4 most relevant document chunks
603
  """)
604
 
605
+ # Event handlers
606
  upload_button.click(
607
  fn=self._upload_files,
608
  inputs=[file_input],
609
  outputs=[upload_status]
610
  )
611
 
 
 
 
 
 
 
 
612
  query_button.click(
613
+ fn=self._query,
614
  inputs=[query_input, system_prompt],
615
  outputs=[answer_output, sources_display]
616
  )
617
 
 
618
  query_input.submit(
619
+ fn=self._query,
620
  inputs=[query_input, system_prompt],
621
  outputs=[answer_output, sources_display]
622
  )
623
 
 
624
  model_switch_btn.click(
625
  fn=self._switch_model,
626
  inputs=[model_dropdown],
 
634
  """Launch the Gradio interface"""
635
  if not self.interface:
636
  self.build_interface()
637
+ return self.interface.launch(**kwargs)
 
638
 
639
 
640
+ # Initialize and launch the application
641
  def main():
642
+ """Main function optimized for Hugging Face Spaces"""
643
+ print("πŸš€ Starting Qwen2.5 PDF RAG System...")
644
+ print(f"πŸ“± Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
645
 
646
+ # Use the lightweight model by default for HF Spaces
647
+ model_name = "Qwen/Qwen2.5-1.5B-Instruct"
 
 
 
 
648
 
649
+ # Create RAG system
650
+ try:
651
+ rag_system = PDFRagSystem(model_name, persist_directory="chroma_db")
652
+
653
+ # Create and launch UI
654
+ ui = RagUI(rag_system)
655
+ ui.launch(
656
+ share=True,
657
+ server_name="0.0.0.0",
658
+ server_port=7860,
659
+ show_error=True
660
+ )
661
+ except Exception as e:
662
+ print(f"❌ Error starting application: {e}")
663
+ # Create a simple error interface
664
+ def error_interface():
665
+ return "❌ Failed to initialize the RAG system. Please check the logs."
666
+
667
+ error_app = gr.Interface(
668
+ fn=error_interface,
669
+ inputs=[],
670
+ outputs="text",
671
+ title="Error - Qwen2.5 PDF RAG System"
672
+ )
673
+ error_app.launch()
 
 
 
 
674
 
675
  if __name__ == "__main__":
676
+ main()