uuuhjb's picture
add submit function
bb0a764
"""
Scoring functions for AMA-Bench submissions.
This module implements evaluation logic for string answers,
calculating accuracy by exact string match (case-insensitive).
"""
import re
from typing import Union, List, Dict
def normalize_answer(text: str) -> str:
"""Normalize answer string for comparison (lowercase, strip whitespace)."""
return str(text).strip().lower()
def string_exact_match(prediction: str, reference: str) -> float:
"""
Calculate accuracy for string answers using exact match.
Args:
prediction: Model's predicted answer string
reference: Ground truth reference answer string
Returns:
1.0 if normalized strings match exactly, 0.0 otherwise
"""
return 1.0 if normalize_answer(prediction) == normalize_answer(reference) else 0.0
def calculate_accuracy(scores: List[float]) -> Dict[str, float]:
"""
Calculate accuracy metric from individual question scores.
Args:
scores: List of question scores (0.0 or 1.0)
Returns:
Dictionary with accuracy metric
"""
if not scores:
return {"accuracy": 0.0, "count": 0, "correct": 0}
import numpy as np
return {
"accuracy": float(np.mean(scores)),
"count": len(scores),
"correct": int(sum(scores)),
}
def score_submission(
submissions: List[Dict],
groundtruth: Dict[str, Dict],
metrics_mapping: Dict[str, str] = None
) -> Dict:
"""
Score a complete submission against ground truth.
Args:
submissions: List of submission dicts with episode_id, question, answer
groundtruth: Dict mapping (episode_id, question) to ground truth info
metrics_mapping: Optional dict mapping question types to metric categories
Returns:
Dictionary with overall and per-metric scores
"""
# Default metric mapping based on question type
if metrics_mapping is None:
metrics_mapping = {
"Recall": "Recall",
"Causal": "Causal Inference",
"State": "State Updating",
"Abstraction": "State Abstraction",
"A": "Recall",
"B": "Causal Inference",
"C": "State Updating",
"D": "State Abstraction",
}
# Initialize scores by metric
scores_by_metric = {
"Recall": [],
"Causal Inference": [],
"State Updating": [],
"State Abstraction": [],
}
all_scores = []
scored_submissions = []
for submission in submissions:
episode_id = submission.get("episode_id", "")
question = submission.get("question", "")
answer = submission.get("answer", "")
# Look up ground truth
key = f"{episode_id}_{question}"
gt_info = groundtruth.get(key)
if gt_info is None:
# Question not found in ground truth
score = 0.0
reference = ""
qa_type = "Unknown"
else:
reference = gt_info["answer"]
qa_type = gt_info.get("type", "Recall")
# Calculate accuracy via exact string match
score = string_exact_match(answer, reference)
# Map question type to metric category
metric_category = "Recall" # default
for key_term, metric in metrics_mapping.items():
if key_term.lower() in qa_type.lower():
metric_category = metric
break
# Add to appropriate metric bucket
if metric_category in scores_by_metric:
scores_by_metric[metric_category].append(score)
all_scores.append(score)
# Store scored submission
scored_submissions.append({
**submission,
"score": score,
"reference_answer": reference,
"metric_category": metric_category,
"domain": gt_info.get("domain", "") if gt_info else "",
})
# Calculate metrics for each category
results = {}
for metric_name, metric_scores in scores_by_metric.items():
results[metric_name] = calculate_accuracy(metric_scores)
# Calculate overall average
results["Average"] = calculate_accuracy(all_scores)
return {
"scores": results,
"scored_submissions": scored_submissions,
}