uuuhjb's picture
add submit function
bb0a764
"""
Utility functions for AMA-Bench Leaderboard.
This module contains helper functions for:
- DataFrame building and manipulation
- Chart generation
- Data validation
"""
import pandas as pd
import plotly.graph_objects as go
from typing import List, Dict
# Metrics configuration
METRICS = ["Recall", "Causal Inference", "State Updating", "State Abstraction"]
ALL_METRICS = METRICS + ["Average"]
# Chart colors moved to visualization.py
def build_dataframe(data: Dict) -> pd.DataFrame:
"""
Build a pandas DataFrame showing Accuracy for each metric.
Args:
data: Dictionary with 'entries' key containing list of results
Returns:
DataFrame with Method and metric columns
"""
rows = []
for entry in data["entries"]:
row = {"Method": entry["method"]}
if entry.get("category"):
row["Category"] = entry["category"]
for m in ALL_METRICS:
accuracy = entry["scores"][m]["accuracy"]
row[m] = f"{accuracy:.4f}"
# Store raw average accuracy for sorting
row["_sort_avg"] = entry["scores"]["Average"]["accuracy"]
rows.append(row)
df = pd.DataFrame(rows)
df = df.sort_values("_sort_avg", ascending=False).reset_index(drop=True)
df = df.drop(columns=["_sort_avg"])
return df
def add_medals(df: pd.DataFrame) -> pd.DataFrame:
"""
Add medal emojis to the top-3 Method names.
Args:
df: DataFrame with 'Method' column
Returns:
DataFrame with medals added to top 3 methods
"""
df = df.copy()
medals = ["\U0001f947", "\U0001f948", "\U0001f949"] # 🥇 🥈 🥉
for i in range(min(3, len(df))):
df.loc[i, "Method"] = f"{medals[i]} {df.loc[i, 'Method']}"
return df
def load_groundtruth(dataset_name: str, token: str = None) -> Dict[str, str]:
"""
Load ground truth Q&A pairs from HuggingFace dataset.
Expected schema in the dataset:
{
"episode_id": "string",
"qa_pairs": [
{
"question": "string",
"answer": "string",
"type": "string",
"sub_type": "string"
}
]
}
Args:
dataset_name: HuggingFace dataset name (e.g., "Pettingllms/AMA-bench")
token: Optional HuggingFace token for private datasets
Returns:
Dictionary mapping (episode_id, question) to answer info
"""
groundtruth = {}
try:
from datasets import load_dataset, VerificationMode
# Try loading from HuggingFace dataset
try:
dataset = load_dataset(
dataset_name,
split="test",
token=token,
verification_mode=VerificationMode.NO_CHECKS,
trust_remote_code=True
)
print(f"Loaded dataset from HuggingFace: {dataset_name}")
for row in dataset:
episode_id = row.get("episode_id", "")
domain = row.get("domain", "")
qa_pairs = row.get("qa_pairs", [])
for qa in qa_pairs:
question = qa.get("question", "")
answer = qa.get("answer", "")
qa_type = qa.get("type", "")
# Create unique key for this Q&A pair
key = f"{episode_id}_{question}"
groundtruth[key] = {
"answer": answer,
"type": qa_type,
"sub_type": qa.get("sub_type", ""),
"domain": domain,
}
except Exception as hf_error:
print(f"Warning: Could not load from HuggingFace ({hf_error})")
print("Trying local file test/test.jsonl...")
# Fallback to local file
import json
local_path = "test/open_end_qa_set.jsonl"
try:
with open(local_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
continue
data = json.loads(line)
episode_id = data.get("episode_id", "")
domain = data.get("domain", "")
qa_pairs = data.get("qa_pairs", [])
for qa in qa_pairs:
question = qa.get("question", "")
answer = qa.get("answer", "")
qa_type = qa.get("type", "")
# Create unique key for this Q&A pair
key = f"{episode_id}_{question}"
groundtruth[key] = {
"answer": answer,
"type": qa_type,
"sub_type": qa.get("sub_type", ""),
"domain": domain,
}
print(f"Loaded from local file: {local_path}")
except FileNotFoundError:
print(f"Warning: Local ground truth file not found: {local_path}")
except Exception as e:
print(f"Warning: Error loading local ground truth: {e}")
except ImportError:
print("Warning: datasets library not available, cannot load ground truth")
return groundtruth
def validate_submission_file(file_path: str) -> tuple:
"""
Validate submission file format.
Expected format:
{"episode_id": "...", "question": "...", "answer": "...", ...}
Args:
file_path: Path to submission JSONL file
Returns:
Tuple of (is_valid, error_message, submissions_list)
"""
import json
submissions = []
seen_pairs = set()
try:
with open(file_path, 'r', encoding='utf-8') as f:
for ix, line in enumerate(f):
line = line.strip()
if not line:
continue
try:
task = json.loads(line)
except json.JSONDecodeError:
return False, f"Line {ix+1} is incorrectly formatted JSON.", []
# Check required fields
required_fields = ["episode_id", "question", "answer"]
for field in required_fields:
if field not in task:
return False, f"Line {ix+1} is missing required field '{field}'.", []
episode_id = task["episode_id"]
question = task["question"]
pair_key = (episode_id, question)
if pair_key in seen_pairs:
return False, f"Line {ix+1} contains duplicate episode_id/question pair.", []
seen_pairs.add(pair_key)
submissions.append(task)
if len(submissions) == 0:
return False, "No valid submissions found in the file.", []
return True, "", submissions
except FileNotFoundError:
return False, "File not found.", []
except Exception as e:
return False, f"Error reading file: {str(e)}", []