""" Self-Reflection RAG - Advanced RAG Pattern RAG system with self-reflection and correction capabilities. """ import asyncio import logging from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple import time from ..config.pipeline_configs.rag_pipeline import RAGPipeline, RAGResponse logger = logging.getLogger(__name__) @dataclass class ReflectionResult: """Result from self-reflection process.""" needs_correction: bool confidence_improvement: float corrected_answer: Optional[str] = None reasoning: Optional[str] = None issues_found: List[str] = field(default_factory=list) @dataclass class ReflectionRound: """Single round of reflection.""" round_number: int original_query: str original_answer: str original_sources: List[Dict[str, Any]] reflection_result: ReflectionResult timestamp: float = field(default_factory=time.time) class SelfReflectionRAG: """RAG system with self-reflection and correction capabilities.""" def __init__(self, base_pipeline: RAGPipeline, config: Optional[Dict[str, Any]] = None): self.pipeline = base_pipeline self.config = config or {} # Reflection settings self.max_reflection_rounds = self.config.get("max_reflection_rounds", 2) self.confidence_threshold = self.config.get("confidence_threshold", 0.7) self.enable_fact_checking = self.config.get("enable_fact_checking", True) self.enable_coherence_checking = self.config.get("enable_coherence_checking", True) self.enable_source_verification = self.config.get("enable_source_verification", True) # LLM settings for reflection self.reflection_model = self.config.get("reflection_model", "gpt-4") self.correction_model = self.config.get("correction_model", "gpt-4") async def query_with_reflection( self, query: str, max_rounds: Optional[int] = None ) -> Dict[str, Any]: """Execute query with self-reflection and correction.""" start_time = time.time() # Initial query reflection_rounds = [] current_query = query current_answer = None current_sources = None total_confidence_improvement = 0.0 max_rounds = min(max_rounds or self.max_reflection_rounds, self.max_reflection_rounds) for round_num in range(max_rounds): logger.info(f"Reflection round {round_num + 1}/{max_rounds}") # Execute query response = await self.pipeline.query( query=current_query, top_k=5, include_sources=True, include_confidence=True ) current_answer = response.answer current_sources = response.sources current_confidence = response.confidence # Perform self-reflection if round_num < max_rounds - 1: # Don't reflect on final round reflection_result = await self._reflect_on_answer( query, current_answer, current_sources, reflection_rounds ) # Decide if correction is needed if reflection_result.needs_correction and reflection_result.corrected_answer: current_query = reflection_result.corrected_answer total_confidence_improvement += reflection_result.confidence_improvement # Create reflection round record reflection_round = ReflectionRound( round_number=round_num + 1, original_query=query, original_answer=current_answer, original_sources=current_sources, reflection_result=reflection_result, ) reflection_rounds.append(reflection_round) else: # No correction needed, this is our final answer reflection_round = ReflectionRound( round_number=round_num + 1, original_query=query, original_answer=current_answer, original_sources=current_sources, reflection_result=reflection_result, ) reflection_rounds.append(reflection_round) break else: # Final round reflection_round = ReflectionRound( round_number=round_num + 1, original_query=query, original_answer=current_answer, original_sources=current_sources, reflection_result=ReflectionResult(needs_correction=False), ) reflection_rounds.append(reflection_round) total_time = (time.time() - start_time) * 1000 return { "original_query": query, "final_answer": current_answer, "final_sources": current_sources, "final_confidence": current_confidence, "reflection_rounds": reflection_rounds, "total_rounds": len(reflection_rounds), "total_confidence_improvement": total_confidence_improvement, "total_time_ms": total_time, "self_corrected": total_confidence_improvement > 0, "metadata": { "max_reflection_rounds": max_rounds, "reflection_threshold": self.confidence_threshold, }, } async def _reflect_on_answer( self, query: str, answer: str, sources: List[Dict[str, Any]], previous_rounds: List[ReflectionRound], ) -> ReflectionResult: """Perform self-reflection on the answer.""" try: # Analyze different aspects of the answer issues_found = [] needs_correction = False corrected_answer = None # 1. Confidence analysis confidence_issue = await self._analyze_confidence(answer, sources) if confidence_issue: issues_found.extend(confidence_issue) # 2. Fact checking if self.enable_fact_checking: fact_issues = await self._check_factual_accuracy(answer, sources) issues_found.extend(fact_issues) # 3. Coherence analysis if self.enable_coherence_checking: coherence_issues = await self._check_coherence(query, answer) issues_found.extend(coherence_issues) # 4. Source verification if self.enable_source_verification: source_issues = await self._verify_sources(answer, sources) issues_found.extend(source_issues) # Determine if correction is needed if issues_found and self.confidence_threshold > 0.0: avg_confidence = await self._estimate_confidence(answer, sources) if avg_confidence < self.confidence_threshold: needs_correction = True corrected_answer = await self._generate_correction(query, answer, issues_found) reasoning = self._generate_reflection_reasoning(issues_found, needs_correction) confidence_improvement = 0.0 if corrected_answer: confidence_improvement = await self._estimate_confidence_improvement( answer, corrected_answer ) return ReflectionResult( needs_correction=needs_correction, confidence_improvement=confidence_improvement, corrected_answer=corrected_answer, reasoning=reasoning, issues_found=issues_found, ) except Exception as e: logger.error(f"Error in self-reflection: {e}") return ReflectionResult( needs_correction=False, confidence_improvement=0.0, reasoning=f"Reflection failed: {str(e)}", ) async def _analyze_confidence(self, answer: str, sources: List[Dict[str, Any]]) -> List[str]: """Analyze confidence of the answer.""" issues = [] # Check for hedge words hedge_phrases = [ "might be", "could be", "possibly", "probably", "seems like", "I think", "it appears", "roughly", "approximately", ] lower_answer = answer.lower() for phrase in hedge_phrases: if phrase in lower_answer: issues.append(f"Contains hedge phrase: '{phrase}'") # Check for uncertainty indicators uncertainty_phrases = [ "I'm not sure", "I cannot confirm", "there is insufficient information", "based on limited data", "this is speculation", ] for phrase in uncertainty_phrases: if phrase in lower_answer: issues.append(f"Contains uncertainty: '{phrase}'") # Check source quality impact on confidence if sources: source_scores = [source.get("score", 0.0) for source in sources] avg_source_score = sum(source_scores) / len(source_scores) if avg_source_score < 0.6: issues.append(f"Low source relevance: {avg_source_score:.2f}") return issues async def _check_factual_accuracy( self, answer: str, sources: List[Dict[str, Any]] ) -> List[str]: """Check factual accuracy against sources.""" issues = [] if not sources: return ["No sources provided for fact-checking"] # Extract key claims from answer claims = self._extract_key_claims(answer) # Check each claim against sources for claim in claims: is_supported = await self._verify_claim_against_sources(claim, sources) if not is_supported: issues.append(f"Unsupported claim: {claim[:100]}...") return issues async def _check_coherence(self, query: str, answer: str) -> List[str]: """Check answer coherence.""" issues = [] # Check for contradictions within the answer sentences = answer.split(".") for i, sentence in enumerate(sentences): sentence = sentence.strip() if len(sentence) < 10: continue # Check for contradictions with previous sentences for j, prev_sentence in enumerate(sentences[:i]): prev_sentence = prev_sentence.strip() if len(prev_sentence) < 10: continue contradiction = await self._detect_contradiction(prev_sentence, sentence) if contradiction: issues.append( f"Contradiction: '{prev_sentence[:50]}...' vs '{sentence[:50]}...'" ) # Check answer relevance to query query_words = set(query.lower().split()) answer_words = set(answer.lower().split()) overlap = len(query_words & answer_words) / len(query_words) if query_words else 0 if overlap < 0.3: # Less than 30% word overlap issues.append(f"Low query relevance: {overlap:.1%}") return issues async def _verify_sources(self, answer: str, sources: List[Dict[str, Any]]) -> List[str]: """Verify source quality and relevance.""" issues = [] # Check source diversity source_ids = set(source.get("document_id", "") for source in sources) if len(source_ids) < 2 and len(sources) > 1: issues.append("Low source diversity") # Check source scores for source in sources: score = source.get("score", 0.0) if score < 0.3: issues.append(f"Low relevance source: {source.get('title', 'Unknown')}") # Check for recent sources # (This would require timestamp information in sources) return issues async def _generate_correction( self, query: str, original_answer: str, issues: List[str] ) -> str: """Generate corrected answer.""" try: # Create correction prompt issues_text = "\n".join(f"- {issue}" for issue in issues) correction_prompt = f"""The following answer has identified issues: Original Query: {query} Original Answer: {original_answer} Issues Found: {issues_text} Please provide a corrected, more accurate and confident answer that addresses these issues. Be more specific, better supported by sources, and more confident in your response.""" from openai import OpenAI client = OpenAI() response = client.chat.completions.create( model=self.correction_model, messages=[ { "role": "system", "content": "You are an expert at correcting and improving AI-generated answers to be more accurate and confident.", }, {"role": "user", "content": correction_prompt}, ], temperature=0.1, max_tokens=800, ) corrected_answer = response.choices[0].message.content.strip() logger.info(f"Generated correction for answer") return corrected_answer except Exception as e: logger.error(f"Error generating correction: {e}") return original_answer def _extract_key_claims(self, text: str) -> List[str]: """Extract key claims from text.""" # Simple claim extraction - split by sentences and filter sentences = [s.strip() for s in text.split(".") if len(s.strip()) > 15] return sentences async def _verify_claim_against_sources( self, claim: str, sources: List[Dict[str, Any]] ) -> bool: """Verify if a claim is supported by sources.""" claim_words = set(claim.lower().split()) for source in sources: source_text = source.get("content", "").lower() source_words = set(source_text.split()) # Check for significant overlap overlap = len(claim_words & source_words) / len(claim_words) if claim_words else 0 if overlap >= 0.5: # 50% overlap threshold return True return False async def _detect_contradiction(self, sentence1: str, sentence2: str) -> bool: """Detect contradiction between two sentences.""" # Simple contradiction patterns contradictions = [ ("not", ""), ("never", "always"), ("no", "yes"), ("false", "true"), ("incorrect", "correct"), ("cannot", "can"), ("impossible", "possible"), ] words1 = set(sentence1.lower().split()) words2 = set(sentence2.lower().split()) for neg, pos in contradictions: if (neg in words1 and pos in words2) or (pos in words1 and neg in words2): return True return False async def _estimate_confidence(self, answer: str, sources: List[Dict[str, Any]]) -> float: """Estimate confidence in the answer.""" # Base confidence on source quality if sources: source_scores = [source.get("score", 0.0) for source in sources] source_confidence = sum(source_scores) / len(source_scores) else: source_confidence = 0.3 # Low confidence without sources # Adjust based on answer characteristics answer_length = len(answer.split()) # Long answers might be more comprehensive length_factor = min(answer_length / 100, 1.2) # Hedge words reduce confidence hedge_words = ["might", "could", "possibly", "probably"] hedge_count = sum(1 for word in hedge_words if word in answer.lower()) hedge_penalty = hedge_count * 0.1 estimated_confidence = source_confidence * length_factor - hedge_penalty return max(0.0, min(1.0, estimated_confidence)) async def _estimate_confidence_improvement( self, original_answer: str, corrected_answer: str ) -> float: """Estimate confidence improvement from correction.""" # Simple heuristic based on correction characteristics if corrected_answer == original_answer: return 0.0 # Corrections that add specificity and citations tend to improve confidence original_length = len(original_answer.split()) corrected_length = len(corrected_answer.split()) if corrected_length > original_length * 1.2: # Significantly longer return 0.3 elif corrected_length > original_length * 1.1: return 0.2 elif corrected_length > original_length: return 0.1 return 0.05 def _generate_reflection_reasoning( self, issues_found: List[str], needs_correction: bool ) -> str: """Generate reasoning for reflection decision.""" if not issues_found: return "No significant issues found in the answer." reasoning_parts = ["Analysis identified the following issues:"] reasoning_parts.extend(f"• {issue}" for issue in issues_found[:5]) if needs_correction: reasoning_parts.append("Correction is recommended to improve accuracy and confidence.") else: reasoning_parts.append("No correction needed at this time.") return " ".join(reasoning_parts) async def get_reflection_stats(self, session_id: Optional[str] = None) -> Dict[str, Any]: """Get statistics about reflection performance.""" # This would connect to a metrics system in a full implementation return { "session_id": session_id, "max_reflection_rounds": self.max_reflection_rounds, "confidence_threshold": self.confidence_threshold, "features_enabled": { "fact_checking": self.enable_fact_checking, "coherence_checking": self.enable_coherence_checking, "source_verification": self.enable_source_verification, }, }