Spaces:
Running
Running
| """ | |
| Utility functions for AMA-Bench Leaderboard. | |
| This module contains helper functions for: | |
| - DataFrame building and manipulation | |
| - Chart generation | |
| - Data validation | |
| """ | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| from typing import List, Dict | |
| # Metrics configuration | |
| METRICS = ["Recall", "Causal Inference", "State Updating", "State Abstraction"] | |
| ALL_METRICS = METRICS + ["Average"] | |
| # Chart colors moved to visualization.py | |
| def build_dataframe(data: Dict) -> pd.DataFrame: | |
| """ | |
| Build a pandas DataFrame showing Accuracy for each metric. | |
| Args: | |
| data: Dictionary with 'entries' key containing list of results | |
| Returns: | |
| DataFrame with Method and metric columns | |
| """ | |
| rows = [] | |
| for entry in data["entries"]: | |
| row = {"Method": entry["method"]} | |
| if entry.get("category"): | |
| row["Category"] = entry["category"] | |
| for m in ALL_METRICS: | |
| accuracy = entry["scores"][m]["accuracy"] | |
| row[m] = f"{accuracy:.4f}" | |
| # Store raw average accuracy for sorting | |
| row["_sort_avg"] = entry["scores"]["Average"]["accuracy"] | |
| rows.append(row) | |
| df = pd.DataFrame(rows) | |
| df = df.sort_values("_sort_avg", ascending=False).reset_index(drop=True) | |
| df = df.drop(columns=["_sort_avg"]) | |
| return df | |
| def add_medals(df: pd.DataFrame) -> pd.DataFrame: | |
| """ | |
| Add medal emojis to the top-3 Method names. | |
| Args: | |
| df: DataFrame with 'Method' column | |
| Returns: | |
| DataFrame with medals added to top 3 methods | |
| """ | |
| df = df.copy() | |
| medals = ["\U0001f947", "\U0001f948", "\U0001f949"] # 🥇 🥈 🥉 | |
| for i in range(min(3, len(df))): | |
| df.loc[i, "Method"] = f"{medals[i]} {df.loc[i, 'Method']}" | |
| return df | |
| def load_groundtruth(dataset_name: str, token: str = None) -> Dict[str, str]: | |
| """ | |
| Load ground truth Q&A pairs from HuggingFace dataset. | |
| Expected schema in the dataset: | |
| { | |
| "episode_id": "string", | |
| "qa_pairs": [ | |
| { | |
| "question": "string", | |
| "answer": "string", | |
| "type": "string", | |
| "sub_type": "string" | |
| } | |
| ] | |
| } | |
| Args: | |
| dataset_name: HuggingFace dataset name (e.g., "Pettingllms/AMA-bench") | |
| token: Optional HuggingFace token for private datasets | |
| Returns: | |
| Dictionary mapping (episode_id, question) to answer info | |
| """ | |
| groundtruth = {} | |
| try: | |
| from datasets import load_dataset, VerificationMode | |
| # Try loading from HuggingFace dataset | |
| try: | |
| dataset = load_dataset( | |
| dataset_name, | |
| split="test", | |
| token=token, | |
| verification_mode=VerificationMode.NO_CHECKS, | |
| trust_remote_code=True | |
| ) | |
| print(f"Loaded dataset from HuggingFace: {dataset_name}") | |
| for row in dataset: | |
| episode_id = row.get("episode_id", "") | |
| domain = row.get("domain", "") | |
| qa_pairs = row.get("qa_pairs", []) | |
| for qa in qa_pairs: | |
| question = qa.get("question", "") | |
| answer = qa.get("answer", "") | |
| qa_type = qa.get("type", "") | |
| # Create unique key for this Q&A pair | |
| key = f"{episode_id}_{question}" | |
| groundtruth[key] = { | |
| "answer": answer, | |
| "type": qa_type, | |
| "sub_type": qa.get("sub_type", ""), | |
| "domain": domain, | |
| } | |
| except Exception as hf_error: | |
| print(f"Warning: Could not load from HuggingFace ({hf_error})") | |
| print("Trying local file test/test.jsonl...") | |
| # Fallback to local file | |
| import json | |
| local_path = "test/open_end_qa_set.jsonl" | |
| try: | |
| with open(local_path, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| data = json.loads(line) | |
| episode_id = data.get("episode_id", "") | |
| domain = data.get("domain", "") | |
| qa_pairs = data.get("qa_pairs", []) | |
| for qa in qa_pairs: | |
| question = qa.get("question", "") | |
| answer = qa.get("answer", "") | |
| qa_type = qa.get("type", "") | |
| # Create unique key for this Q&A pair | |
| key = f"{episode_id}_{question}" | |
| groundtruth[key] = { | |
| "answer": answer, | |
| "type": qa_type, | |
| "sub_type": qa.get("sub_type", ""), | |
| "domain": domain, | |
| } | |
| print(f"Loaded from local file: {local_path}") | |
| except FileNotFoundError: | |
| print(f"Warning: Local ground truth file not found: {local_path}") | |
| except Exception as e: | |
| print(f"Warning: Error loading local ground truth: {e}") | |
| except ImportError: | |
| print("Warning: datasets library not available, cannot load ground truth") | |
| return groundtruth | |
| def validate_submission_file(file_path: str) -> tuple: | |
| """ | |
| Validate submission file format. | |
| Expected format: | |
| {"episode_id": "...", "question": "...", "answer": "...", ...} | |
| Args: | |
| file_path: Path to submission JSONL file | |
| Returns: | |
| Tuple of (is_valid, error_message, submissions_list) | |
| """ | |
| import json | |
| submissions = [] | |
| seen_pairs = set() | |
| try: | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| for ix, line in enumerate(f): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| task = json.loads(line) | |
| except json.JSONDecodeError: | |
| return False, f"Line {ix+1} is incorrectly formatted JSON.", [] | |
| # Check required fields | |
| required_fields = ["episode_id", "question", "answer"] | |
| for field in required_fields: | |
| if field not in task: | |
| return False, f"Line {ix+1} is missing required field '{field}'.", [] | |
| episode_id = task["episode_id"] | |
| question = task["question"] | |
| pair_key = (episode_id, question) | |
| if pair_key in seen_pairs: | |
| return False, f"Line {ix+1} contains duplicate episode_id/question pair.", [] | |
| seen_pairs.add(pair_key) | |
| submissions.append(task) | |
| if len(submissions) == 0: | |
| return False, "No valid submissions found in the file.", [] | |
| return True, "", submissions | |
| except FileNotFoundError: | |
| return False, "File not found.", [] | |
| except Exception as e: | |
| return False, f"Error reading file: {str(e)}", [] |