langgraph-rag-agent / src /reflection.py
Harsh-1132's picture
hf
a77376b
"""
Reflection module for evaluating answer quality and relevance.
Provides self-evaluation mechanisms for generated answers.
"""
from typing import Dict, Any, Optional, List
from llm_utils import LLMHandler
import re
class ReflectionEvaluator:
"""Evaluates the quality and relevance of generated answers."""
def __init__(
self,
llm_handler: Optional[LLMHandler] = None,
use_llm_reflection: bool = True
):
"""
Initialize the reflection evaluator.
Args:
llm_handler: LLM handler for LLM-based reflection
use_llm_reflection: Whether to use LLM or heuristic evaluation
"""
self.llm_handler = llm_handler
self.use_llm_reflection = use_llm_reflection and llm_handler is not None
if self.use_llm_reflection:
print("✓ Reflection evaluator initialized (LLM-based)")
else:
print("✓ Reflection evaluator initialized (Heuristic-based)")
def evaluate(
self,
query: str,
answer: str,
context: str,
retrieved_chunks: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""
Evaluate the generated answer.
Args:
query: Original user query
answer: Generated answer
context: Retrieved context used for generation
retrieved_chunks: List of retrieved document chunks
Returns:
Evaluation result dictionary with score and reasoning
"""
print("\n" + "="*60)
print("🔍 REFLECTION: Evaluating Answer Quality")
print("="*60 + "\n")
if self.use_llm_reflection:
result = self._llm_based_evaluation(query, answer, context)
else:
result = self._heuristic_evaluation(query, answer, retrieved_chunks)
# Print evaluation results
print(f"Relevance: {result['relevance']}")
print(f"Score: {result['score']:.2f}/1.0")
print(f"Reasoning: {result['reasoning']}")
# Add recommendation
if result['score'] >= 0.7:
result['recommendation'] = "ACCEPT"
result['action'] = "Answer is satisfactory"
elif result['score'] >= 0.4:
result['recommendation'] = "PARTIAL"
result['action'] = "Answer is partially relevant, may need refinement"
else:
result['recommendation'] = "REJECT"
result['action'] = "Answer is not relevant, should be regenerated"
print(f"Recommendation: {result['recommendation']}")
print(f"Action: {result['action']}")
print("\n" + "="*60 + "\n")
return result
def _llm_based_evaluation(
self,
query: str,
answer: str,
context: str
) -> Dict[str, Any]:
"""
Use LLM to evaluate answer quality.
Args:
query: Original query
answer: Generated answer
context: Retrieved context
Returns:
Evaluation result dictionary
"""
evaluation_prompt = f"""You are an expert evaluator assessing the quality of an AI-generated answer.
**Original Question:**
{query}
**Retrieved Context:**
{context}
**Generated Answer:**
{answer}
**Task:**
Evaluate the answer based on the following criteria:
1. Relevance: Does the answer address the question?
2. Accuracy: Is the answer consistent with the provided context?
3. Completeness: Does the answer fully address the question?
4. Clarity: Is the answer clear and well-structured?
Provide your evaluation in the following format:
RELEVANCE: [Relevant/Partially Relevant/Irrelevant]
SCORE: [0.0-1.0]
REASONING: [Your detailed reasoning]
Be concise but thorough in your reasoning."""
system_message = "You are a critical evaluator of AI-generated answers. Be objective and precise."
evaluation_response = self.llm_handler.generate(
evaluation_prompt,
system_message
)
# Parse the response
relevance = self._extract_field(evaluation_response, "RELEVANCE", "Partially Relevant")
score_str = self._extract_field(evaluation_response, "SCORE", "0.5")
reasoning = self._extract_field(evaluation_response, "REASONING", evaluation_response)
# Convert score to float
try:
score = float(score_str)
score = max(0.0, min(1.0, score)) # Clamp between 0 and 1
except:
score = 0.5 # Default score if parsing fails
return {
"relevance": relevance,
"score": score,
"reasoning": reasoning,
"method": "llm"
}
def _heuristic_evaluation(
self,
query: str,
answer: str,
retrieved_chunks: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""
Use heuristic methods to evaluate answer quality.
Args:
query: Original query
answer: Generated answer
retrieved_chunks: Retrieved document chunks
Returns:
Evaluation result dictionary
"""
score_components = []
reasoning_parts = []
# 1. Length check (answer should not be too short or empty)
answer_length = len(answer.strip())
if answer_length == 0:
length_score = 0.0
reasoning_parts.append("Answer is empty")
elif answer_length < 20:
length_score = 0.3
reasoning_parts.append("Answer is very short")
elif answer_length < 50:
length_score = 0.6
reasoning_parts.append("Answer is somewhat brief")
else:
length_score = 1.0
reasoning_parts.append("Answer has adequate length")
score_components.append(("length", length_score, 0.2))
# 2. Query term coverage (check if key query terms appear in answer)
query_terms = set(re.findall(r'\b\w+\b', query.lower()))
# Remove common stop words
stop_words = {'what', 'is', 'are', 'the', 'a', 'an', 'how', 'why', 'when', 'where', 'which', 'who', 'does', 'do', 'can', 'could', 'would', 'should', 'about', 'in', 'on', 'for', 'to', 'of'}
query_terms = query_terms - stop_words
answer_lower = answer.lower()
matched_terms = sum(1 for term in query_terms if term in answer_lower)
if len(query_terms) > 0:
term_coverage_score = matched_terms / len(query_terms)
reasoning_parts.append(f"Query term coverage: {matched_terms}/{len(query_terms)} key terms")
else:
term_coverage_score = 0.5
reasoning_parts.append("Unable to extract key terms from query")
score_components.append(("term_coverage", term_coverage_score, 0.3))
# 3. Context relevance (check if answer references context)
if retrieved_chunks:
context_snippets = [chunk['content'][:100].lower() for chunk in retrieved_chunks]
context_overlap = 0
for snippet in context_snippets:
# Check for shared phrases (3+ words)
snippet_words = snippet.split()
for i in range(len(snippet_words) - 2):
phrase = ' '.join(snippet_words[i:i+3])
if phrase in answer_lower:
context_overlap += 1
if context_overlap >= 3:
context_score = 1.0
reasoning_parts.append(f"Strong context alignment (overlap: {context_overlap})")
elif context_overlap >= 1:
context_score = 0.7
reasoning_parts.append(f"Moderate context alignment (overlap: {context_overlap})")
else:
context_score = 0.4
reasoning_parts.append(f"Weak context alignment (overlap: {context_overlap})")
else:
context_score = 0.3
reasoning_parts.append("No context retrieved")
score_components.append(("context_relevance", context_score, 0.3))
# 4. Answer completeness (checks for phrases indicating incomplete answers)
incomplete_phrases = [
"i don't know", "cannot answer", "no information",
"not sure", "unclear", "unable to determine"
]
has_incomplete_phrase = any(phrase in answer_lower for phrase in incomplete_phrases)
if has_incomplete_phrase:
completeness_score = 0.3
reasoning_parts.append("Answer contains phrases indicating uncertainty")
else:
completeness_score = 1.0
reasoning_parts.append("Answer appears complete and confident")
score_components.append(("completeness", completeness_score, 0.2))
# Calculate weighted score
total_score = sum(score * weight for _, score, weight in score_components)
# Determine relevance category
if total_score >= 0.7:
relevance = "Relevant"
elif total_score >= 0.4:
relevance = "Partially Relevant"
else:
relevance = "Irrelevant"
# Combine reasoning
reasoning = "; ".join(reasoning_parts)
return {
"relevance": relevance,
"score": total_score,
"reasoning": reasoning,
"score_breakdown": {name: score for name, score, _ in score_components},
"method": "heuristic"
}
def _extract_field(
self,
text: str,
field_name: str,
default: str
) -> str:
"""
Extract a field value from structured text.
Args:
text: Source text
field_name: Field name to extract
default: Default value if not found
Returns:
Extracted field value
"""
pattern = rf"{field_name}:\s*(.+?)(?:\n|$)"
match = re.search(pattern, text, re.IGNORECASE)
if match:
return match.group(1).strip()
return default
def create_reflection_evaluator(
llm_handler: Optional[LLMHandler] = None,
use_llm_reflection: bool = False
) -> ReflectionEvaluator:
"""
Create and return a reflection evaluator instance.
Args:
llm_handler: Optional LLM handler for LLM-based reflection
use_llm_reflection: Whether to use LLM-based reflection
Returns:
ReflectionEvaluator instance
"""
return ReflectionEvaluator(llm_handler, use_llm_reflection)