Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| from typing import List, Optional, Dict, Any | |
| import logging | |
| from enum import Enum | |
| from dataclasses import dataclass | |
| import gradio as gr | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import Chroma | |
| from langchain.embeddings.base import Embeddings | |
| import PyPDF2 | |
| from huggingface_hub import InferenceClient | |
| import torch | |
| from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
| # Install required packages | |
| embed_model = HuggingFaceBgeEmbeddings( | |
| model_name="all-MiniLM-L6-v2",#"dunzhang/stella_en_1.5B_v5", | |
| model_kwargs={'device': 'cpu'}, | |
| encode_kwargs={'normalize_embeddings': True} | |
| ) | |
| model_name = "meta-llama/Llama-3.2-3B-Instruct"#"google/gemma-2-2b-it"#"prithivMLmods/Llama-3.2-3B-GGUF" | |
| client = InferenceClient(model_name) | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class DocumentFormat(Enum): | |
| PDF = ".pdf" | |
| # Can be extended for other document types | |
| class RAGConfig: | |
| """Configuration for RAG system parameters""" | |
| chunk_size: int = 100 | |
| chunk_overlap: int = 10 | |
| retriever_k: int = 3 | |
| persist_directory: str = "./chroma_db" | |
| class AdvancedRAGSystem: | |
| """Advanced RAG System with improved error handling and type safety""" | |
| def __init__( | |
| self, | |
| embed_model, | |
| llm, | |
| config = None | |
| ): | |
| """Initialize the RAG system with required models and optional configuration""" | |
| self.embed_model = embed_model | |
| self.llm = llm | |
| self.config = config or RAGConfig() | |
| self.vector_store: Optional[Chroma] = None | |
| self.last_context: Optional[str] = None | |
| self.context = None | |
| self.source_documents = 0 | |
| def _validate_file(self, file_path: Path) : | |
| """Validate if the file is of supported format and exists""" | |
| return file_path.suffix.lower() == DocumentFormat.PDF.value and file_path.exists() | |
| def _extract_text_from_pdf(self, pdf_path: Path) : | |
| """Extract text from a PDF file with proper error handling""" | |
| try: | |
| with open(pdf_path, 'rb') as file: | |
| pdf_reader = PyPDF2.PdfReader(file) | |
| return "\n".join( | |
| page.extract_text() | |
| for page in pdf_reader.pages | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing PDF {pdf_path}: {str(e)}") | |
| raise ValueError(f"Failed to process PDF {pdf_path}: {str(e)}") | |
| def _create_document_chunks(self, texts: List[str]) : | |
| """Split documents into chunks using the configured parameters""" | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=self.config.chunk_size, | |
| chunk_overlap=self.config.chunk_overlap, | |
| length_function=len, | |
| add_start_index=True, | |
| ) | |
| return text_splitter.create_documents(texts) | |
| def process_pdfs(self, pdf_files: List[str]) : | |
| """Process and index PDF documents with improved error handling""" | |
| try: | |
| # Convert to Path objects and validate | |
| pdf_paths = [Path(pdf.name) for pdf in pdf_files] | |
| invalid_files = [f for f in pdf_paths if not self._validate_file(f)] | |
| if invalid_files: | |
| raise ValueError(f"Invalid or missing files: {invalid_files}") | |
| # Extract text from valid PDFs | |
| documents = [ | |
| self._extract_text_from_pdf(pdf_path) | |
| for pdf_path in pdf_paths | |
| ] | |
| # Create document chunks | |
| doc_chunks = self._create_document_chunks(documents) | |
| # Initialize or update vector store | |
| self.vector_store = Chroma.from_documents( | |
| documents=doc_chunks, | |
| embedding=self.embed_model, | |
| persist_directory=self.config.persist_directory | |
| ) | |
| logger.info(f"Successfully processed {len(doc_chunks)} chunks from {len(pdf_files)} PDF files") | |
| return f"Successfully processed {len(doc_chunks)} chunks from {len(pdf_files)} PDF files" | |
| except Exception as e: | |
| error_msg = f"Error during PDF processing: {str(e)}" | |
| logger.error(error_msg) | |
| raise RuntimeError(error_msg) | |
| def get_retriever(self) : | |
| """Get the document retriever with current configuration""" | |
| if not self.vector_store: | |
| raise RuntimeError("Vector store not initialized. Please process documents first.") | |
| return self.vector_store.as_retriever(search_kwargs={"k": self.config.retriever_k}) | |
| def _format_context(self, documents: List[Any]) : | |
| """Format retrieved documents into a single context string""" | |
| return "\n\n".join(doc.page_content for doc in documents) | |
| def query(self, question: str) : | |
| """Query the RAG system with improved error handling and response formatting""" | |
| try: | |
| if not self.vector_store: | |
| raise RuntimeError("Please process PDF documents first before querying") | |
| # Retrieve relevant documents | |
| retriever = self.get_retriever() | |
| retrieved_docs = retriever.get_relevant_documents(question) | |
| context = self._format_context(retrieved_docs) | |
| self.last_context = context | |
| self.context = context | |
| self.source_documents = len(retrieved_docs) | |
| messages = [ | |
| { | |
| "role":"system", | |
| "content":f"""You are a helpful assistant. Use the following pieces of context to answer the question at the end. | |
| If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
| Context: | |
| {context} | |
| """ | |
| }, | |
| { | |
| "role": "user", | |
| "content": question | |
| } | |
| ] | |
| return self.llm.chat.completions.create( | |
| model=model_name, | |
| messages=messages, | |
| max_tokens=500, | |
| # stream=True | |
| ).choices[0].message.content | |
| except Exception as e: | |
| error_msg = f"Error during query processing: {str(e)}" | |
| logger.error(error_msg) | |
| return error_msg | |
| def create_gradio_interface(rag_system: AdvancedRAGSystem) : | |
| """Create an improved Gradio interface for the RAG system""" | |
| def process_files(files: List[Any], chunk_size: int, overlap: int) : | |
| """Process uploaded files with updated configuration""" | |
| if not files: | |
| return "Please upload PDF files" | |
| # Update configuration with new parameters | |
| rag_system.config.chunk_size = chunk_size | |
| rag_system.config.chunk_overlap = overlap | |
| try: | |
| return rag_system.process_pdfs(files) | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def query_streaming(question: str) : | |
| try: | |
| return rag_system.query(question) | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def update_history(question: str): | |
| try: | |
| return f"Last context used ({rag_system.source_documents} documents):\n\n{rag_system.context}" | |
| except Exception as e: | |
| return f"Error retrieving context: {str(e)}" | |
| with gr.Blocks(title="Advanced RAG System") as demo: | |
| gr.Markdown("# Advanced RAG System with PDF Processing") | |
| with gr.Tab("Upload & Process PDFs"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| file_input = gr.File( | |
| file_count="multiple", | |
| label="Upload PDF Documents", | |
| file_types=[".pdf"] | |
| ) | |
| chunk_size = gr.Slider( | |
| minimum=100, | |
| maximum=10000, | |
| value=100, | |
| step=100, | |
| label="Chunk Size" | |
| ) | |
| overlap = gr.Slider( | |
| minimum=10, | |
| maximum=5000, | |
| value=10, | |
| step=10, | |
| label="Chunk Overlap" | |
| ) | |
| process_button = gr.Button("Process PDFs", variant="primary") | |
| process_output = gr.Textbox(label="Processing Status") | |
| with gr.Tab("Query System"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| question_input = gr.Textbox( | |
| label="Your Question", | |
| placeholder="Enter your question here...", | |
| lines=3 | |
| ) | |
| query_button = gr.Button("Get Answer", variant="primary") | |
| answer_output = gr.Textbox( | |
| label="Answer", | |
| lines=10 | |
| ) | |
| with gr.Column(scale=1): | |
| history_output = gr.Textbox( | |
| label="Retrieved Context", | |
| lines=15 | |
| ) | |
| # Set up event handlers | |
| process_button.click( | |
| fn=process_files, | |
| inputs=[file_input, chunk_size, overlap], | |
| outputs=[process_output] | |
| ) | |
| query_button.click( | |
| fn=query_streaming, | |
| inputs=[question_input], | |
| outputs=[answer_output], | |
| # api_name="stream_response", | |
| # queue=False, | |
| # show_progress=False | |
| ).then( | |
| fn=update_history, | |
| inputs=[question_input], | |
| outputs=[history_output] | |
| ) | |
| return demo | |
| rag_system = AdvancedRAGSystem(embed_model, client) | |
| demo = create_gradio_interface(rag_system) | |
| if __name__ == "__main__": | |
| demo.launch() | |