Spaces:
Running
Running
| """ | |
| RAG Agent - Advanced Retrieval-Augmented Generation Agent | |
| This module implements a RAG Agent that: | |
| - Accepts files uploaded from the frontend via FastAPI | |
| - Processes uploaded files dynamically (PDF, TXT, etc.) | |
| - Creates vector embeddings from uploaded content using Weaviate | |
| - Uses Query Decomposition for focused retrieval | |
| - Uses Reciprocal Rank Fusion (RRF) for intelligent result merging | |
| - Returns responses based on the uploaded file content | |
| Requires Weaviate running on localhost:8081 | |
| """ | |
| import logging | |
| import os | |
| import json | |
| import tempfile | |
| import shutil | |
| import re | |
| from typing import Optional, List, Any, Dict | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from dotenv import load_dotenv, find_dotenv | |
| import weaviate | |
| from langchain_community.document_loaders import PyPDFLoader, TextLoader | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_weaviate import WeaviateVectorStore | |
| from langchain_huggingface.embeddings import HuggingFaceEmbeddings | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.documents import Document | |
| logger = logging.getLogger(__name__) | |
| # Load environment variables | |
| _ = load_dotenv(find_dotenv()) | |
| class AdvancedRAGSystem: | |
| """ | |
| Production-ready RAG system with hybrid retrieval + RRF. | |
| Features: | |
| - Hybrid Retrieval: Original query + decomposed sub-queries | |
| - Reciprocal Rank Fusion (RRF): Intelligently merge results | |
| - Keyword Boosting: Prioritize documents with relevant terms | |
| - Cost-efficient: Only 2 LLM calls (decomposition + answer) | |
| - Fully scalable with configurable parameters | |
| """ | |
| def __init__( | |
| self, | |
| vector_store, | |
| llm, | |
| retriever_k: int = 10, | |
| num_sub_queries: int = 2, | |
| rrf_k: int = 60, | |
| keyword_boost: float = 0.25, | |
| top_docs: int = 5 | |
| ): | |
| """ | |
| Initialize the Advanced RAG System. | |
| Args: | |
| vector_store: Weaviate/Pinecone/etc vector store | |
| llm: Language model for decomposition and answer generation | |
| retriever_k: Number of documents to retrieve per query | |
| num_sub_queries: Number of sub-queries to generate (lower = cheaper) | |
| rrf_k: RRF constant (higher = flatter ranking) | |
| keyword_boost: Boost factor per keyword match | |
| top_docs: Number of top documents for final context | |
| """ | |
| self.vector_store = vector_store | |
| self.llm = llm | |
| self.retriever_k = retriever_k | |
| self.num_sub_queries = num_sub_queries | |
| self.rrf_k = rrf_k | |
| self.keyword_boost = keyword_boost | |
| self.top_docs = top_docs | |
| self.retriever = vector_store.as_retriever( | |
| search_type="similarity", | |
| search_kwargs={"k": retriever_k} | |
| ) | |
| self._build_chains() | |
| logger.info(f"AdvancedRAGSystem initialized with k={retriever_k}, sub_queries={num_sub_queries}") | |
| def _build_chains(self): | |
| """Build the internal LangChain pipelines.""" | |
| # Query decomposition prompt | |
| decomposition_template = f"""Rewrite this question into {self.num_sub_queries} specific search queries. | |
| RULES: | |
| 1. Include technical keywords that would appear in documentation | |
| 2. Focus on syntax, commands, and implementation details | |
| 3. Keep the core topic but make it more specific | |
| Question: {{question}} | |
| Write {self.num_sub_queries} search queries (one per line):""" | |
| self.decomposition_prompt = ChatPromptTemplate.from_template(decomposition_template) | |
| # Build decomposer chain | |
| self.query_decomposer = ( | |
| self.decomposition_prompt | |
| | self.llm | |
| | StrOutputParser() | |
| | (lambda x: [q.strip() for q in x.strip().split("\n") if q.strip() and len(q.strip()) > 5][:self.num_sub_queries]) | |
| ) | |
| # RAG answer prompt | |
| self.rag_prompt = ChatPromptTemplate.from_template("""Answer the question using ONLY the provided context. | |
| Context: | |
| {context} | |
| Question: {question} | |
| Instructions: | |
| - Use only information from the context | |
| - If the answer isn't in the context, say "I don't have enough information" | |
| - Be specific and cite relevant details | |
| - Format your answer clearly""") | |
| def _extract_keywords(self, question: str) -> List[str]: | |
| """Extract keywords from question for boosting.""" | |
| stop_words = {'what', 'how', 'why', 'when', 'where', 'is', 'are', 'the', | |
| 'a', 'an', 'to', 'in', 'for', 'of', 'and', 'or', 'can', 'do', | |
| 'explain', 'describe', 'tell', 'me', 'about'} | |
| words = question.lower().replace('?', '').replace('.', '').split() | |
| keywords = [w for w in words if w not in stop_words and len(w) > 2] | |
| return keywords | |
| def _reciprocal_rank_fusion(self, results: List[List], keywords: List[str] = None) -> List: | |
| """Apply RRF to merge multiple ranked document lists with keyword boosting.""" | |
| fused_scores = defaultdict(float) | |
| doc_map = {} | |
| for doc_list in results: | |
| for rank, doc in enumerate(doc_list): | |
| doc_key = (doc.page_content, json.dumps(doc.metadata, sort_keys=True, default=str)) | |
| # Base RRF score: 1 / (k + rank + 1) | |
| score = 1 / (self.rrf_k + rank + 1) | |
| # Apply keyword boost | |
| if keywords: | |
| content_lower = doc.page_content.lower() | |
| matches = sum(1 for kw in keywords if kw in content_lower) | |
| score *= (1 + self.keyword_boost * matches) | |
| fused_scores[doc_key] += score | |
| if doc_key not in doc_map: | |
| doc_map[doc_key] = doc | |
| # Sort by fused score (descending) | |
| reranked = sorted( | |
| [(doc_map[k], s) for k, s in fused_scores.items()], | |
| key=lambda x: x[1], | |
| reverse=True | |
| ) | |
| return [doc for doc, _ in reranked] | |
| def _format_context(self, docs: List) -> str: | |
| """Format documents into context string.""" | |
| return "\n\n".join( | |
| f"[Doc {i+1}] {doc.page_content}" | |
| for i, doc in enumerate(docs[:self.top_docs]) | |
| ) | |
| def retrieve(self, question: str) -> List: | |
| """ | |
| Hybrid retrieval: original query + decomposed queries + RRF. | |
| Args: | |
| question: User's question | |
| Returns: | |
| List of relevant documents ranked by RRF score | |
| """ | |
| keywords = self._extract_keywords(question) | |
| all_results = [] | |
| # 1. ALWAYS include original query results | |
| original_docs = self.retriever.invoke(question) | |
| all_results.append(original_docs) | |
| # 2. Add decomposed sub-query results | |
| try: | |
| sub_queries = self.query_decomposer.invoke({"question": question}) | |
| for sq in sub_queries: | |
| docs = self.retriever.invoke(sq) | |
| all_results.append(docs) | |
| except Exception as e: | |
| logger.warning(f"Sub-query decomposition skipped: {str(e)[:50]}") | |
| # 3. Apply RRF with keyword boosting | |
| ranked_docs = self._reciprocal_rank_fusion(all_results, keywords) | |
| return ranked_docs | |
| def query(self, question: str) -> str: | |
| """ | |
| Full RAG pipeline: retrieve + generate answer. | |
| Args: | |
| question: User's question | |
| Returns: | |
| Generated answer based on retrieved context | |
| """ | |
| docs = self.retrieve(question) | |
| context = self._format_context(docs) | |
| chain = self.rag_prompt | self.llm | StrOutputParser() | |
| return chain.invoke({"context": context, "question": question}) | |
| class RAGAgent: | |
| """ | |
| RAG Agent - Handles document-based question answering with files from frontend. | |
| This agent: | |
| - Receives files uploaded from the frontend via FastAPI | |
| - Processes uploaded files (PDF, TXT, etc.) | |
| - Creates vector embeddings using Weaviate | |
| - Answers questions based on the uploaded file content | |
| """ | |
| def __init__( | |
| self, | |
| weaviate_port: int = 8081, | |
| index_name: str = "UploadedDocuments", | |
| retriever_k: int = 10, | |
| num_sub_queries: int = 2, | |
| chunk_size: int = 1000, | |
| chunk_overlap: int = 200, | |
| ): | |
| """ | |
| Initialize the RAG Agent. | |
| Args: | |
| weaviate_port: Port where Weaviate is running | |
| index_name: Name for the Weaviate index | |
| retriever_k: Documents to retrieve per query | |
| num_sub_queries: Sub-queries to generate | |
| chunk_size: Size of text chunks | |
| chunk_overlap: Overlap between chunks | |
| """ | |
| self.weaviate_port = weaviate_port | |
| self.index_name = index_name | |
| self.retriever_k = retriever_k | |
| self.num_sub_queries = num_sub_queries | |
| self.chunk_size = chunk_size | |
| self.chunk_overlap = chunk_overlap | |
| # Will be set when processing a file | |
| self.weaviate_client = None | |
| self.llm = None | |
| self.embeddings = None | |
| # Per-conversation state ("session" here means a chat/conversation id) | |
| # session_id -> {"vector_store": ..., "rag_system": ..., "current_file_name": str, "index_name": str} | |
| self._sessions: Dict[str, Dict[str, Any]] = {} | |
| # Temp directory for uploaded files | |
| self.temp_dir = tempfile.mkdtemp(prefix="rag_uploads_") | |
| # Initialize embeddings and LLM | |
| self._init_embeddings() | |
| self._init_llm() | |
| logger.info("RAG Agent initialized - ready to receive files from frontend") | |
| def _normalize_session_id(self, session_id: Optional[str]) -> str: | |
| """Normalize a conversation/session id into a safe, stable identifier.""" | |
| if not session_id: | |
| return "default" | |
| session_id = str(session_id).strip() | |
| if not session_id: | |
| return "default" | |
| # Allow only safe characters; cap length to avoid huge class names | |
| session_id = re.sub(r"[^a-zA-Z0-9_-]", "_", session_id)[:64] | |
| return session_id or "default" | |
| def _index_name_for_session(self, session_id: str) -> str: | |
| """Build a Weaviate index/class name for a session.""" | |
| session_id = self._normalize_session_id(session_id) | |
| # Keep the base index name stable and ensure it starts with a letter (Weaviate class naming rules) | |
| base = re.sub(r"[^a-zA-Z0-9_]", "_", str(self.index_name)) or "UploadedDocuments" | |
| if not base[0].isalpha(): | |
| base = f"C_{base}" | |
| return f"{base}_{session_id}" | |
| def _delete_index_best_effort(self, index_name: str) -> None: | |
| """Delete a Weaviate collection/index if it exists (best-effort).""" | |
| if self.weaviate_client is None: | |
| return | |
| try: | |
| # Weaviate client v4 | |
| self.weaviate_client.collections.delete(index_name) | |
| logger.info(f"Deleted Weaviate index: {index_name}") | |
| except Exception: | |
| # Ignore if it doesn't exist or deletion isn't supported | |
| pass | |
| def _get_session(self, session_id: Optional[str]) -> Dict[str, Any]: | |
| sid = self._normalize_session_id(session_id) | |
| return self._sessions.get(sid, {}) | |
| def _init_embeddings(self): | |
| """Initialize embeddings model.""" | |
| try: | |
| logger.info("Loading embeddings model...") | |
| self.embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-mpnet-base-v2" | |
| ) | |
| logger.info("✅ Embeddings model loaded") | |
| except Exception as e: | |
| logger.error(f"Failed to load embeddings: {e}") | |
| raise | |
| def _init_llm(self): | |
| """Initialize LLM.""" | |
| try: | |
| logger.info("Initializing LLM for RAG...") | |
| openrouter_api_key = os.getenv("OPENROUTER_API_KEY", "").strip().strip('"').strip("'") | |
| if not openrouter_api_key or openrouter_api_key.startswith("your-"): | |
| raise RuntimeError("Missing or invalid OPENROUTER_API_KEY environment variable") | |
| self.llm = ChatOpenAI( | |
| model="xiaomi/mimo-v2-flash:free", | |
| temperature=0, | |
| openai_api_key=openrouter_api_key, | |
| openai_api_base="https://openrouter.ai/api/v1", | |
| ) | |
| logger.info("✅ LLM initialized for RAG") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize LLM: {e}") | |
| raise | |
| def _connect_weaviate(self): | |
| """Connect to Weaviate if not already connected.""" | |
| if self.weaviate_client is None: | |
| logger.info(f"Connecting to Weaviate on port {self.weaviate_port}...") | |
| self.weaviate_client = weaviate.connect_to_local(host= "192.168.1.5",port=self.weaviate_port) | |
| if not self.weaviate_client.is_ready(): | |
| raise RuntimeError(f"Weaviate is not ready at localhost:{self.weaviate_port}") | |
| logger.info("✅ Weaviate connected") | |
| def _load_file(self, file_path: str) -> List[Document]: | |
| """Load a file and return documents.""" | |
| file_ext = Path(file_path).suffix.lower() | |
| if file_ext == ".pdf": | |
| loader = PyPDFLoader(file_path) | |
| elif file_ext in [".txt", ".md", ".py", ".js", ".json", ".csv"]: | |
| loader = TextLoader(file_path, encoding="utf-8") | |
| else: | |
| # Try as text file | |
| loader = TextLoader(file_path, encoding="utf-8") | |
| return loader.load() | |
| def process_file_from_bytes(self, file_content: bytes, filename: str, session_id: Optional[str] = None) -> Dict[str, Any]: | |
| """ | |
| Process a file uploaded from the frontend (synchronous). | |
| Args: | |
| file_content: Raw bytes of the uploaded file | |
| filename: Original filename | |
| Returns: | |
| Dict with status and info about the processed file | |
| """ | |
| try: | |
| session_id = self._normalize_session_id(session_id) | |
| logger.info(f"Processing uploaded file: {filename}") | |
| # Connect to Weaviate | |
| self._connect_weaviate() | |
| # Save file temporarily (avoid trusting user filename for paths) | |
| suffix = Path(filename).suffix if filename else "" | |
| with tempfile.NamedTemporaryFile(delete=False, dir=self.temp_dir, suffix=suffix, prefix="upload_") as tmp: | |
| tmp.write(file_content) | |
| file_path = tmp.name | |
| logger.info(f"File saved to: {file_path}") | |
| # Load documents from file | |
| documents = self._load_file(file_path) | |
| logger.info(f"✅ Loaded {len(documents)} pages/sections from {filename}") | |
| # Split into chunks | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=self.chunk_size, | |
| chunk_overlap=self.chunk_overlap | |
| ) | |
| docs = text_splitter.split_documents(documents) | |
| logger.info(f"✅ Split into {len(docs)} chunks") | |
| # Use a per-session index so multiple conversations don't mix documents. | |
| session_index_name = self._index_name_for_session(session_id) | |
| # Replace any prior session index (ChatGPT-like behavior: latest upload becomes active) | |
| self._delete_index_best_effort(session_index_name) | |
| # Create vector store with Weaviate | |
| logger.info("Creating vector embeddings with Weaviate...") | |
| vector_store = WeaviateVectorStore.from_documents( | |
| documents=docs, | |
| embedding=self.embeddings, | |
| client=self.weaviate_client, | |
| index_name=session_index_name, | |
| text_key="text", | |
| ) | |
| logger.info("✅ Vector store created with Weaviate") | |
| # Create RAG system | |
| rag_system = AdvancedRAGSystem( | |
| vector_store=vector_store, | |
| llm=self.llm, | |
| retriever_k=self.retriever_k, | |
| num_sub_queries=self.num_sub_queries, | |
| ) | |
| # Persist per-session state | |
| self._sessions[session_id] = { | |
| "vector_store": vector_store, | |
| "rag_system": rag_system, | |
| "current_file_name": filename, | |
| "index_name": session_index_name, | |
| } | |
| logger.info(f"✅ RAG system ready for session={session_id}, file={filename}") | |
| # Clean up temp file | |
| try: | |
| os.remove(file_path) | |
| except: | |
| pass | |
| return { | |
| "success": True, | |
| "filename": filename, | |
| "session_id": session_id, | |
| "pages": len(documents), | |
| "chunks": len(docs), | |
| "message": f"Successfully processed {filename}. Ready to answer questions." | |
| } | |
| except Exception as e: | |
| logger.error(f"Error processing file {filename}: {e}", exc_info=True) | |
| return { | |
| "success": False, | |
| "filename": filename, | |
| "session_id": session_id, | |
| "error": str(e), | |
| "message": f"Failed to process {filename}: {str(e)}" | |
| } | |
| def initialize(self) -> bool: | |
| """Initialize RAG Agent - connect to Weaviate.""" | |
| try: | |
| self._connect_weaviate() | |
| logger.info("RAG Agent ready (using Weaviate)") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to initialize RAG Agent: {e}") | |
| return False | |
| def retrieve_context(self, question: str, session_id: Optional[str] = None) -> str: | |
| """ | |
| Retrieve relevant context from the uploaded file for a question. | |
| Args: | |
| question: User's question | |
| Returns: | |
| Retrieved context as a string | |
| """ | |
| session_id = self._normalize_session_id(session_id) | |
| rag_system = self._sessions.get(session_id, {}).get("rag_system") | |
| if not rag_system: | |
| return "" | |
| try: | |
| docs = rag_system.retrieve(question) | |
| context = rag_system._format_context(docs) | |
| logger.info(f"Retrieved {len(docs)} relevant chunks for question") | |
| return context | |
| except Exception as e: | |
| logger.error(f"Error retrieving context: {e}") | |
| return "" | |
| def answer_question(self, question: str, session_id: Optional[str] = None) -> str: | |
| """ | |
| Answer a question based on the uploaded file. | |
| Args: | |
| question: User's question about the uploaded file | |
| Returns: | |
| Generated answer based on the file content | |
| """ | |
| session_id = self._normalize_session_id(session_id) | |
| rag_system = self._sessions.get(session_id, {}).get("rag_system") | |
| if not rag_system: | |
| return "No file has been uploaded yet. Please upload a file first before asking questions." | |
| try: | |
| logger.info(f"Processing RAG query: {question[:50]}...") | |
| answer = rag_system.query(question) | |
| logger.info("✅ RAG query processed successfully") | |
| return answer | |
| except Exception as e: | |
| logger.error(f"Error processing RAG query: {e}", exc_info=True) | |
| return f"Error processing query: {str(e)}" | |
| def has_file_loaded(self, session_id: Optional[str] = None) -> bool: | |
| """Check if a file has been processed and is ready for queries (per session).""" | |
| session_id = self._normalize_session_id(session_id) | |
| return bool(self._sessions.get(session_id, {}).get("rag_system")) | |
| def get_current_file(self, session_id: Optional[str] = None) -> Optional[str]: | |
| """Get the name of the currently loaded file (per session).""" | |
| session_id = self._normalize_session_id(session_id) | |
| return self._sessions.get(session_id, {}).get("current_file_name") | |
| def clear(self, session_id: Optional[str] = None): | |
| """Clear the current file and vector store for a session.""" | |
| session_id = self._normalize_session_id(session_id) | |
| session = self._sessions.pop(session_id, None) | |
| if session and session.get("index_name"): | |
| self._delete_index_best_effort(session["index_name"]) | |
| logger.info(f"RAG Agent cleared for session={session_id} - ready for new file") | |
| def close(self): | |
| """Close connections and cleanup.""" | |
| try: | |
| # Close Weaviate connection | |
| if self.weaviate_client is not None: | |
| self.weaviate_client.close() | |
| self.weaviate_client = None | |
| logger.info("✅ Weaviate connection closed") | |
| # Clean up temp directory | |
| if os.path.exists(self.temp_dir): | |
| shutil.rmtree(self.temp_dir, ignore_errors=True) | |
| self._sessions.clear() | |
| logger.info("✅ RAG Agent cleanup complete") | |
| except Exception as e: | |
| logger.warning(f"Error during cleanup: {e}") | |
| # ============================================================================ | |
| # GLOBAL RAG AGENT INSTANCE | |
| # ============================================================================ | |
| _rag_agent: Optional[RAGAgent] = None | |
| def get_rag_agent() -> RAGAgent: | |
| """Get or create the global RAG Agent instance.""" | |
| global _rag_agent | |
| if _rag_agent is None: | |
| _rag_agent = RAGAgent() | |
| return _rag_agent | |
| def process_uploaded_file(file_content: bytes, filename: str, session_id: Optional[str] = None) -> Dict[str, Any]: | |
| """ | |
| Process a file uploaded from the frontend. | |
| This function is called by FastAPI when a file is uploaded. | |
| Args: | |
| file_content: Raw bytes of the uploaded file | |
| filename: Original filename | |
| Returns: | |
| Dict with status and info about the processed file | |
| """ | |
| agent = get_rag_agent() | |
| return agent.process_file_from_bytes(file_content, filename, session_id=session_id) | |
| def retrieve_context_for_query(question: str, session_id: Optional[str] = None) -> str: | |
| """ | |
| Retrieve relevant context from uploaded file for a query. | |
| Args: | |
| question: User's question | |
| Returns: | |
| Retrieved context string | |
| """ | |
| agent = get_rag_agent() | |
| return agent.retrieve_context(question, session_id=session_id) | |
| async def answer_rag_question(question: str, session_id: Optional[str] = None) -> str: | |
| """ | |
| Answer a question using the RAG Agent. | |
| Args: | |
| question: User's question | |
| Returns: | |
| RAG-generated answer | |
| """ | |
| agent = get_rag_agent() | |
| return agent.answer_question(question, session_id=session_id) | |
| def has_file_loaded(session_id: Optional[str] = None) -> bool: | |
| """Check if a file has been loaded into the RAG agent (per session).""" | |
| agent = get_rag_agent() | |
| return agent.has_file_loaded(session_id=session_id) | |
| def cleanup_rag_agent(): | |
| """Cleanup RAG Agent resources.""" | |
| global _rag_agent | |
| if _rag_agent is not None: | |
| _rag_agent.close() | |
| _rag_agent = None | |
| logger.info("RAG Agent cleaned up") | |
| # ============================================================================ | |
| # FOR TESTING | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| import asyncio | |
| logging.basicConfig(level=logging.INFO) | |
| async def test_rag_agent(): | |
| """Test the RAG Agent with a sample in-memory file.""" | |
| print("=" * 80) | |
| print("RAG AGENT TEST") | |
| print("=" * 80) | |
| session_id = "local_test" | |
| sample_content = b""" | |
| Python is a high-level programming language. | |
| It was created by Guido van Rossum in 1991. | |
| Python is known for its simple syntax and readability. | |
| It supports multiple programming paradigms including procedural, object-oriented, and functional programming. | |
| Python has a large standard library and active community. | |
| """ | |
| result = process_uploaded_file(sample_content, "sample.txt", session_id=session_id) | |
| print(f"\nFile processing result: {result}") | |
| if result.get("success"): | |
| question = "Who created Python?" | |
| answer = await answer_rag_question(question, session_id=session_id) | |
| print(f"\nQ: {question}") | |
| print(f"A: {answer}") | |
| cleanup_rag_agent() | |
| asyncio.run(test_rag_agent()) | |