safe_rag / eval /eval_qa.py
Tairun Meng
Initial commit: SafeRAG project ready for HF Spaces
db06013
import re
from typing import List, Dict, Any
import numpy as np
from evaluate import load
import logging
logger = logging.getLogger(__name__)
class QAEvaluator:
def __init__(self):
self.squad_metric = load("squad")
self.rouge_metric = load("rouge")
def exact_match(self, predictions: List[str], references: List[str]) -> float:
"""Calculate exact match score"""
matches = 0
for pred, ref in zip(predictions, references):
if self._normalize_answer(pred) == self._normalize_answer(ref):
matches += 1
return matches / len(predictions) if predictions else 0.0
def f1_score(self, predictions: List[str], references: List[str]) -> float:
"""Calculate F1 score"""
f1_scores = []
for pred, ref in zip(predictions, references):
f1 = self._calculate_f1(pred, ref)
f1_scores.append(f1)
return np.mean(f1_scores) if f1_scores else 0.0
def rouge_score(self, predictions: List[str], references: List[str]) -> Dict[str, float]:
"""Calculate ROUGE scores"""
if not predictions or not references:
return {'rouge1': 0.0, 'rouge2': 0.0, 'rougeL': 0.0}
results = self.rouge_metric.compute(
predictions=predictions,
references=references
)
return {
'rouge1': results['rouge1'],
'rouge2': results['rouge2'],
'rougeL': results['rougeL']
}
def squad_metrics(self, predictions: List[str], references: List[str]) -> Dict[str, float]:
"""Calculate SQuAD-style metrics"""
if not predictions or not references:
return {'exact_match': 0.0, 'f1': 0.0}
# Format for SQuAD metric
formatted_predictions = [{"prediction_text": pred, "id": str(i)}
for i, pred in enumerate(predictions)]
formatted_references = [{"answers": {"text": [ref], "answer_start": [0]}, "id": str(i)}
for i, ref in enumerate(references)]
results = self.squad_metric.compute(
predictions=formatted_predictions,
references=formatted_references
)
return {
'exact_match': results['exact_match'],
'f1': results['f1']
}
def evaluate_batch(self, predictions: List[str], references: List[str]) -> Dict[str, float]:
"""Evaluate a batch of predictions"""
metrics = {}
# Basic metrics
metrics['exact_match'] = self.exact_match(predictions, references)
metrics['f1'] = self.f1_score(predictions, references)
# ROUGE metrics
rouge_scores = self.rouge_score(predictions, references)
metrics.update(rouge_scores)
# SQuAD metrics
squad_scores = self.squad_metrics(predictions, references)
metrics.update(squad_scores)
return metrics
def _normalize_answer(self, answer: str) -> str:
"""Normalize answer for comparison"""
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(answer))))
def _calculate_f1(self, prediction: str, reference: str) -> float:
"""Calculate F1 score between prediction and reference"""
pred_tokens = self._normalize_answer(prediction).split()
ref_tokens = self._normalize_answer(reference).split()
if len(ref_tokens) == 0:
return 1.0 if len(pred_tokens) == 0 else 0.0
common = set(pred_tokens) & set(ref_tokens)
if len(common) == 0:
return 0.0
precision = len(common) / len(pred_tokens)
recall = len(common) / len(ref_tokens)
f1 = 2 * precision * recall / (precision + recall)
return f1
def evaluate_with_context(self, predictions: List[str], references: List[str],
contexts: List[str]) -> Dict[str, float]:
"""Evaluate with context awareness"""
metrics = self.evaluate_batch(predictions, references)
# Context-based metrics
context_scores = []
for pred, context in zip(predictions, contexts):
# Check if prediction is supported by context
pred_words = set(pred.lower().split())
context_words = set(context.lower().split())
overlap = len(pred_words & context_words) / len(pred_words) if pred_words else 0
context_scores.append(overlap)
metrics['context_support'] = np.mean(context_scores)
return metrics