Spaces:
Sleeping
Sleeping
| from typing import List, Dict, Any, Set | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class AttributionEvaluator: | |
| def __init__(self, embedding_model: str = "BAAI/bge-large-en-v1.5"): | |
| self.embedding_model = SentenceTransformer(embedding_model) | |
| def evaluate_attribution(self, answers: List[str], | |
| retrieved_passages: List[List[Dict[str, Any]]], | |
| supporting_facts: List[List[str]] = None) -> Dict[str, float]: | |
| """Evaluate attribution quality""" | |
| if not answers or not retrieved_passages: | |
| return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0} | |
| precisions = [] | |
| recalls = [] | |
| f1_scores = [] | |
| for answer, passages, facts in zip(answers, retrieved_passages, supporting_facts or [[]] * len(answers)): | |
| if not passages: | |
| precisions.append(0.0) | |
| recalls.append(0.0) | |
| f1_scores.append(0.0) | |
| continue | |
| # Extract passage texts | |
| passage_texts = [p.get('text', '') for p in passages] | |
| # Calculate attribution metrics | |
| if facts: | |
| # Use provided supporting facts | |
| precision, recall, f1 = self._calculate_attribution_metrics( | |
| answer, passage_texts, facts | |
| ) | |
| else: | |
| # Use semantic similarity as proxy | |
| precision, recall, f1 = self._calculate_semantic_attribution( | |
| answer, passage_texts | |
| ) | |
| precisions.append(precision) | |
| recalls.append(recall) | |
| f1_scores.append(f1) | |
| return { | |
| 'precision': np.mean(precisions), | |
| 'recall': np.mean(recalls), | |
| 'f1': np.mean(f1_scores), | |
| 'precision_std': np.std(precisions), | |
| 'recall_std': np.std(recalls), | |
| 'f1_std': np.std(f1_scores) | |
| } | |
| def _calculate_attribution_metrics(self, answer: str, passages: List[str], | |
| supporting_facts: List[str]) -> tuple: | |
| """Calculate attribution metrics using supporting facts""" | |
| # Find which passages contain supporting facts | |
| relevant_passages = set() | |
| for fact in supporting_facts: | |
| for i, passage in enumerate(passages): | |
| if self._passage_contains_fact(passage, fact): | |
| relevant_passages.add(i) | |
| # Calculate metrics | |
| total_passages = len(passages) | |
| relevant_count = len(relevant_passages) | |
| if total_passages == 0: | |
| return 0.0, 0.0, 0.0 | |
| # Precision: relevant passages / total retrieved passages | |
| precision = relevant_count / total_passages | |
| # Recall: relevant passages / total supporting facts | |
| recall = relevant_count / len(supporting_facts) if supporting_facts else 0.0 | |
| # F1 score | |
| f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 | |
| return precision, recall, f1 | |
| def _calculate_semantic_attribution(self, answer: str, passages: List[str]) -> tuple: | |
| """Calculate attribution using semantic similarity""" | |
| if not passages: | |
| return 0.0, 0.0, 0.0 | |
| # Encode answer and passages | |
| answer_embedding = self.embedding_model.encode([answer]) | |
| passage_embeddings = self.embedding_model.encode(passages) | |
| # Calculate similarities | |
| similarities = cosine_similarity(answer_embedding, passage_embeddings)[0] | |
| # Use threshold to determine relevant passages | |
| threshold = 0.3 | |
| relevant_passages = similarities >= threshold | |
| # Calculate metrics | |
| total_passages = len(passages) | |
| relevant_count = np.sum(relevant_passages) | |
| precision = relevant_count / total_passages | |
| recall = relevant_count / total_passages # Simplified for semantic method | |
| f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 | |
| return precision, recall, f1 | |
| def _passage_contains_fact(self, passage: str, fact: str) -> bool: | |
| """Check if passage contains a supporting fact""" | |
| # Simple containment check | |
| fact_words = set(fact.lower().split()) | |
| passage_words = set(passage.lower().split()) | |
| # Check if most fact words are in passage | |
| overlap = len(fact_words & passage_words) | |
| return overlap >= len(fact_words) * 0.7 | |
| def evaluate_citation_quality(self, answers: List[str], | |
| citations: List[List[Dict[str, Any]]]) -> Dict[str, float]: | |
| """Evaluate citation quality in answers""" | |
| if not answers or not citations: | |
| return {'citation_coverage': 0.0, 'citation_accuracy': 0.0} | |
| coverage_scores = [] | |
| accuracy_scores = [] | |
| for answer, answer_citations in zip(answers, citations): | |
| # Citation coverage: percentage of answer that is cited | |
| coverage = self._calculate_citation_coverage(answer, answer_citations) | |
| coverage_scores.append(coverage) | |
| # Citation accuracy: percentage of citations that are relevant | |
| accuracy = self._calculate_citation_accuracy(answer, answer_citations) | |
| accuracy_scores.append(accuracy) | |
| return { | |
| 'citation_coverage': np.mean(coverage_scores), | |
| 'citation_accuracy': np.mean(accuracy_scores), | |
| 'citation_coverage_std': np.std(coverage_scores), | |
| 'citation_accuracy_std': np.std(accuracy_scores) | |
| } | |
| def _calculate_citation_coverage(self, answer: str, citations: List[Dict[str, Any]]) -> float: | |
| """Calculate what percentage of answer is covered by citations""" | |
| if not citations: | |
| return 0.0 | |
| # Simple heuristic: check if answer contains citation markers | |
| import re | |
| citation_markers = re.findall(r'\[\d+\]', answer) | |
| if not citation_markers: | |
| return 0.0 | |
| # Estimate coverage based on citation density | |
| answer_length = len(answer.split()) | |
| citation_density = len(citation_markers) / answer_length if answer_length > 0 else 0 | |
| return min(1.0, citation_density * 10) # Scale factor | |
| def _calculate_citation_accuracy(self, answer: str, citations: List[Dict[str, Any]]) -> float: | |
| """Calculate accuracy of citations""" | |
| if not citations: | |
| return 0.0 | |
| # Simple heuristic: check if cited passages are relevant to answer | |
| answer_words = set(answer.lower().split()) | |
| relevant_citations = 0 | |
| for citation in citations: | |
| citation_text = citation.get('text', '') | |
| citation_words = set(citation_text.lower().split()) | |
| # Check word overlap | |
| overlap = len(answer_words & citation_words) | |
| if overlap >= 3: # Threshold for relevance | |
| relevant_citations += 1 | |
| return relevant_citations / len(citations) | |
| def evaluate_retrieval_quality(self, queries: List[str], | |
| retrieved_passages: List[List[Dict[str, Any]]], | |
| relevant_passages: List[List[str]] = None) -> Dict[str, float]: | |
| """Evaluate retrieval quality""" | |
| if not queries or not retrieved_passages: | |
| return {'retrieval_precision': 0.0, 'retrieval_recall': 0.0, 'retrieval_f1': 0.0} | |
| precisions = [] | |
| recalls = [] | |
| f1_scores = [] | |
| for query, passages, relevant in zip(queries, retrieved_passages, relevant_passages or [[]] * len(queries)): | |
| if not passages: | |
| precisions.append(0.0) | |
| recalls.append(0.0) | |
| f1_scores.append(0.0) | |
| continue | |
| # Calculate retrieval metrics | |
| if relevant: | |
| precision, recall, f1 = self._calculate_retrieval_metrics(passages, relevant) | |
| else: | |
| # Use semantic similarity as proxy | |
| precision, recall, f1 = self._calculate_semantic_retrieval(query, passages) | |
| precisions.append(precision) | |
| recalls.append(recall) | |
| f1_scores.append(f1) | |
| return { | |
| 'retrieval_precision': np.mean(precisions), | |
| 'retrieval_recall': np.mean(recalls), | |
| 'retrieval_f1': np.mean(f1_scores), | |
| 'retrieval_precision_std': np.std(precisions), | |
| 'retrieval_recall_std': np.std(recalls), | |
| 'retrieval_f1_std': np.std(f1_scores) | |
| } | |
| def _calculate_retrieval_metrics(self, passages: List[Dict[str, Any]], | |
| relevant_passages: List[str]) -> tuple: | |
| """Calculate retrieval metrics using ground truth""" | |
| retrieved_texts = [p.get('text', '') for p in passages] | |
| # Find relevant retrieved passages | |
| relevant_retrieved = 0 | |
| for retrieved in retrieved_texts: | |
| for relevant in relevant_passages: | |
| if self._passage_contains_fact(retrieved, relevant): | |
| relevant_retrieved += 1 | |
| break | |
| total_retrieved = len(passages) | |
| total_relevant = len(relevant_passages) | |
| precision = relevant_retrieved / total_retrieved if total_retrieved > 0 else 0.0 | |
| recall = relevant_retrieved / total_relevant if total_relevant > 0 else 0.0 | |
| f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 | |
| return precision, recall, f1 | |
| def _calculate_semantic_retrieval(self, query: str, passages: List[Dict[str, Any]]) -> tuple: | |
| """Calculate retrieval metrics using semantic similarity""" | |
| if not passages: | |
| return 0.0, 0.0, 0.0 | |
| # Encode query and passages | |
| query_embedding = self.embedding_model.encode([query]) | |
| passage_embeddings = self.embedding_model.encode([p.get('text', '') for p in passages]) | |
| # Calculate similarities | |
| similarities = cosine_similarity(query_embedding, passage_embeddings)[0] | |
| # Use threshold to determine relevant passages | |
| threshold = 0.3 | |
| relevant_count = np.sum(similarities >= threshold) | |
| total_retrieved = len(passages) | |
| precision = relevant_count / total_retrieved | |
| recall = relevant_count / total_retrieved # Simplified for semantic method | |
| f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 | |
| return precision, recall, f1 | |