from typing import List, Dict, Any, Set import numpy as np from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity import logging logger = logging.getLogger(__name__) class AttributionEvaluator: def __init__(self, embedding_model: str = "BAAI/bge-large-en-v1.5"): self.embedding_model = SentenceTransformer(embedding_model) def evaluate_attribution(self, answers: List[str], retrieved_passages: List[List[Dict[str, Any]]], supporting_facts: List[List[str]] = None) -> Dict[str, float]: """Evaluate attribution quality""" if not answers or not retrieved_passages: return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0} precisions = [] recalls = [] f1_scores = [] for answer, passages, facts in zip(answers, retrieved_passages, supporting_facts or [[]] * len(answers)): if not passages: precisions.append(0.0) recalls.append(0.0) f1_scores.append(0.0) continue # Extract passage texts passage_texts = [p.get('text', '') for p in passages] # Calculate attribution metrics if facts: # Use provided supporting facts precision, recall, f1 = self._calculate_attribution_metrics( answer, passage_texts, facts ) else: # Use semantic similarity as proxy precision, recall, f1 = self._calculate_semantic_attribution( answer, passage_texts ) precisions.append(precision) recalls.append(recall) f1_scores.append(f1) return { 'precision': np.mean(precisions), 'recall': np.mean(recalls), 'f1': np.mean(f1_scores), 'precision_std': np.std(precisions), 'recall_std': np.std(recalls), 'f1_std': np.std(f1_scores) } def _calculate_attribution_metrics(self, answer: str, passages: List[str], supporting_facts: List[str]) -> tuple: """Calculate attribution metrics using supporting facts""" # Find which passages contain supporting facts relevant_passages = set() for fact in supporting_facts: for i, passage in enumerate(passages): if self._passage_contains_fact(passage, fact): relevant_passages.add(i) # Calculate metrics total_passages = len(passages) relevant_count = len(relevant_passages) if total_passages == 0: return 0.0, 0.0, 0.0 # Precision: relevant passages / total retrieved passages precision = relevant_count / total_passages # Recall: relevant passages / total supporting facts recall = relevant_count / len(supporting_facts) if supporting_facts else 0.0 # F1 score f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 return precision, recall, f1 def _calculate_semantic_attribution(self, answer: str, passages: List[str]) -> tuple: """Calculate attribution using semantic similarity""" if not passages: return 0.0, 0.0, 0.0 # Encode answer and passages answer_embedding = self.embedding_model.encode([answer]) passage_embeddings = self.embedding_model.encode(passages) # Calculate similarities similarities = cosine_similarity(answer_embedding, passage_embeddings)[0] # Use threshold to determine relevant passages threshold = 0.3 relevant_passages = similarities >= threshold # Calculate metrics total_passages = len(passages) relevant_count = np.sum(relevant_passages) precision = relevant_count / total_passages recall = relevant_count / total_passages # Simplified for semantic method f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 return precision, recall, f1 def _passage_contains_fact(self, passage: str, fact: str) -> bool: """Check if passage contains a supporting fact""" # Simple containment check fact_words = set(fact.lower().split()) passage_words = set(passage.lower().split()) # Check if most fact words are in passage overlap = len(fact_words & passage_words) return overlap >= len(fact_words) * 0.7 def evaluate_citation_quality(self, answers: List[str], citations: List[List[Dict[str, Any]]]) -> Dict[str, float]: """Evaluate citation quality in answers""" if not answers or not citations: return {'citation_coverage': 0.0, 'citation_accuracy': 0.0} coverage_scores = [] accuracy_scores = [] for answer, answer_citations in zip(answers, citations): # Citation coverage: percentage of answer that is cited coverage = self._calculate_citation_coverage(answer, answer_citations) coverage_scores.append(coverage) # Citation accuracy: percentage of citations that are relevant accuracy = self._calculate_citation_accuracy(answer, answer_citations) accuracy_scores.append(accuracy) return { 'citation_coverage': np.mean(coverage_scores), 'citation_accuracy': np.mean(accuracy_scores), 'citation_coverage_std': np.std(coverage_scores), 'citation_accuracy_std': np.std(accuracy_scores) } def _calculate_citation_coverage(self, answer: str, citations: List[Dict[str, Any]]) -> float: """Calculate what percentage of answer is covered by citations""" if not citations: return 0.0 # Simple heuristic: check if answer contains citation markers import re citation_markers = re.findall(r'\[\d+\]', answer) if not citation_markers: return 0.0 # Estimate coverage based on citation density answer_length = len(answer.split()) citation_density = len(citation_markers) / answer_length if answer_length > 0 else 0 return min(1.0, citation_density * 10) # Scale factor def _calculate_citation_accuracy(self, answer: str, citations: List[Dict[str, Any]]) -> float: """Calculate accuracy of citations""" if not citations: return 0.0 # Simple heuristic: check if cited passages are relevant to answer answer_words = set(answer.lower().split()) relevant_citations = 0 for citation in citations: citation_text = citation.get('text', '') citation_words = set(citation_text.lower().split()) # Check word overlap overlap = len(answer_words & citation_words) if overlap >= 3: # Threshold for relevance relevant_citations += 1 return relevant_citations / len(citations) def evaluate_retrieval_quality(self, queries: List[str], retrieved_passages: List[List[Dict[str, Any]]], relevant_passages: List[List[str]] = None) -> Dict[str, float]: """Evaluate retrieval quality""" if not queries or not retrieved_passages: return {'retrieval_precision': 0.0, 'retrieval_recall': 0.0, 'retrieval_f1': 0.0} precisions = [] recalls = [] f1_scores = [] for query, passages, relevant in zip(queries, retrieved_passages, relevant_passages or [[]] * len(queries)): if not passages: precisions.append(0.0) recalls.append(0.0) f1_scores.append(0.0) continue # Calculate retrieval metrics if relevant: precision, recall, f1 = self._calculate_retrieval_metrics(passages, relevant) else: # Use semantic similarity as proxy precision, recall, f1 = self._calculate_semantic_retrieval(query, passages) precisions.append(precision) recalls.append(recall) f1_scores.append(f1) return { 'retrieval_precision': np.mean(precisions), 'retrieval_recall': np.mean(recalls), 'retrieval_f1': np.mean(f1_scores), 'retrieval_precision_std': np.std(precisions), 'retrieval_recall_std': np.std(recalls), 'retrieval_f1_std': np.std(f1_scores) } def _calculate_retrieval_metrics(self, passages: List[Dict[str, Any]], relevant_passages: List[str]) -> tuple: """Calculate retrieval metrics using ground truth""" retrieved_texts = [p.get('text', '') for p in passages] # Find relevant retrieved passages relevant_retrieved = 0 for retrieved in retrieved_texts: for relevant in relevant_passages: if self._passage_contains_fact(retrieved, relevant): relevant_retrieved += 1 break total_retrieved = len(passages) total_relevant = len(relevant_passages) precision = relevant_retrieved / total_retrieved if total_retrieved > 0 else 0.0 recall = relevant_retrieved / total_relevant if total_relevant > 0 else 0.0 f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 return precision, recall, f1 def _calculate_semantic_retrieval(self, query: str, passages: List[Dict[str, Any]]) -> tuple: """Calculate retrieval metrics using semantic similarity""" if not passages: return 0.0, 0.0, 0.0 # Encode query and passages query_embedding = self.embedding_model.encode([query]) passage_embeddings = self.embedding_model.encode([p.get('text', '') for p in passages]) # Calculate similarities similarities = cosine_similarity(query_embedding, passage_embeddings)[0] # Use threshold to determine relevant passages threshold = 0.3 relevant_count = np.sum(similarities >= threshold) total_retrieved = len(passages) precision = relevant_count / total_retrieved recall = relevant_count / total_retrieved # Simplified for semantic method f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 return precision, recall, f1