Spaces:
Running
Running
| import re | |
| from typing import List, Dict, Any | |
| import numpy as np | |
| from evaluate import load | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class QAEvaluator: | |
| def __init__(self): | |
| self.squad_metric = load("squad") | |
| self.rouge_metric = load("rouge") | |
| def exact_match(self, predictions: List[str], references: List[str]) -> float: | |
| """Calculate exact match score""" | |
| matches = 0 | |
| for pred, ref in zip(predictions, references): | |
| if self._normalize_answer(pred) == self._normalize_answer(ref): | |
| matches += 1 | |
| return matches / len(predictions) if predictions else 0.0 | |
| def f1_score(self, predictions: List[str], references: List[str]) -> float: | |
| """Calculate F1 score""" | |
| f1_scores = [] | |
| for pred, ref in zip(predictions, references): | |
| f1 = self._calculate_f1(pred, ref) | |
| f1_scores.append(f1) | |
| return np.mean(f1_scores) if f1_scores else 0.0 | |
| def rouge_score(self, predictions: List[str], references: List[str]) -> Dict[str, float]: | |
| """Calculate ROUGE scores""" | |
| if not predictions or not references: | |
| return {'rouge1': 0.0, 'rouge2': 0.0, 'rougeL': 0.0} | |
| results = self.rouge_metric.compute( | |
| predictions=predictions, | |
| references=references | |
| ) | |
| return { | |
| 'rouge1': results['rouge1'], | |
| 'rouge2': results['rouge2'], | |
| 'rougeL': results['rougeL'] | |
| } | |
| def squad_metrics(self, predictions: List[str], references: List[str]) -> Dict[str, float]: | |
| """Calculate SQuAD-style metrics""" | |
| if not predictions or not references: | |
| return {'exact_match': 0.0, 'f1': 0.0} | |
| # Format for SQuAD metric | |
| formatted_predictions = [{"prediction_text": pred, "id": str(i)} | |
| for i, pred in enumerate(predictions)] | |
| formatted_references = [{"answers": {"text": [ref], "answer_start": [0]}, "id": str(i)} | |
| for i, ref in enumerate(references)] | |
| results = self.squad_metric.compute( | |
| predictions=formatted_predictions, | |
| references=formatted_references | |
| ) | |
| return { | |
| 'exact_match': results['exact_match'], | |
| 'f1': results['f1'] | |
| } | |
| def evaluate_batch(self, predictions: List[str], references: List[str]) -> Dict[str, float]: | |
| """Evaluate a batch of predictions""" | |
| metrics = {} | |
| # Basic metrics | |
| metrics['exact_match'] = self.exact_match(predictions, references) | |
| metrics['f1'] = self.f1_score(predictions, references) | |
| # ROUGE metrics | |
| rouge_scores = self.rouge_score(predictions, references) | |
| metrics.update(rouge_scores) | |
| # SQuAD metrics | |
| squad_scores = self.squad_metrics(predictions, references) | |
| metrics.update(squad_scores) | |
| return metrics | |
| def _normalize_answer(self, answer: str) -> str: | |
| """Normalize answer for comparison""" | |
| def remove_articles(text): | |
| return re.sub(r'\b(a|an|the)\b', ' ', text) | |
| def white_space_fix(text): | |
| return ' '.join(text.split()) | |
| def remove_punc(text): | |
| exclude = set(string.punctuation) | |
| return ''.join(ch for ch in text if ch not in exclude) | |
| def lower(text): | |
| return text.lower() | |
| return white_space_fix(remove_articles(remove_punc(lower(answer)))) | |
| def _calculate_f1(self, prediction: str, reference: str) -> float: | |
| """Calculate F1 score between prediction and reference""" | |
| pred_tokens = self._normalize_answer(prediction).split() | |
| ref_tokens = self._normalize_answer(reference).split() | |
| if len(ref_tokens) == 0: | |
| return 1.0 if len(pred_tokens) == 0 else 0.0 | |
| common = set(pred_tokens) & set(ref_tokens) | |
| if len(common) == 0: | |
| return 0.0 | |
| precision = len(common) / len(pred_tokens) | |
| recall = len(common) / len(ref_tokens) | |
| f1 = 2 * precision * recall / (precision + recall) | |
| return f1 | |
| def evaluate_with_context(self, predictions: List[str], references: List[str], | |
| contexts: List[str]) -> Dict[str, float]: | |
| """Evaluate with context awareness""" | |
| metrics = self.evaluate_batch(predictions, references) | |
| # Context-based metrics | |
| context_scores = [] | |
| for pred, context in zip(predictions, contexts): | |
| # Check if prediction is supported by context | |
| pred_words = set(pred.lower().split()) | |
| context_words = set(context.lower().split()) | |
| overlap = len(pred_words & context_words) / len(pred_words) if pred_words else 0 | |
| context_scores.append(overlap) | |
| metrics['context_support'] = np.mean(context_scores) | |
| return metrics | |