""" Utility functions for AMA-Bench Leaderboard. This module contains helper functions for: - DataFrame building and manipulation - Chart generation - Data validation """ import pandas as pd import plotly.graph_objects as go from typing import List, Dict # Metrics configuration METRICS = ["Recall", "Causal Inference", "State Updating", "State Abstraction"] ALL_METRICS = METRICS + ["Average"] # Chart colors moved to visualization.py def build_dataframe(data: Dict) -> pd.DataFrame: """ Build a pandas DataFrame showing Accuracy for each metric. Args: data: Dictionary with 'entries' key containing list of results Returns: DataFrame with Method and metric columns """ rows = [] for entry in data["entries"]: row = {"Method": entry["method"]} if entry.get("category"): row["Category"] = entry["category"] for m in ALL_METRICS: accuracy = entry["scores"][m]["accuracy"] row[m] = f"{accuracy:.4f}" # Store raw average accuracy for sorting row["_sort_avg"] = entry["scores"]["Average"]["accuracy"] rows.append(row) df = pd.DataFrame(rows) df = df.sort_values("_sort_avg", ascending=False).reset_index(drop=True) df = df.drop(columns=["_sort_avg"]) return df def add_medals(df: pd.DataFrame) -> pd.DataFrame: """ Add medal emojis to the top-3 Method names. Args: df: DataFrame with 'Method' column Returns: DataFrame with medals added to top 3 methods """ df = df.copy() medals = ["\U0001f947", "\U0001f948", "\U0001f949"] # 🥇 🥈 🥉 for i in range(min(3, len(df))): df.loc[i, "Method"] = f"{medals[i]} {df.loc[i, 'Method']}" return df def load_groundtruth(dataset_name: str, token: str = None) -> Dict[str, str]: """ Load ground truth Q&A pairs from HuggingFace dataset. Expected schema in the dataset: { "episode_id": "string", "qa_pairs": [ { "question": "string", "answer": "string", "type": "string", "sub_type": "string" } ] } Args: dataset_name: HuggingFace dataset name (e.g., "Pettingllms/AMA-bench") token: Optional HuggingFace token for private datasets Returns: Dictionary mapping (episode_id, question) to answer info """ groundtruth = {} try: from datasets import load_dataset, VerificationMode # Try loading from HuggingFace dataset try: dataset = load_dataset( dataset_name, split="test", token=token, verification_mode=VerificationMode.NO_CHECKS, trust_remote_code=True ) print(f"Loaded dataset from HuggingFace: {dataset_name}") for row in dataset: episode_id = row.get("episode_id", "") domain = row.get("domain", "") qa_pairs = row.get("qa_pairs", []) for qa in qa_pairs: question = qa.get("question", "") answer = qa.get("answer", "") qa_type = qa.get("type", "") # Create unique key for this Q&A pair key = f"{episode_id}_{question}" groundtruth[key] = { "answer": answer, "type": qa_type, "sub_type": qa.get("sub_type", ""), "domain": domain, } except Exception as hf_error: print(f"Warning: Could not load from HuggingFace ({hf_error})") print("Trying local file test/test.jsonl...") # Fallback to local file import json local_path = "test/open_end_qa_set.jsonl" try: with open(local_path, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if not line: continue data = json.loads(line) episode_id = data.get("episode_id", "") domain = data.get("domain", "") qa_pairs = data.get("qa_pairs", []) for qa in qa_pairs: question = qa.get("question", "") answer = qa.get("answer", "") qa_type = qa.get("type", "") # Create unique key for this Q&A pair key = f"{episode_id}_{question}" groundtruth[key] = { "answer": answer, "type": qa_type, "sub_type": qa.get("sub_type", ""), "domain": domain, } print(f"Loaded from local file: {local_path}") except FileNotFoundError: print(f"Warning: Local ground truth file not found: {local_path}") except Exception as e: print(f"Warning: Error loading local ground truth: {e}") except ImportError: print("Warning: datasets library not available, cannot load ground truth") return groundtruth def validate_submission_file(file_path: str) -> tuple: """ Validate submission file format. Expected format: {"episode_id": "...", "question": "...", "answer": "...", ...} Args: file_path: Path to submission JSONL file Returns: Tuple of (is_valid, error_message, submissions_list) """ import json submissions = [] seen_pairs = set() try: with open(file_path, 'r', encoding='utf-8') as f: for ix, line in enumerate(f): line = line.strip() if not line: continue try: task = json.loads(line) except json.JSONDecodeError: return False, f"Line {ix+1} is incorrectly formatted JSON.", [] # Check required fields required_fields = ["episode_id", "question", "answer"] for field in required_fields: if field not in task: return False, f"Line {ix+1} is missing required field '{field}'.", [] episode_id = task["episode_id"] question = task["question"] pair_key = (episode_id, question) if pair_key in seen_pairs: return False, f"Line {ix+1} contains duplicate episode_id/question pair.", [] seen_pairs.add(pair_key) submissions.append(task) if len(submissions) == 0: return False, "No valid submissions found in the file.", [] return True, "", submissions except FileNotFoundError: return False, "File not found.", [] except Exception as e: return False, f"Error reading file: {str(e)}", []