msse-ai-engineering / src /rag /rag_pipeline.py
sethmcknight
Refactor test cases for improved readability and consistency
159faf0
"""
RAG Pipeline - Core RAG Functionality
This module orchestrates the complete RAG (Retrieval-Augmented Generation) pipeline,
combining semantic search, context management, and LLM generation.
"""
import logging
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from src.llm.context_manager import ContextConfig, ContextManager
from src.llm.llm_service import LLMResponse, LLMService
from src.llm.prompt_templates import PromptTemplates
# Import our modules
from src.search.search_service import SearchService
logger = logging.getLogger(__name__)
@dataclass
class RAGConfig:
"""Configuration for RAG pipeline."""
max_context_length: int = 3000
search_top_k: int = 10
search_threshold: float = 0.0 # No threshold filtering at search level
min_similarity_for_answer: float = 0.2 # Threshold for normalized distance similarity
max_response_length: int = 1000
enable_citation_validation: bool = True
@dataclass
class RAGResponse:
"""Response from RAG pipeline with metadata."""
answer: str
sources: List[Dict[str, Any]]
confidence: float
processing_time: float
llm_provider: str
llm_model: str
context_length: int
search_results_count: int
success: bool
error_message: Optional[str] = None
class RAGPipeline:
"""
Complete RAG pipeline orchestrating retrieval and generation.
Combines:
- Semantic search for context retrieval
- Context optimization and management
- LLM-based response generation
- Citation validation and formatting
"""
def __init__(
self,
search_service: SearchService,
llm_service: LLMService,
config: Optional[RAGConfig] = None,
):
"""
Initialize RAG pipeline with required services.
Args:
search_service: Configured SearchService instance
llm_service: Configured LLMService instance
config: RAG configuration, uses defaults if None
"""
self.search_service = search_service
self.llm_service = llm_service
self.config = config or RAGConfig()
# Initialize context manager with matching config
context_config = ContextConfig(
max_context_length=self.config.max_context_length,
max_results=self.config.search_top_k,
min_similarity=self.config.search_threshold,
)
self.context_manager = ContextManager(context_config)
# Initialize prompt templates
self.prompt_templates = PromptTemplates()
logger.info("RAGPipeline initialized successfully")
def generate_answer(self, question: str) -> RAGResponse:
"""
Generate answer to question using RAG pipeline.
Args:
question: User's question about corporate policies
Returns:
RAGResponse with answer and metadata
"""
start_time = time.time()
try:
# Step 1: Retrieve relevant context
logger.debug(f"Starting RAG pipeline for question: {question[:100]}...")
search_results = self._retrieve_context(question)
if not search_results:
return self._create_no_context_response(question, start_time)
# Step 2: Prepare and optimize context
context, filtered_results = self.context_manager.prepare_context(search_results, question)
# Step 3: Check if we have sufficient context
quality_metrics = self.context_manager.validate_context_quality(
context, question, self.config.min_similarity_for_answer
)
if not quality_metrics["passes_validation"]:
return self._create_insufficient_context_response(question, filtered_results, start_time)
# Step 4: Generate response using LLM
llm_response = self._generate_llm_response(question, context)
if not llm_response.success:
return self._create_llm_error_response(question, llm_response.error_message, start_time)
# Step 5: Process and validate response
processed_response = self._process_response(llm_response.content, filtered_results)
processing_time = time.time() - start_time
return RAGResponse(
answer=processed_response,
sources=self._format_sources(filtered_results),
confidence=self._calculate_confidence(quality_metrics, llm_response),
processing_time=processing_time,
llm_provider=llm_response.provider,
llm_model=llm_response.model,
context_length=len(context),
search_results_count=len(search_results),
success=True,
)
except Exception as e:
logger.error(f"RAG pipeline error: {e}")
return RAGResponse(
answer=(
"I apologize, but I encountered an error processing your question. "
"Please try again or contact support."
),
sources=[],
confidence=0.0,
processing_time=time.time() - start_time,
llm_provider="none",
llm_model="none",
context_length=0,
search_results_count=0,
success=False,
error_message=str(e),
)
def _retrieve_context(self, question: str) -> List[Dict[str, Any]]:
"""Retrieve relevant context using search service."""
try:
results = self.search_service.search(
query=question,
top_k=self.config.search_top_k,
threshold=self.config.search_threshold,
)
logger.debug(f"Retrieved {len(results)} search results")
return results
except Exception as e:
logger.error(f"Context retrieval error: {e}")
return []
def _generate_llm_response(self, question: str, context: str) -> LLMResponse:
"""Generate response using LLM with formatted prompt."""
template = self.prompt_templates.get_policy_qa_template()
# Format the prompt
formatted_prompt = template.user_template.format(question=question, context=context)
# Add system prompt (if LLM service supports it in future)
full_prompt = f"{template.system_prompt}\n\n{formatted_prompt}"
return self.llm_service.generate_response(full_prompt)
def _process_response(self, raw_response: str, search_results: List[Dict[str, Any]]) -> str:
"""Process and validate LLM response."""
# Ensure citations are present
response_with_citations = self.prompt_templates.add_fallback_citations(raw_response, search_results)
# Validate citations if enabled
if self.config.enable_citation_validation:
available_sources = [result.get("metadata", {}).get("filename", "") for result in search_results]
citation_validation = self.prompt_templates.validate_citations(response_with_citations, available_sources)
# Log any invalid citations
invalid_citations = [citation for citation, valid in citation_validation.items() if not valid]
if invalid_citations:
logger.warning(f"Invalid citations detected: {invalid_citations}")
# Truncate if too long
if len(response_with_citations) > self.config.max_response_length:
truncated = response_with_citations[: self.config.max_response_length - 3] + "..."
logger.warning(f"Response truncated from {len(response_with_citations)} " f"to {len(truncated)} characters")
return truncated
return response_with_citations
def _format_sources(self, search_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Format search results for response metadata."""
sources = []
for result in search_results:
metadata = result.get("metadata", {})
sources.append(
{
"document": metadata.get("filename", "unknown"),
"chunk_id": result.get("chunk_id", ""),
"relevance_score": result.get("similarity_score", 0.0),
"excerpt": (
result.get("content", "")[:200] + "..."
if len(result.get("content", "")) > 200
else result.get("content", "")
),
}
)
return sources
def _calculate_confidence(self, quality_metrics: Dict[str, Any], llm_response: LLMResponse) -> float:
"""Calculate confidence score for the response."""
# Base confidence on context quality
context_confidence = quality_metrics.get("estimated_relevance", 0.0)
# Adjust based on LLM response time (faster might indicate more confidence)
time_factor = min(1.0, 10.0 / max(llm_response.response_time, 1.0))
# Combine factors
confidence = (context_confidence * 0.7) + (time_factor * 0.3)
return min(1.0, max(0.0, confidence))
def _create_no_context_response(self, question: str, start_time: float) -> RAGResponse:
"""Create response when no relevant context found."""
return RAGResponse(
answer=(
"I couldn't find any relevant information in our corporate policies "
"to answer your question. Please contact HR or check other company "
"resources for assistance."
),
sources=[],
confidence=0.0,
processing_time=time.time() - start_time,
llm_provider="none",
llm_model="none",
context_length=0,
search_results_count=0,
success=True, # This is a valid "no answer" response
)
def _create_insufficient_context_response(
self, question: str, results: List[Dict[str, Any]], start_time: float
) -> RAGResponse:
"""Create response when context quality is insufficient."""
return RAGResponse(
answer=(
"I found some potentially relevant information, but it doesn't provide "
"enough detail to fully answer your question. Please contact HR for "
"more specific guidance or rephrase your question."
),
sources=self._format_sources(results),
confidence=0.2,
processing_time=time.time() - start_time,
llm_provider="none",
llm_model="none",
context_length=0,
search_results_count=len(results),
success=True,
)
def _create_llm_error_response(self, question: str, error_message: str, start_time: float) -> RAGResponse:
"""Create response when LLM generation fails."""
return RAGResponse(
answer=(
"I apologize, but I'm currently unable to generate a response. "
"Please try again in a moment or contact support if the issue persists."
),
sources=[],
confidence=0.0,
processing_time=time.time() - start_time,
llm_provider="error",
llm_model="error",
context_length=0,
search_results_count=0,
success=False,
error_message=error_message,
)
def health_check(self) -> Dict[str, Any]:
"""
Perform health check on all pipeline components.
Returns:
Dictionary with component health status
"""
health_status = {"pipeline": "healthy", "components": {}}
try:
# Check search service
test_results = self.search_service.search("test query", top_k=1, threshold=0.0)
health_status["components"]["search_service"] = {
"status": "healthy",
"test_results_count": len(test_results),
}
except Exception as e:
health_status["components"]["search_service"] = {
"status": "unhealthy",
"error": str(e),
}
health_status["pipeline"] = "degraded"
try:
# Check LLM service
llm_health = self.llm_service.health_check()
health_status["components"]["llm_service"] = llm_health
# Pipeline is unhealthy if all LLM providers are down
healthy_providers = sum(
1 for provider_status in llm_health.values() if provider_status.get("status") == "healthy"
)
if healthy_providers == 0:
health_status["pipeline"] = "unhealthy"
except Exception as e:
health_status["components"]["llm_service"] = {
"status": "unhealthy",
"error": str(e),
}
health_status["pipeline"] = "unhealthy"
return health_status