Spaces:
Sleeping
Sleeping
| """RAG retrieval pipelines: Base-RAG and Hier-RAG.""" | |
| import time | |
| import logging | |
| from typing import List, Dict, Any, Optional, Tuple | |
| from core.index import VectorStore, IndexManager | |
| from openai import OpenAI | |
| import openai | |
| import os | |
| from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type | |
| logger = logging.getLogger(__name__) | |
| class BaseRAG: | |
| """Standard RAG pipeline without hierarchical filtering.""" | |
| def __init__( | |
| self, | |
| vector_store: VectorStore, | |
| llm_model: str = "gpt-3.5-turbo", | |
| api_key: Optional[str] = None | |
| ): | |
| """ | |
| Initialize Base RAG pipeline. | |
| Args: | |
| vector_store: Vector store instance | |
| llm_model: OpenAI model name | |
| api_key: OpenAI API key | |
| """ | |
| self.vector_store = vector_store | |
| self.llm_model = llm_model | |
| # Set OpenAI API key | |
| self.api_key = api_key or os.getenv("OPENAI_API_KEY") | |
| self.client = OpenAI(api_key=self.api_key) | |
| def retrieve( | |
| self, | |
| query: str, | |
| n_results: int = 5 | |
| ) -> Tuple[List[Dict[str, Any]], float]: | |
| """ | |
| Retrieve relevant documents. | |
| Args: | |
| query: Search query | |
| n_results: Number of results to retrieve | |
| Returns: | |
| Tuple of (results, retrieval_time) | |
| """ | |
| start_time = time.time() | |
| results = self.vector_store.search(query, n_results=n_results) | |
| retrieval_time = time.time() - start_time | |
| logger.info(f"Retrieved {len(results)} documents in {retrieval_time:.3f}s") | |
| return results, retrieval_time | |
| def generate( | |
| self, | |
| query: str, | |
| contexts: List[str], | |
| max_tokens: int = 500 | |
| ) -> Tuple[str, float]: | |
| """ | |
| Generate answer using LLM with retry logic. | |
| Args: | |
| query: User query | |
| contexts: Retrieved context documents | |
| max_tokens: Maximum tokens in response | |
| Returns: | |
| Tuple of (answer, generation_time) | |
| """ | |
| # Build prompt | |
| context_text = "\n\n".join([f"Context {i+1}:\n{ctx}" for i, ctx in enumerate(contexts)]) | |
| prompt = f"""Based on the following context documents, please answer the question. | |
| {context_text} | |
| Question: {query} | |
| Answer:""" | |
| start_time = time.time() | |
| try: | |
| response = self.client.chat.completions.create( | |
| model=self.llm_model, | |
| messages=[ | |
| {"role": "system", "content": "You are a helpful assistant that answers questions based on provided context."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| max_tokens=max_tokens, | |
| temperature=0.3 | |
| ) | |
| answer = response.choices[0].message.content | |
| generation_time = time.time() - start_time | |
| logger.info(f"Generated answer in {generation_time:.3f}s") | |
| return answer, generation_time | |
| except openai.AuthenticationError as e: | |
| logger.error(f"Authentication failed: {str(e)}") | |
| return "❌ **Authentication Error**: Invalid OpenAI API key. Please check your credentials in Settings → Secrets.", 0 | |
| except openai.RateLimitError as e: | |
| logger.error(f"Rate limit exceeded: {str(e)}") | |
| return "⚠️ **Rate Limit Exceeded**: Too many requests. Please wait a moment and try again.", 0 | |
| except openai.APITimeoutError as e: | |
| logger.error(f"API timeout: {str(e)}") | |
| return "⏱️ **Timeout Error**: Request took too long. Please try again with a shorter query.", 0 | |
| except openai.APIConnectionError as e: | |
| logger.error(f"Connection error: {str(e)}") | |
| return "🌐 **Connection Error**: Unable to reach OpenAI API. Please check your internet connection.", 0 | |
| except Exception as e: | |
| logger.error(f"Unexpected error during generation: {str(e)}") | |
| return f"❌ **Error**: {str(e)}", 0 | |
| def query( | |
| self, | |
| query: str, | |
| n_results: int = 5, | |
| max_tokens: int = 500 | |
| ) -> Dict[str, Any]: | |
| """ | |
| Complete RAG pipeline: retrieve + generate. | |
| Args: | |
| query: User query | |
| n_results: Number of documents to retrieve | |
| max_tokens: Maximum tokens in response | |
| Returns: | |
| Dictionary with answer, contexts, and timing info | |
| """ | |
| # Retrieve | |
| results, retrieval_time = self.retrieve(query, n_results) | |
| # Extract contexts | |
| contexts = [r["document"] for r in results] | |
| # Generate | |
| answer, generation_time = self.generate(query, contexts, max_tokens) | |
| total_time = retrieval_time + generation_time | |
| logger.info(f"Base-RAG query completed in {total_time:.3f}s (retrieval: {retrieval_time:.3f}s, generation: {generation_time:.3f}s)") | |
| return { | |
| "query": query, | |
| "answer": answer, | |
| "contexts": results, | |
| "retrieval_time": retrieval_time, | |
| "generation_time": generation_time, | |
| "total_time": total_time, | |
| "pipeline": "Base-RAG" | |
| } | |
| class HierarchicalRAG: | |
| """Hierarchical RAG pipeline with metadata filtering.""" | |
| def __init__( | |
| self, | |
| vector_store: VectorStore, | |
| llm_model: str = "gpt-3.5-turbo", | |
| api_key: Optional[str] = None | |
| ): | |
| """ | |
| Initialize Hierarchical RAG pipeline. | |
| Args: | |
| vector_store: Vector store instance | |
| llm_model: OpenAI model name | |
| api_key: OpenAI API key | |
| """ | |
| self.vector_store = vector_store | |
| self.llm_model = llm_model | |
| # Set OpenAI API key | |
| self.api_key = api_key or os.getenv("OPENAI_API_KEY") | |
| self.client = OpenAI(api_key=self.api_key) | |
| def infer_hierarchy_from_query(self, query: str) -> Dict[str, Optional[str]]: | |
| """ | |
| Infer hierarchical filters from query using simple keyword matching. | |
| Args: | |
| query: User query | |
| Returns: | |
| Dictionary with level1, level2, level3, doc_type filters | |
| """ | |
| query_lower = query.lower() | |
| # This is a simple heuristic - in production, use an LLM classifier | |
| filters = { | |
| "level1": None, | |
| "level2": None, | |
| "level3": None, | |
| "doc_type": None | |
| } | |
| # Simple keyword-based inference (can be enhanced with LLM) | |
| # Hospital domain keywords | |
| if any(kw in query_lower for kw in ["patient", "clinical", "medical", "treatment", "admission", "hospital", "nurse", "doctor"]): | |
| filters["level1"] = "Clinical Care" | |
| elif any(kw in query_lower for kw in ["policy", "compliance", "administrative", "staff"]): | |
| filters["level1"] = "Administrative" | |
| elif any(kw in query_lower for kw in ["infection", "safety", "quality", "incident", "error"]): | |
| filters["level1"] = "Quality & Safety" | |
| elif any(kw in query_lower for kw in ["training", "education", "course", "certification"]): | |
| filters["level1"] = "Education & Training" | |
| # Bank domain keywords | |
| elif any(kw in query_lower for kw in ["account", "loan", "banking", "retail", "customer", "deposit"]): | |
| filters["level1"] = "Retail Banking" | |
| elif any(kw in query_lower for kw in ["risk", "credit", "fraud", "default"]): | |
| filters["level1"] = "Risk Management" | |
| elif any(kw in query_lower for kw in ["compliance", "kyc", "aml", "regulatory", "legal"]): | |
| filters["level1"] = "Compliance & Legal" | |
| elif any(kw in query_lower for kw in ["corporate", "business", "commercial", "treasury"]): | |
| filters["level1"] = "Corporate Banking" | |
| # Fluid simulation keywords | |
| elif any(kw in query_lower for kw in ["turbulence", "flow", "simulation", "cfd", "solver", "algorithm"]): | |
| filters["level1"] = "Physical Models" | |
| elif any(kw in query_lower for kw in ["mesh", "grid", "discretization", "numerical", "finite"]): | |
| filters["level1"] = "Numerical Methods" | |
| elif any(kw in query_lower for kw in ["validation", "verification", "benchmark", "accuracy"]): | |
| filters["level1"] = "Validation & Verification" | |
| elif any(kw in query_lower for kw in ["software", "tool", "platform", "parallel", "computing"]): | |
| filters["level1"] = "Software & Tools" | |
| # Doc type inference | |
| if any(kw in query_lower for kw in ["policy", "policies"]): | |
| filters["doc_type"] = "policy" | |
| elif any(kw in query_lower for kw in ["manual", "guide", "handbook"]): | |
| filters["doc_type"] = "manual" | |
| elif any(kw in query_lower for kw in ["report", "analysis", "findings"]): | |
| filters["doc_type"] = "report" | |
| elif any(kw in query_lower for kw in ["protocol", "procedure", "standard"]): | |
| filters["doc_type"] = "protocol" | |
| elif any(kw in query_lower for kw in ["paper", "research", "study"]): | |
| filters["doc_type"] = "paper" | |
| logger.info(f"Inferred filters: {filters}") | |
| return filters | |
| def retrieve( | |
| self, | |
| query: str, | |
| n_results: int = 5, | |
| level1: Optional[str] = None, | |
| level2: Optional[str] = None, | |
| level3: Optional[str] = None, | |
| doc_type: Optional[str] = None, | |
| auto_infer: bool = True | |
| ) -> Tuple[List[Dict[str, Any]], float, Dict[str, Optional[str]]]: | |
| """ | |
| Retrieve relevant documents with hierarchical filtering. | |
| Args: | |
| query: Search query | |
| n_results: Number of results to retrieve | |
| level1: Domain filter | |
| level2: Section filter | |
| level3: Topic filter | |
| doc_type: Document type filter | |
| auto_infer: Whether to auto-infer filters from query | |
| Returns: | |
| Tuple of (results, retrieval_time, applied_filters) | |
| """ | |
| # Auto-infer filters if enabled and no explicit filters provided | |
| if auto_infer and not any([level1, level2, level3, doc_type]): | |
| inferred = self.infer_hierarchy_from_query(query) | |
| level1 = level1 or inferred["level1"] | |
| level2 = level2 or inferred["level2"] | |
| level3 = level3 or inferred["level3"] | |
| doc_type = doc_type or inferred["doc_type"] | |
| applied_filters = { | |
| "level1": level1, | |
| "level2": level2, | |
| "level3": level3, | |
| "doc_type": doc_type | |
| } | |
| start_time = time.time() | |
| results = self.vector_store.search_with_hierarchy( | |
| query=query, | |
| n_results=n_results, | |
| level1=level1, | |
| level2=level2, | |
| level3=level3, | |
| doc_type=doc_type | |
| ) | |
| retrieval_time = time.time() - start_time | |
| logger.info(f"Retrieved {len(results)} documents with filters in {retrieval_time:.3f}s. Filters: {applied_filters}") | |
| return results, retrieval_time, applied_filters | |
| def generate( | |
| self, | |
| query: str, | |
| contexts: List[str], | |
| max_tokens: int = 500 | |
| ) -> Tuple[str, float]: | |
| """ | |
| Generate answer using LLM with retry logic. | |
| Args: | |
| query: User query | |
| contexts: Retrieved context documents | |
| max_tokens: Maximum tokens in response | |
| Returns: | |
| Tuple of (answer, generation_time) | |
| """ | |
| # Build prompt | |
| context_text = "\n\n".join([f"Context {i+1}:\n{ctx}" for i, ctx in enumerate(contexts)]) | |
| prompt = f"""Based on the following context documents, please answer the question. | |
| {context_text} | |
| Question: {query} | |
| Answer:""" | |
| start_time = time.time() | |
| try: | |
| response = self.client.chat.completions.create( | |
| model=self.llm_model, | |
| messages=[ | |
| {"role": "system", "content": "You are a helpful assistant that answers questions based on provided context."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| max_tokens=max_tokens, | |
| temperature=0.3 | |
| ) | |
| answer = response.choices[0].message.content | |
| generation_time = time.time() - start_time | |
| logger.info(f"Generated answer in {generation_time:.3f}s") | |
| return answer, generation_time | |
| except openai.AuthenticationError as e: | |
| logger.error(f"Authentication failed: {str(e)}") | |
| return "❌ **Authentication Error**: Invalid OpenAI API key. Please check your credentials.", 0 | |
| except openai.RateLimitError as e: | |
| logger.error(f"Rate limit exceeded: {str(e)}") | |
| return "⚠️ **Rate Limit Exceeded**: Too many requests. Please wait a moment and try again.", 0 | |
| except openai.APITimeoutError as e: | |
| logger.error(f"API timeout: {str(e)}") | |
| return "⏱️ **Timeout Error**: Request took too long. Please try again.", 0 | |
| except openai.APIConnectionError as e: | |
| logger.error(f"Connection error: {str(e)}") | |
| return "🌐 **Connection Error**: Unable to reach OpenAI API. Check your connection.", 0 | |
| except Exception as e: | |
| logger.error(f"Unexpected error: {str(e)}") | |
| return f"❌ **Error**: {str(e)}", 0 | |
| def query( | |
| self, | |
| query: str, | |
| n_results: int = 5, | |
| max_tokens: int = 500, | |
| level1: Optional[str] = None, | |
| level2: Optional[str] = None, | |
| level3: Optional[str] = None, | |
| doc_type: Optional[str] = None, | |
| auto_infer: bool = True | |
| ) -> Dict[str, Any]: | |
| """ | |
| Complete Hierarchical RAG pipeline: filter + retrieve + generate. | |
| Args: | |
| query: User query | |
| n_results: Number of documents to retrieve | |
| max_tokens: Maximum tokens in response | |
| level1: Domain filter | |
| level2: Section filter | |
| level3: Topic filter | |
| doc_type: Document type filter | |
| auto_infer: Whether to auto-infer filters from query | |
| Returns: | |
| Dictionary with answer, contexts, filters, and timing info | |
| """ | |
| # Retrieve with hierarchy | |
| results, retrieval_time, applied_filters = self.retrieve( | |
| query=query, | |
| n_results=n_results, | |
| level1=level1, | |
| level2=level2, | |
| level3=level3, | |
| doc_type=doc_type, | |
| auto_infer=auto_infer | |
| ) | |
| # Extract contexts | |
| contexts = [r["document"] for r in results] | |
| # Generate | |
| answer, generation_time = self.generate(query, contexts, max_tokens) | |
| total_time = retrieval_time + generation_time | |
| logger.info(f"Hier-RAG query completed in {total_time:.3f}s (retrieval: {retrieval_time:.3f}s, generation: {generation_time:.3f}s)") | |
| return { | |
| "query": query, | |
| "answer": answer, | |
| "contexts": results, | |
| "applied_filters": applied_filters, | |
| "retrieval_time": retrieval_time, | |
| "generation_time": generation_time, | |
| "total_time": total_time, | |
| "pipeline": "Hier-RAG" | |
| } | |
| class RAGComparator: | |
| """Compare Base-RAG and Hier-RAG side-by-side.""" | |
| def __init__( | |
| self, | |
| vector_store: VectorStore, | |
| llm_model: str = "gpt-3.5-turbo", | |
| api_key: Optional[str] = None | |
| ): | |
| """ | |
| Initialize RAG comparator. | |
| Args: | |
| vector_store: Vector store instance | |
| llm_model: OpenAI model name | |
| api_key: OpenAI API key | |
| """ | |
| self.base_rag = BaseRAG(vector_store, llm_model, api_key) | |
| self.hier_rag = HierarchicalRAG(vector_store, llm_model, api_key) | |
| def compare( | |
| self, | |
| query: str, | |
| n_results: int = 5, | |
| max_tokens: int = 500, | |
| level1: Optional[str] = None, | |
| level2: Optional[str] = None, | |
| level3: Optional[str] = None, | |
| doc_type: Optional[str] = None, | |
| auto_infer: bool = True | |
| ) -> Dict[str, Any]: | |
| """ | |
| Run both pipelines and compare results. | |
| Args: | |
| query: User query | |
| n_results: Number of documents to retrieve | |
| max_tokens: Maximum tokens in response | |
| level1: Domain filter (Hier-RAG only) | |
| level2: Section filter (Hier-RAG only) | |
| level3: Topic filter (Hier-RAG only) | |
| doc_type: Document type filter (Hier-RAG only) | |
| auto_infer: Whether to auto-infer filters (Hier-RAG only) | |
| Returns: | |
| Dictionary with results from both pipelines | |
| """ | |
| logger.info(f"Comparing pipelines for query: {query}") | |
| # Run Base-RAG | |
| base_results = self.base_rag.query(query, n_results, max_tokens) | |
| # Run Hier-RAG | |
| hier_results = self.hier_rag.query( | |
| query=query, | |
| n_results=n_results, | |
| max_tokens=max_tokens, | |
| level1=level1, | |
| level2=level2, | |
| level3=level3, | |
| doc_type=doc_type, | |
| auto_infer=auto_infer | |
| ) | |
| # Calculate speedup | |
| speedup = base_results["total_time"] / hier_results["total_time"] if hier_results["total_time"] > 0 else 0 | |
| logger.info(f"Comparison complete. Speedup: {speedup:.2f}x") | |
| return { | |
| "query": query, | |
| "base_rag": base_results, | |
| "hier_rag": hier_results, | |
| "speedup": speedup | |
| } |