Spaces:
Build error
Build error
| import os | |
| import json | |
| import tempfile | |
| from typing import List, Dict, Any, Optional | |
| from pathlib import Path | |
| # LangChain imports for RAG | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings | |
| from langchain.chains import RetrievalQA | |
| from langchain.prompts import PromptTemplate | |
| from langchain.schema import Document | |
| # Google Gemini imports | |
| from google import genai | |
| class RAGSystem: | |
| """ | |
| Complete RAG (Retrieval-Augmented Generation) system using Google Gemini | |
| Handles document ingestion, chunking, embedding, and question answering | |
| """ | |
| def __init__(self, persist_directory: str = "./chroma_db"): | |
| """Initialize the RAG system with Google Gemini and ChromaDB""" | |
| self.persist_directory = persist_directory | |
| self.gemini_api_key = None | |
| # Initialize components (lazy loading) | |
| self.embeddings = None | |
| self.llm = None | |
| self.vectorstore = None | |
| self.retriever = None | |
| self.qa_chain = None | |
| # Text splitter for document chunking | |
| self.text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1000, | |
| chunk_overlap=200, | |
| length_function=len, | |
| separators=["\n\n", "\n", " ", ""] | |
| ) | |
| # Track ingested documents | |
| self.ingested_documents = [] | |
| def _initialize_components(self): | |
| """Lazy initialization of Gemini components""" | |
| if self.llm is None: | |
| self.gemini_api_key = os.getenv('GEMINI_API_KEY') | |
| if not self.gemini_api_key: | |
| raise ValueError("GEMINI_API_KEY environment variable must be set") | |
| # Initialize Google Gemini LLM | |
| self.llm = ChatGoogleGenerativeAI( | |
| model="gemini-2.5-flash", | |
| temperature=0.1, | |
| max_tokens=2048, | |
| google_api_key=self.gemini_api_key | |
| ) | |
| # Initialize Google embeddings | |
| self.embeddings = GoogleGenerativeAIEmbeddings( | |
| model="models/text-embedding-004", | |
| google_api_key=self.gemini_api_key | |
| ) | |
| # Initialize or load existing vector store | |
| self._initialize_vectorstore() | |
| def _initialize_vectorstore(self): | |
| """Initialize ChromaDB vector store""" | |
| try: | |
| # Try to load existing vectorstore | |
| if os.path.exists(self.persist_directory): | |
| self.vectorstore = Chroma( | |
| persist_directory=self.persist_directory, | |
| embedding_function=self.embeddings | |
| ) | |
| else: | |
| # Create new empty vectorstore | |
| self.vectorstore = Chroma( | |
| persist_directory=self.persist_directory, | |
| embedding_function=self.embeddings | |
| ) | |
| # Set up retriever | |
| self.retriever = self.vectorstore.as_retriever( | |
| search_type="similarity", | |
| search_kwargs={"k": 5} # Retrieve top 5 most similar chunks | |
| ) | |
| except Exception as e: | |
| raise Exception(f"Failed to initialize vector store: {str(e)}") | |
| def ingest_document(self, text_content: str, metadata: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Ingest a document into the RAG system | |
| Args: | |
| text_content: The full text content of the document | |
| metadata: Document metadata (filename, type, etc.) | |
| Returns: | |
| Dict with ingestion results | |
| """ | |
| try: | |
| # Initialize components if needed | |
| self._initialize_components() | |
| # Create document object | |
| document = Document( | |
| page_content=text_content, | |
| metadata=metadata | |
| ) | |
| # Split document into chunks | |
| chunks = self.text_splitter.split_documents([document]) | |
| # Add chunk numbers to metadata | |
| for i, chunk in enumerate(chunks): | |
| chunk.metadata.update({ | |
| 'chunk_id': i, | |
| 'total_chunks': len(chunks) | |
| }) | |
| # Add chunks to vector store | |
| self.vectorstore.add_documents(chunks) | |
| # Persist the changes | |
| self.vectorstore.persist() | |
| # Track ingested document | |
| doc_info = { | |
| 'filename': metadata.get('filename', 'Unknown'), | |
| 'document_type': metadata.get('document_type', 'Unknown'), | |
| 'chunks_created': len(chunks), | |
| 'ingestion_timestamp': metadata.get('ingestion_timestamp', 'Unknown') | |
| } | |
| self.ingested_documents.append(doc_info) | |
| return { | |
| 'status': 'success', | |
| 'chunks_created': len(chunks), | |
| 'document_info': doc_info | |
| } | |
| except Exception as e: | |
| return { | |
| 'status': 'error', | |
| 'error': str(e) | |
| } | |
| def query(self, question: str, return_source_docs: bool = True) -> Dict[str, Any]: | |
| """ | |
| Query the RAG system with a question | |
| Args: | |
| question: User's question | |
| return_source_docs: Whether to return source documents | |
| Returns: | |
| Dict with answer and source information | |
| """ | |
| try: | |
| # Initialize components if needed | |
| self._initialize_components() | |
| if not self.vectorstore: | |
| return { | |
| 'status': 'error', | |
| 'error': 'No documents have been ingested yet. Please upload and process some PDFs first.' | |
| } | |
| # Create RAG chain if not exists | |
| if not self.qa_chain: | |
| self._setup_qa_chain() | |
| # Execute query | |
| result = self.qa_chain.invoke({ | |
| "query": question, | |
| "return_source_documents": return_source_docs | |
| }) | |
| # Format response | |
| response = { | |
| 'status': 'success', | |
| 'answer': result.get('result', ''), | |
| 'question': question | |
| } | |
| # Add source documents if requested | |
| if return_source_docs and 'source_documents' in result: | |
| response['sources'] = [] | |
| for doc in result['source_documents']: | |
| response['sources'].append({ | |
| 'content': doc.page_content[:200] + '...', # Preview | |
| 'metadata': doc.metadata | |
| }) | |
| return response | |
| except Exception as e: | |
| return { | |
| 'status': 'error', | |
| 'error': f"Query failed: {str(e)}" | |
| } | |
| def _setup_qa_chain(self): | |
| """Set up the question-answering chain with custom prompt""" | |
| # Custom prompt template for better responses | |
| prompt_template = """ | |
| You are an AI assistant that answers questions based on the provided document context. | |
| Use the following context to answer the question accurately and comprehensively. | |
| If the answer cannot be found in the context, say "I don't have enough information in the provided documents to answer this question." | |
| Context: | |
| {context} | |
| Question: {question} | |
| Answer:""" | |
| prompt = PromptTemplate( | |
| template=prompt_template, | |
| input_variables=["context", "question"] | |
| ) | |
| # Create RetrievalQA chain | |
| self.qa_chain = RetrievalQA.from_llm( | |
| llm=self.llm, | |
| retriever=self.retriever, | |
| prompt=prompt, | |
| return_source_documents=True | |
| ) | |
| def get_document_list(self) -> List[Dict[str, Any]]: | |
| """Get list of ingested documents""" | |
| return self.ingested_documents.copy() | |
| def get_vector_store_stats(self) -> Dict[str, Any]: | |
| """Get statistics about the vector store""" | |
| try: | |
| self._initialize_components() | |
| if not self.vectorstore: | |
| return {'total_chunks': 0, 'status': 'empty'} | |
| # Get collection info | |
| collection = self.vectorstore._collection | |
| stats = { | |
| 'total_chunks': collection.count(), | |
| 'total_documents': len(self.ingested_documents), | |
| 'status': 'active' | |
| } | |
| return stats | |
| except Exception as e: | |
| return { | |
| 'status': 'error', | |
| 'error': str(e) | |
| } | |
| def clear_knowledge_base(self) -> Dict[str, Any]: | |
| """Clear all documents from the knowledge base""" | |
| try: | |
| # Delete vector store directory | |
| import shutil | |
| if os.path.exists(self.persist_directory): | |
| shutil.rmtree(self.persist_directory) | |
| # Reset components | |
| self.vectorstore = None | |
| self.qa_chain = None | |
| self.ingested_documents = [] | |
| return {'status': 'success', 'message': 'Knowledge base cleared successfully'} | |
| except Exception as e: | |
| return {'status': 'error', 'error': str(e)} | |
| def search_similar_chunks(self, query: str, k: int = 5) -> List[Dict[str, Any]]: | |
| """Search for similar document chunks""" | |
| try: | |
| self._initialize_components() | |
| if not self.vectorstore: | |
| return [] | |
| # Perform similarity search | |
| docs = self.vectorstore.similarity_search(query, k=k) | |
| results = [] | |
| for doc in docs: | |
| results.append({ | |
| 'content': doc.page_content, | |
| 'metadata': doc.metadata, | |
| 'preview': doc.page_content[:150] + '...' | |
| }) | |
| return results | |
| except Exception as e: | |
| print(f"Search error: {e}") | |
| return [] |