Spaces:
Running
Running
| """ | |
| Scoring functions for AMA-Bench submissions. | |
| This module implements evaluation logic for string answers, | |
| calculating accuracy by exact string match (case-insensitive). | |
| """ | |
| import re | |
| from typing import Union, List, Dict | |
| def normalize_answer(text: str) -> str: | |
| """Normalize answer string for comparison (lowercase, strip whitespace).""" | |
| return str(text).strip().lower() | |
| def string_exact_match(prediction: str, reference: str) -> float: | |
| """ | |
| Calculate accuracy for string answers using exact match. | |
| Args: | |
| prediction: Model's predicted answer string | |
| reference: Ground truth reference answer string | |
| Returns: | |
| 1.0 if normalized strings match exactly, 0.0 otherwise | |
| """ | |
| return 1.0 if normalize_answer(prediction) == normalize_answer(reference) else 0.0 | |
| def calculate_accuracy(scores: List[float]) -> Dict[str, float]: | |
| """ | |
| Calculate accuracy metric from individual question scores. | |
| Args: | |
| scores: List of question scores (0.0 or 1.0) | |
| Returns: | |
| Dictionary with accuracy metric | |
| """ | |
| if not scores: | |
| return {"accuracy": 0.0, "count": 0, "correct": 0} | |
| import numpy as np | |
| return { | |
| "accuracy": float(np.mean(scores)), | |
| "count": len(scores), | |
| "correct": int(sum(scores)), | |
| } | |
| def score_submission( | |
| submissions: List[Dict], | |
| groundtruth: Dict[str, Dict], | |
| metrics_mapping: Dict[str, str] = None | |
| ) -> Dict: | |
| """ | |
| Score a complete submission against ground truth. | |
| Args: | |
| submissions: List of submission dicts with episode_id, question, answer | |
| groundtruth: Dict mapping (episode_id, question) to ground truth info | |
| metrics_mapping: Optional dict mapping question types to metric categories | |
| Returns: | |
| Dictionary with overall and per-metric scores | |
| """ | |
| # Default metric mapping based on question type | |
| if metrics_mapping is None: | |
| metrics_mapping = { | |
| "Recall": "Recall", | |
| "Causal": "Causal Inference", | |
| "State": "State Updating", | |
| "Abstraction": "State Abstraction", | |
| "A": "Recall", | |
| "B": "Causal Inference", | |
| "C": "State Updating", | |
| "D": "State Abstraction", | |
| } | |
| # Initialize scores by metric | |
| scores_by_metric = { | |
| "Recall": [], | |
| "Causal Inference": [], | |
| "State Updating": [], | |
| "State Abstraction": [], | |
| } | |
| all_scores = [] | |
| scored_submissions = [] | |
| for submission in submissions: | |
| episode_id = submission.get("episode_id", "") | |
| question = submission.get("question", "") | |
| answer = submission.get("answer", "") | |
| # Look up ground truth | |
| key = f"{episode_id}_{question}" | |
| gt_info = groundtruth.get(key) | |
| if gt_info is None: | |
| # Question not found in ground truth | |
| score = 0.0 | |
| reference = "" | |
| qa_type = "Unknown" | |
| else: | |
| reference = gt_info["answer"] | |
| qa_type = gt_info.get("type", "Recall") | |
| # Calculate accuracy via exact string match | |
| score = string_exact_match(answer, reference) | |
| # Map question type to metric category | |
| metric_category = "Recall" # default | |
| for key_term, metric in metrics_mapping.items(): | |
| if key_term.lower() in qa_type.lower(): | |
| metric_category = metric | |
| break | |
| # Add to appropriate metric bucket | |
| if metric_category in scores_by_metric: | |
| scores_by_metric[metric_category].append(score) | |
| all_scores.append(score) | |
| # Store scored submission | |
| scored_submissions.append({ | |
| **submission, | |
| "score": score, | |
| "reference_answer": reference, | |
| "metric_category": metric_category, | |
| "domain": gt_info.get("domain", "") if gt_info else "", | |
| }) | |
| # Calculate metrics for each category | |
| results = {} | |
| for metric_name, metric_scores in scores_by_metric.items(): | |
| results[metric_name] = calculate_accuracy(metric_scores) | |
| # Calculate overall average | |
| results["Average"] = calculate_accuracy(all_scores) | |
| return { | |
| "scores": results, | |
| "scored_submissions": scored_submissions, | |
| } |