Spaces:
Sleeping
Sleeping
| """ | |
| RAG Pipeline - Core RAG Functionality | |
| This module orchestrates the complete RAG (Retrieval-Augmented Generation) pipeline, | |
| combining semantic search, context management, and LLM generation. | |
| """ | |
| import logging | |
| import time | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Optional | |
| from src.llm.context_manager import ContextConfig, ContextManager | |
| from src.llm.llm_service import LLMResponse, LLMService | |
| from src.llm.prompt_templates import PromptTemplates | |
| # Import our modules | |
| from src.search.search_service import SearchService | |
| logger = logging.getLogger(__name__) | |
| class RAGConfig: | |
| """Configuration for RAG pipeline.""" | |
| max_context_length: int = 3000 | |
| search_top_k: int = 10 | |
| search_threshold: float = 0.0 # No threshold filtering at search level | |
| min_similarity_for_answer: float = 0.2 # Threshold for normalized distance similarity | |
| max_response_length: int = 1000 | |
| enable_citation_validation: bool = True | |
| class RAGResponse: | |
| """Response from RAG pipeline with metadata.""" | |
| answer: str | |
| sources: List[Dict[str, Any]] | |
| confidence: float | |
| processing_time: float | |
| llm_provider: str | |
| llm_model: str | |
| context_length: int | |
| search_results_count: int | |
| success: bool | |
| error_message: Optional[str] = None | |
| class RAGPipeline: | |
| """ | |
| Complete RAG pipeline orchestrating retrieval and generation. | |
| Combines: | |
| - Semantic search for context retrieval | |
| - Context optimization and management | |
| - LLM-based response generation | |
| - Citation validation and formatting | |
| """ | |
| def __init__( | |
| self, | |
| search_service: SearchService, | |
| llm_service: LLMService, | |
| config: Optional[RAGConfig] = None, | |
| ): | |
| """ | |
| Initialize RAG pipeline with required services. | |
| Args: | |
| search_service: Configured SearchService instance | |
| llm_service: Configured LLMService instance | |
| config: RAG configuration, uses defaults if None | |
| """ | |
| self.search_service = search_service | |
| self.llm_service = llm_service | |
| self.config = config or RAGConfig() | |
| # Initialize context manager with matching config | |
| context_config = ContextConfig( | |
| max_context_length=self.config.max_context_length, | |
| max_results=self.config.search_top_k, | |
| min_similarity=self.config.search_threshold, | |
| ) | |
| self.context_manager = ContextManager(context_config) | |
| # Initialize prompt templates | |
| self.prompt_templates = PromptTemplates() | |
| logger.info("RAGPipeline initialized successfully") | |
| def generate_answer(self, question: str) -> RAGResponse: | |
| """ | |
| Generate answer to question using RAG pipeline. | |
| Args: | |
| question: User's question about corporate policies | |
| Returns: | |
| RAGResponse with answer and metadata | |
| """ | |
| start_time = time.time() | |
| try: | |
| # Step 1: Retrieve relevant context | |
| logger.debug(f"Starting RAG pipeline for question: {question[:100]}...") | |
| search_results = self._retrieve_context(question) | |
| if not search_results: | |
| return self._create_no_context_response(question, start_time) | |
| # Step 2: Prepare and optimize context | |
| context, filtered_results = self.context_manager.prepare_context(search_results, question) | |
| # Step 3: Check if we have sufficient context | |
| quality_metrics = self.context_manager.validate_context_quality( | |
| context, question, self.config.min_similarity_for_answer | |
| ) | |
| if not quality_metrics["passes_validation"]: | |
| return self._create_insufficient_context_response(question, filtered_results, start_time) | |
| # Step 4: Generate response using LLM | |
| llm_response = self._generate_llm_response(question, context) | |
| if not llm_response.success: | |
| return self._create_llm_error_response(question, llm_response.error_message, start_time) | |
| # Step 5: Process and validate response | |
| processed_response = self._process_response(llm_response.content, filtered_results) | |
| processing_time = time.time() - start_time | |
| return RAGResponse( | |
| answer=processed_response, | |
| sources=self._format_sources(filtered_results), | |
| confidence=self._calculate_confidence(quality_metrics, llm_response), | |
| processing_time=processing_time, | |
| llm_provider=llm_response.provider, | |
| llm_model=llm_response.model, | |
| context_length=len(context), | |
| search_results_count=len(search_results), | |
| success=True, | |
| ) | |
| except Exception as e: | |
| logger.error(f"RAG pipeline error: {e}") | |
| return RAGResponse( | |
| answer=( | |
| "I apologize, but I encountered an error processing your question. " | |
| "Please try again or contact support." | |
| ), | |
| sources=[], | |
| confidence=0.0, | |
| processing_time=time.time() - start_time, | |
| llm_provider="none", | |
| llm_model="none", | |
| context_length=0, | |
| search_results_count=0, | |
| success=False, | |
| error_message=str(e), | |
| ) | |
| def _retrieve_context(self, question: str) -> List[Dict[str, Any]]: | |
| """Retrieve relevant context using search service.""" | |
| try: | |
| results = self.search_service.search( | |
| query=question, | |
| top_k=self.config.search_top_k, | |
| threshold=self.config.search_threshold, | |
| ) | |
| logger.debug(f"Retrieved {len(results)} search results") | |
| return results | |
| except Exception as e: | |
| logger.error(f"Context retrieval error: {e}") | |
| return [] | |
| def _generate_llm_response(self, question: str, context: str) -> LLMResponse: | |
| """Generate response using LLM with formatted prompt.""" | |
| template = self.prompt_templates.get_policy_qa_template() | |
| # Format the prompt | |
| formatted_prompt = template.user_template.format(question=question, context=context) | |
| # Add system prompt (if LLM service supports it in future) | |
| full_prompt = f"{template.system_prompt}\n\n{formatted_prompt}" | |
| return self.llm_service.generate_response(full_prompt) | |
| def _process_response(self, raw_response: str, search_results: List[Dict[str, Any]]) -> str: | |
| """Process and validate LLM response.""" | |
| # Ensure citations are present | |
| response_with_citations = self.prompt_templates.add_fallback_citations(raw_response, search_results) | |
| # Validate citations if enabled | |
| if self.config.enable_citation_validation: | |
| available_sources = [result.get("metadata", {}).get("filename", "") for result in search_results] | |
| citation_validation = self.prompt_templates.validate_citations(response_with_citations, available_sources) | |
| # Log any invalid citations | |
| invalid_citations = [citation for citation, valid in citation_validation.items() if not valid] | |
| if invalid_citations: | |
| logger.warning(f"Invalid citations detected: {invalid_citations}") | |
| # Truncate if too long | |
| if len(response_with_citations) > self.config.max_response_length: | |
| truncated = response_with_citations[: self.config.max_response_length - 3] + "..." | |
| logger.warning(f"Response truncated from {len(response_with_citations)} " f"to {len(truncated)} characters") | |
| return truncated | |
| return response_with_citations | |
| def _format_sources(self, search_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| """Format search results for response metadata.""" | |
| sources = [] | |
| for result in search_results: | |
| metadata = result.get("metadata", {}) | |
| sources.append( | |
| { | |
| "document": metadata.get("filename", "unknown"), | |
| "chunk_id": result.get("chunk_id", ""), | |
| "relevance_score": result.get("similarity_score", 0.0), | |
| "excerpt": ( | |
| result.get("content", "")[:200] + "..." | |
| if len(result.get("content", "")) > 200 | |
| else result.get("content", "") | |
| ), | |
| } | |
| ) | |
| return sources | |
| def _calculate_confidence(self, quality_metrics: Dict[str, Any], llm_response: LLMResponse) -> float: | |
| """Calculate confidence score for the response.""" | |
| # Base confidence on context quality | |
| context_confidence = quality_metrics.get("estimated_relevance", 0.0) | |
| # Adjust based on LLM response time (faster might indicate more confidence) | |
| time_factor = min(1.0, 10.0 / max(llm_response.response_time, 1.0)) | |
| # Combine factors | |
| confidence = (context_confidence * 0.7) + (time_factor * 0.3) | |
| return min(1.0, max(0.0, confidence)) | |
| def _create_no_context_response(self, question: str, start_time: float) -> RAGResponse: | |
| """Create response when no relevant context found.""" | |
| return RAGResponse( | |
| answer=( | |
| "I couldn't find any relevant information in our corporate policies " | |
| "to answer your question. Please contact HR or check other company " | |
| "resources for assistance." | |
| ), | |
| sources=[], | |
| confidence=0.0, | |
| processing_time=time.time() - start_time, | |
| llm_provider="none", | |
| llm_model="none", | |
| context_length=0, | |
| search_results_count=0, | |
| success=True, # This is a valid "no answer" response | |
| ) | |
| def _create_insufficient_context_response( | |
| self, question: str, results: List[Dict[str, Any]], start_time: float | |
| ) -> RAGResponse: | |
| """Create response when context quality is insufficient.""" | |
| return RAGResponse( | |
| answer=( | |
| "I found some potentially relevant information, but it doesn't provide " | |
| "enough detail to fully answer your question. Please contact HR for " | |
| "more specific guidance or rephrase your question." | |
| ), | |
| sources=self._format_sources(results), | |
| confidence=0.2, | |
| processing_time=time.time() - start_time, | |
| llm_provider="none", | |
| llm_model="none", | |
| context_length=0, | |
| search_results_count=len(results), | |
| success=True, | |
| ) | |
| def _create_llm_error_response(self, question: str, error_message: str, start_time: float) -> RAGResponse: | |
| """Create response when LLM generation fails.""" | |
| return RAGResponse( | |
| answer=( | |
| "I apologize, but I'm currently unable to generate a response. " | |
| "Please try again in a moment or contact support if the issue persists." | |
| ), | |
| sources=[], | |
| confidence=0.0, | |
| processing_time=time.time() - start_time, | |
| llm_provider="error", | |
| llm_model="error", | |
| context_length=0, | |
| search_results_count=0, | |
| success=False, | |
| error_message=error_message, | |
| ) | |
| def health_check(self) -> Dict[str, Any]: | |
| """ | |
| Perform health check on all pipeline components. | |
| Returns: | |
| Dictionary with component health status | |
| """ | |
| health_status = {"pipeline": "healthy", "components": {}} | |
| try: | |
| # Check search service | |
| test_results = self.search_service.search("test query", top_k=1, threshold=0.0) | |
| health_status["components"]["search_service"] = { | |
| "status": "healthy", | |
| "test_results_count": len(test_results), | |
| } | |
| except Exception as e: | |
| health_status["components"]["search_service"] = { | |
| "status": "unhealthy", | |
| "error": str(e), | |
| } | |
| health_status["pipeline"] = "degraded" | |
| try: | |
| # Check LLM service | |
| llm_health = self.llm_service.health_check() | |
| health_status["components"]["llm_service"] = llm_health | |
| # Pipeline is unhealthy if all LLM providers are down | |
| healthy_providers = sum( | |
| 1 for provider_status in llm_health.values() if provider_status.get("status") == "healthy" | |
| ) | |
| if healthy_providers == 0: | |
| health_status["pipeline"] = "unhealthy" | |
| except Exception as e: | |
| health_status["components"]["llm_service"] = { | |
| "status": "unhealthy", | |
| "error": str(e), | |
| } | |
| health_status["pipeline"] = "unhealthy" | |
| return health_status | |