RAG10 / trace_evaluator.py
Vivek Kadamati
Initial commit
ee444c0
"""TRACE evaluation metrics for RAG systems.
TRACE Metrics:
- uTilization: How well the system uses retrieved documents
- Relevance: Relevance of retrieved documents to the query
- Adherence: How well the response adheres to the retrieved context
- Completeness: How complete the response is in answering the query
"""
from typing import List, Dict, Optional
import numpy as np
from dataclasses import dataclass
import re
from collections import Counter
@dataclass
class TRACEScores:
"""Container for TRACE evaluation scores."""
utilization: float
relevance: float
adherence: float
completeness: float
def to_dict(self) -> Dict:
"""Convert to dictionary."""
return {
"utilization": self.utilization,
"relevance": self.relevance,
"adherence": self.adherence,
"completeness": self.completeness,
"average": self.average()
}
def average(self) -> float:
"""Calculate average score."""
return (self.utilization + self.relevance +
self.adherence + self.completeness) / 4
class TRACEEvaluator:
"""TRACE evaluation metrics for RAG systems."""
def __init__(self, llm_client=None):
"""Initialize TRACE evaluator.
Args:
llm_client: Optional LLM client for LLM-based evaluation
"""
self.llm_client = llm_client
def evaluate(
self,
query: str,
response: str,
retrieved_documents: List[str],
ground_truth: Optional[str] = None
) -> TRACEScores:
"""Evaluate a RAG response using TRACE metrics.
Args:
query: User query
response: Generated response
retrieved_documents: List of retrieved documents
ground_truth: Optional ground truth answer
Returns:
TRACEScores object
"""
utilization = self._compute_utilization(response, retrieved_documents)
relevance = self._compute_relevance(query, retrieved_documents)
adherence = self._compute_adherence(response, retrieved_documents)
completeness = self._compute_completeness(query, response, ground_truth)
return TRACEScores(
utilization=utilization,
relevance=relevance,
adherence=adherence,
completeness=completeness
)
def _compute_utilization(
self,
response: str,
retrieved_documents: List[str]
) -> float:
"""Compute utilization score.
Measures how well the system uses retrieved documents.
Score based on:
- Number of documents that contributed to the response
- Proportion of retrieved documents used
Args:
response: Generated response
retrieved_documents: List of retrieved documents
Returns:
Utilization score (0-1)
"""
if not retrieved_documents or not response:
return 0.0
response_lower = response.lower()
response_words = set(self._tokenize(response_lower))
# Count how many documents contributed
docs_used = 0
total_overlap = 0
for doc in retrieved_documents:
doc_lower = doc.lower()
doc_words = set(self._tokenize(doc_lower))
# Check for significant overlap
overlap = len(response_words & doc_words)
if overlap > 5: # Threshold for significant contribution
docs_used += 1
total_overlap += overlap
# Score based on proportion of documents used
proportion_used = docs_used / len(retrieved_documents)
# Also consider depth of utilization
avg_overlap = total_overlap / len(retrieved_documents) if retrieved_documents else 0
depth_score = min(avg_overlap / 20, 1.0) # Normalize
# Combined score
utilization_score = 0.6 * proportion_used + 0.4 * depth_score
return min(utilization_score, 1.0)
def _compute_relevance(
self,
query: str,
retrieved_documents: List[str]
) -> float:
"""Compute relevance score.
Measures relevance of retrieved documents to the query.
Uses lexical overlap and keyword matching.
Args:
query: User query
retrieved_documents: List of retrieved documents
Returns:
Relevance score (0-1)
"""
if not retrieved_documents or not query:
return 0.0
query_lower = query.lower()
query_words = set(self._tokenize(query_lower))
query_keywords = self._extract_keywords(query_lower)
relevance_scores = []
for doc in retrieved_documents:
doc_lower = doc.lower()
doc_words = set(self._tokenize(doc_lower))
# Lexical overlap
overlap = len(query_words & doc_words)
overlap_score = overlap / len(query_words) if query_words else 0
# Keyword matching
keyword_matches = sum(1 for kw in query_keywords if kw in doc_lower)
keyword_score = keyword_matches / len(query_keywords) if query_keywords else 0
# Combined relevance for this document
doc_relevance = 0.5 * overlap_score + 0.5 * keyword_score
relevance_scores.append(doc_relevance)
# Average relevance across documents
return np.mean(relevance_scores)
def _compute_adherence(
self,
response: str,
retrieved_documents: List[str]
) -> float:
"""Compute adherence score.
Measures how well the response adheres to the retrieved context.
Higher score means response is grounded in the documents.
Args:
response: Generated response
retrieved_documents: List of retrieved documents
Returns:
Adherence score (0-1)
"""
if not retrieved_documents or not response:
return 0.0
# Combine all documents
combined_docs = " ".join(retrieved_documents).lower()
doc_words = set(self._tokenize(combined_docs))
# Analyze response
response_lower = response.lower()
response_sentences = self._split_sentences(response_lower)
adherence_scores = []
for sentence in response_sentences:
sentence_words = set(self._tokenize(sentence))
# Check what proportion of sentence words appear in documents
if sentence_words:
grounded_words = len(sentence_words & doc_words)
sentence_adherence = grounded_words / len(sentence_words)
adherence_scores.append(sentence_adherence)
# Average adherence across sentences
return np.mean(adherence_scores) if adherence_scores else 0.0
def _compute_completeness(
self,
query: str,
response: str,
ground_truth: Optional[str] = None
) -> float:
"""Compute completeness score.
Measures how complete the response is in answering the query.
Args:
query: User query
response: Generated response
ground_truth: Optional ground truth answer
Returns:
Completeness score (0-1)
"""
if not response or not query:
return 0.0
# Query analysis
query_lower = query.lower()
# Check for question types and expected components
is_what = any(w in query_lower for w in ["what", "which"])
is_when = "when" in query_lower
is_where = "where" in query_lower
is_who = "who" in query_lower
is_why = "why" in query_lower
is_how = "how" in query_lower
response_lower = response.lower()
# Basic completeness checks
completeness_factors = []
# Length check (not too short)
min_length = 50
length_score = min(len(response) / min_length, 1.0)
completeness_factors.append(length_score)
# Check for appropriate response type
if is_when and any(w in response_lower for w in ["year", "date", "time", "century"]):
completeness_factors.append(1.0)
elif is_where and any(w in response_lower for w in ["location", "place", "country", "city"]):
completeness_factors.append(1.0)
elif is_who and any(w in response_lower for w in ["person", "people", "name"]):
completeness_factors.append(1.0)
# If ground truth available, compare
if ground_truth:
gt_lower = ground_truth.lower()
gt_words = set(self._tokenize(gt_lower))
response_words = set(self._tokenize(response_lower))
# Check overlap with ground truth
overlap = len(gt_words & response_words)
gt_score = overlap / len(gt_words) if gt_words else 0
completeness_factors.append(gt_score)
# Average all factors
return np.mean(completeness_factors) if completeness_factors else 0.5
def _tokenize(self, text: str) -> List[str]:
"""Tokenize text into words."""
# Remove punctuation and split
text = re.sub(r'[^\w\s]', ' ', text)
words = text.split()
# Filter out very short words and common stop words
stop_words = {"a", "an", "the", "is", "are", "was", "were", "in", "on", "at", "to", "for"}
return [w for w in words if len(w) > 2 and w not in stop_words]
def _extract_keywords(self, text: str) -> List[str]:
"""Extract keywords from text."""
words = self._tokenize(text)
# Simple keyword extraction - words that appear in query
# In production, use TF-IDF or similar
word_freq = Counter(words)
# Return words that appear at least once
return list(word_freq.keys())
def _split_sentences(self, text: str) -> List[str]:
"""Split text into sentences."""
# Simple sentence splitting
sentences = re.split(r'[.!?]+', text)
return [s.strip() for s in sentences if s.strip()]
def evaluate_batch(
self,
test_data: List[Dict]
) -> Dict:
"""Evaluate multiple test cases.
Args:
test_data: List of test cases, each containing:
- query: User query
- response: Generated response
- retrieved_documents: Retrieved documents
- ground_truth: Ground truth answer (optional)
Returns:
Dictionary with aggregated scores
"""
all_scores = []
for i, test_case in enumerate(test_data):
print(f"Evaluating test case {i+1}/{len(test_data)}")
scores = self.evaluate(
query=test_case.get("query", ""),
response=test_case.get("response", ""),
retrieved_documents=test_case.get("retrieved_documents", []),
ground_truth=test_case.get("ground_truth")
)
all_scores.append(scores)
# Aggregate scores
avg_utilization = np.mean([s.utilization for s in all_scores])
avg_relevance = np.mean([s.relevance for s in all_scores])
avg_adherence = np.mean([s.adherence for s in all_scores])
avg_completeness = np.mean([s.completeness for s in all_scores])
return {
"utilization": float(avg_utilization),
"relevance": float(avg_relevance),
"adherence": float(avg_adherence),
"completeness": float(avg_completeness),
"average": float((avg_utilization + avg_relevance +
avg_adherence + avg_completeness) / 4),
"num_samples": len(test_data),
"individual_scores": [s.to_dict() for s in all_scores]
}