safe_rag / eval /eval_attr.py
Tairun Meng
Initial commit: SafeRAG project ready for HF Spaces
db06013
from typing import List, Dict, Any, Set
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import logging
logger = logging.getLogger(__name__)
class AttributionEvaluator:
def __init__(self, embedding_model: str = "BAAI/bge-large-en-v1.5"):
self.embedding_model = SentenceTransformer(embedding_model)
def evaluate_attribution(self, answers: List[str],
retrieved_passages: List[List[Dict[str, Any]]],
supporting_facts: List[List[str]] = None) -> Dict[str, float]:
"""Evaluate attribution quality"""
if not answers or not retrieved_passages:
return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}
precisions = []
recalls = []
f1_scores = []
for answer, passages, facts in zip(answers, retrieved_passages, supporting_facts or [[]] * len(answers)):
if not passages:
precisions.append(0.0)
recalls.append(0.0)
f1_scores.append(0.0)
continue
# Extract passage texts
passage_texts = [p.get('text', '') for p in passages]
# Calculate attribution metrics
if facts:
# Use provided supporting facts
precision, recall, f1 = self._calculate_attribution_metrics(
answer, passage_texts, facts
)
else:
# Use semantic similarity as proxy
precision, recall, f1 = self._calculate_semantic_attribution(
answer, passage_texts
)
precisions.append(precision)
recalls.append(recall)
f1_scores.append(f1)
return {
'precision': np.mean(precisions),
'recall': np.mean(recalls),
'f1': np.mean(f1_scores),
'precision_std': np.std(precisions),
'recall_std': np.std(recalls),
'f1_std': np.std(f1_scores)
}
def _calculate_attribution_metrics(self, answer: str, passages: List[str],
supporting_facts: List[str]) -> tuple:
"""Calculate attribution metrics using supporting facts"""
# Find which passages contain supporting facts
relevant_passages = set()
for fact in supporting_facts:
for i, passage in enumerate(passages):
if self._passage_contains_fact(passage, fact):
relevant_passages.add(i)
# Calculate metrics
total_passages = len(passages)
relevant_count = len(relevant_passages)
if total_passages == 0:
return 0.0, 0.0, 0.0
# Precision: relevant passages / total retrieved passages
precision = relevant_count / total_passages
# Recall: relevant passages / total supporting facts
recall = relevant_count / len(supporting_facts) if supporting_facts else 0.0
# F1 score
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
return precision, recall, f1
def _calculate_semantic_attribution(self, answer: str, passages: List[str]) -> tuple:
"""Calculate attribution using semantic similarity"""
if not passages:
return 0.0, 0.0, 0.0
# Encode answer and passages
answer_embedding = self.embedding_model.encode([answer])
passage_embeddings = self.embedding_model.encode(passages)
# Calculate similarities
similarities = cosine_similarity(answer_embedding, passage_embeddings)[0]
# Use threshold to determine relevant passages
threshold = 0.3
relevant_passages = similarities >= threshold
# Calculate metrics
total_passages = len(passages)
relevant_count = np.sum(relevant_passages)
precision = relevant_count / total_passages
recall = relevant_count / total_passages # Simplified for semantic method
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
return precision, recall, f1
def _passage_contains_fact(self, passage: str, fact: str) -> bool:
"""Check if passage contains a supporting fact"""
# Simple containment check
fact_words = set(fact.lower().split())
passage_words = set(passage.lower().split())
# Check if most fact words are in passage
overlap = len(fact_words & passage_words)
return overlap >= len(fact_words) * 0.7
def evaluate_citation_quality(self, answers: List[str],
citations: List[List[Dict[str, Any]]]) -> Dict[str, float]:
"""Evaluate citation quality in answers"""
if not answers or not citations:
return {'citation_coverage': 0.0, 'citation_accuracy': 0.0}
coverage_scores = []
accuracy_scores = []
for answer, answer_citations in zip(answers, citations):
# Citation coverage: percentage of answer that is cited
coverage = self._calculate_citation_coverage(answer, answer_citations)
coverage_scores.append(coverage)
# Citation accuracy: percentage of citations that are relevant
accuracy = self._calculate_citation_accuracy(answer, answer_citations)
accuracy_scores.append(accuracy)
return {
'citation_coverage': np.mean(coverage_scores),
'citation_accuracy': np.mean(accuracy_scores),
'citation_coverage_std': np.std(coverage_scores),
'citation_accuracy_std': np.std(accuracy_scores)
}
def _calculate_citation_coverage(self, answer: str, citations: List[Dict[str, Any]]) -> float:
"""Calculate what percentage of answer is covered by citations"""
if not citations:
return 0.0
# Simple heuristic: check if answer contains citation markers
import re
citation_markers = re.findall(r'\[\d+\]', answer)
if not citation_markers:
return 0.0
# Estimate coverage based on citation density
answer_length = len(answer.split())
citation_density = len(citation_markers) / answer_length if answer_length > 0 else 0
return min(1.0, citation_density * 10) # Scale factor
def _calculate_citation_accuracy(self, answer: str, citations: List[Dict[str, Any]]) -> float:
"""Calculate accuracy of citations"""
if not citations:
return 0.0
# Simple heuristic: check if cited passages are relevant to answer
answer_words = set(answer.lower().split())
relevant_citations = 0
for citation in citations:
citation_text = citation.get('text', '')
citation_words = set(citation_text.lower().split())
# Check word overlap
overlap = len(answer_words & citation_words)
if overlap >= 3: # Threshold for relevance
relevant_citations += 1
return relevant_citations / len(citations)
def evaluate_retrieval_quality(self, queries: List[str],
retrieved_passages: List[List[Dict[str, Any]]],
relevant_passages: List[List[str]] = None) -> Dict[str, float]:
"""Evaluate retrieval quality"""
if not queries or not retrieved_passages:
return {'retrieval_precision': 0.0, 'retrieval_recall': 0.0, 'retrieval_f1': 0.0}
precisions = []
recalls = []
f1_scores = []
for query, passages, relevant in zip(queries, retrieved_passages, relevant_passages or [[]] * len(queries)):
if not passages:
precisions.append(0.0)
recalls.append(0.0)
f1_scores.append(0.0)
continue
# Calculate retrieval metrics
if relevant:
precision, recall, f1 = self._calculate_retrieval_metrics(passages, relevant)
else:
# Use semantic similarity as proxy
precision, recall, f1 = self._calculate_semantic_retrieval(query, passages)
precisions.append(precision)
recalls.append(recall)
f1_scores.append(f1)
return {
'retrieval_precision': np.mean(precisions),
'retrieval_recall': np.mean(recalls),
'retrieval_f1': np.mean(f1_scores),
'retrieval_precision_std': np.std(precisions),
'retrieval_recall_std': np.std(recalls),
'retrieval_f1_std': np.std(f1_scores)
}
def _calculate_retrieval_metrics(self, passages: List[Dict[str, Any]],
relevant_passages: List[str]) -> tuple:
"""Calculate retrieval metrics using ground truth"""
retrieved_texts = [p.get('text', '') for p in passages]
# Find relevant retrieved passages
relevant_retrieved = 0
for retrieved in retrieved_texts:
for relevant in relevant_passages:
if self._passage_contains_fact(retrieved, relevant):
relevant_retrieved += 1
break
total_retrieved = len(passages)
total_relevant = len(relevant_passages)
precision = relevant_retrieved / total_retrieved if total_retrieved > 0 else 0.0
recall = relevant_retrieved / total_relevant if total_relevant > 0 else 0.0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
return precision, recall, f1
def _calculate_semantic_retrieval(self, query: str, passages: List[Dict[str, Any]]) -> tuple:
"""Calculate retrieval metrics using semantic similarity"""
if not passages:
return 0.0, 0.0, 0.0
# Encode query and passages
query_embedding = self.embedding_model.encode([query])
passage_embeddings = self.embedding_model.encode([p.get('text', '') for p in passages])
# Calculate similarities
similarities = cosine_similarity(query_embedding, passage_embeddings)[0]
# Use threshold to determine relevant passages
threshold = 0.3
relevant_count = np.sum(similarities >= threshold)
total_retrieved = len(passages)
precision = relevant_count / total_retrieved
recall = relevant_count / total_retrieved # Simplified for semantic method
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
return precision, recall, f1