import re from typing import List, Dict, Any import numpy as np from evaluate import load import logging logger = logging.getLogger(__name__) class QAEvaluator: def __init__(self): self.squad_metric = load("squad") self.rouge_metric = load("rouge") def exact_match(self, predictions: List[str], references: List[str]) -> float: """Calculate exact match score""" matches = 0 for pred, ref in zip(predictions, references): if self._normalize_answer(pred) == self._normalize_answer(ref): matches += 1 return matches / len(predictions) if predictions else 0.0 def f1_score(self, predictions: List[str], references: List[str]) -> float: """Calculate F1 score""" f1_scores = [] for pred, ref in zip(predictions, references): f1 = self._calculate_f1(pred, ref) f1_scores.append(f1) return np.mean(f1_scores) if f1_scores else 0.0 def rouge_score(self, predictions: List[str], references: List[str]) -> Dict[str, float]: """Calculate ROUGE scores""" if not predictions or not references: return {'rouge1': 0.0, 'rouge2': 0.0, 'rougeL': 0.0} results = self.rouge_metric.compute( predictions=predictions, references=references ) return { 'rouge1': results['rouge1'], 'rouge2': results['rouge2'], 'rougeL': results['rougeL'] } def squad_metrics(self, predictions: List[str], references: List[str]) -> Dict[str, float]: """Calculate SQuAD-style metrics""" if not predictions or not references: return {'exact_match': 0.0, 'f1': 0.0} # Format for SQuAD metric formatted_predictions = [{"prediction_text": pred, "id": str(i)} for i, pred in enumerate(predictions)] formatted_references = [{"answers": {"text": [ref], "answer_start": [0]}, "id": str(i)} for i, ref in enumerate(references)] results = self.squad_metric.compute( predictions=formatted_predictions, references=formatted_references ) return { 'exact_match': results['exact_match'], 'f1': results['f1'] } def evaluate_batch(self, predictions: List[str], references: List[str]) -> Dict[str, float]: """Evaluate a batch of predictions""" metrics = {} # Basic metrics metrics['exact_match'] = self.exact_match(predictions, references) metrics['f1'] = self.f1_score(predictions, references) # ROUGE metrics rouge_scores = self.rouge_score(predictions, references) metrics.update(rouge_scores) # SQuAD metrics squad_scores = self.squad_metrics(predictions, references) metrics.update(squad_scores) return metrics def _normalize_answer(self, answer: str) -> str: """Normalize answer for comparison""" def remove_articles(text): return re.sub(r'\b(a|an|the)\b', ' ', text) def white_space_fix(text): return ' '.join(text.split()) def remove_punc(text): exclude = set(string.punctuation) return ''.join(ch for ch in text if ch not in exclude) def lower(text): return text.lower() return white_space_fix(remove_articles(remove_punc(lower(answer)))) def _calculate_f1(self, prediction: str, reference: str) -> float: """Calculate F1 score between prediction and reference""" pred_tokens = self._normalize_answer(prediction).split() ref_tokens = self._normalize_answer(reference).split() if len(ref_tokens) == 0: return 1.0 if len(pred_tokens) == 0 else 0.0 common = set(pred_tokens) & set(ref_tokens) if len(common) == 0: return 0.0 precision = len(common) / len(pred_tokens) recall = len(common) / len(ref_tokens) f1 = 2 * precision * recall / (precision + recall) return f1 def evaluate_with_context(self, predictions: List[str], references: List[str], contexts: List[str]) -> Dict[str, float]: """Evaluate with context awareness""" metrics = self.evaluate_batch(predictions, references) # Context-based metrics context_scores = [] for pred, context in zip(predictions, contexts): # Check if prediction is supported by context pred_words = set(pred.lower().split()) context_words = set(context.lower().split()) overlap = len(pred_words & context_words) / len(pred_words) if pred_words else 0 context_scores.append(overlap) metrics['context_support'] = np.mean(context_scores) return metrics