mle-bench-tabular / tests /test_leaderboard.py
Sunmarinup's picture
Update table view (#2)
1498561 verified
raw
history blame
11.3 kB
from unittest.mock import patch
import pandas as pd
import pytest
import requests
from app import load_leaderboard, refresh_leaderboard
from src.leaderboard.columns import DisplayColumns, RequiredInputColumns
@pytest.fixture
def sample_csv_data():
"""Sample CSV data matching the expected leaderboard format."""
return (
"experiment_id,mean_normalized_score,std_normalized_score,"
"mean_medal_pct,sem_medal_pct,Agent,LLM(s) used,Date\n"
"exp_001,0.854321,0.012345,0.876543,0.009876,Agent A,GPT-4,2024-01-15\n"
"exp_002,0.789012,0.023456,0.765432,0.012345,Agent B,Claude-3,2024-01-20\n"
"exp_003,0.912345,0.008765,0.923456,0.007654,Agent C,GPT-4,2024-02-01"
)
@pytest.fixture
def sample_csv_with_extra_columns():
"""Sample CSV with extra columns that should be filtered out."""
return (
"experiment_id,mean_normalized_score,std_normalized_score,"
"mean_medal_pct,sem_medal_pct,Agent,LLM(s) used,Date,extra_col\n"
"exp_001,0.854321,0.012345,0.876543,0.009876,Agent A,GPT-4,2024-01-15,extra_value\n"
"exp_002,0.789012,0.023456,0.765432,0.012345,Agent B,Claude-3,2024-01-20,extra_value"
)
@pytest.fixture
def sample_csv_missing_columns():
"""Sample CSV missing required columns."""
return """experiment_id,mean_normalized_score,Agent
exp_001,0.854321,Agent A
exp_002,0.789012,Agent B"""
class TestDownloadLeaderboard:
"""Tests for download_leaderboard function."""
@patch("src.leaderboard.input.download_github_file_content")
def test_successful_download(self, mock_download, sample_csv_data):
"""Test successful download and parsing of leaderboard."""
# Setup mock to return CSV content directly
mock_download.return_value = sample_csv_data
# Execute
df = load_leaderboard()
# Assertions
assert isinstance(df, pd.DataFrame)
assert len(df) == 3
assert all(col in df.columns for col in DisplayColumns.values())
mock_download.assert_called_once()
@patch("src.leaderboard.input.download_github_file_content")
def test_data_cleaning_rounding(self, mock_download, sample_csv_data):
"""Test that numeric columns are properly formatted as mean ± std."""
mock_download.return_value = sample_csv_data
df = load_leaderboard()
# Check that scores are formatted as strings with mean ± std
# df is sorted by score descending: exp_003 (0.912), exp_001 (0.854), exp_002 (0.789)
assert df.iloc[0][DisplayColumns.NORMALIZED_SCORE] == "0.912 ± 0.009"
assert df.iloc[1][DisplayColumns.NORMALIZED_SCORE] == "0.854 ± 0.012"
assert df.iloc[2][DisplayColumns.NORMALIZED_SCORE] == "0.789 ± 0.023"
# Check that scores are strings
assert isinstance(df.iloc[0][DisplayColumns.NORMALIZED_SCORE], str)
@patch("src.leaderboard.input.download_github_file_content")
def test_percentage_conversion(self, mock_download, sample_csv_data):
"""Test that medal percentages are converted from decimal to percentage and formatted."""
mock_download.return_value = sample_csv_data
df = load_leaderboard()
# Check percentage conversion and formatting (0.876543 * 100 = 87.6543, rounded to 87.7)
# df is sorted by score descending: exp_003 (92.3), exp_001 (87.7), exp_002 (76.5)
assert df.iloc[0][DisplayColumns.ANY_MEDAL_SCORE] == "92.3 ± 0.8" # exp_003
assert df.iloc[1][DisplayColumns.ANY_MEDAL_SCORE] == "87.7 ± 1.0" # exp_001
assert df.iloc[2][DisplayColumns.ANY_MEDAL_SCORE] == "76.5 ± 1.2" # exp_002
@patch("src.leaderboard.input.download_github_file_content")
def test_date_formatting(self, mock_download, sample_csv_data):
"""Test that dates are properly formatted."""
mock_download.return_value = sample_csv_data
df = load_leaderboard()
# Check date formatting - df sorted by score descending
# exp_003 (2024-02-01), exp_001 (2024-01-15), exp_002 (2024-01-20)
assert df.iloc[0][DisplayColumns.DATE] == "2024-02-01"
assert df.iloc[1][DisplayColumns.DATE] == "2024-01-15"
assert df.iloc[2][DisplayColumns.DATE] == "2024-01-20"
@patch("src.leaderboard.input.download_github_file_content")
def test_sorting(self, mock_download, sample_csv_data):
"""Test that df is sorted by mean_normalized_score descending."""
mock_download.return_value = sample_csv_data
df = load_leaderboard()
# Check sorting (highest score first)
# Extract numeric scores from formatted strings for comparison
scores = [float(score.split(" ± ")[0]) for score in df[DisplayColumns.NORMALIZED_SCORE]]
assert scores == sorted(scores, reverse=True)
assert df.iloc[0][DisplayColumns.EXPERIMENT_NAME] == "exp_003" # Highest score
assert df.iloc[2][DisplayColumns.EXPERIMENT_NAME] == "exp_002" # Lowest score
@patch("src.leaderboard.input.download_github_file_content")
def test_extra_columns_filtered(self, mock_download, sample_csv_with_extra_columns):
"""Test that extra columns are filtered out."""
mock_download.return_value = sample_csv_with_extra_columns
df = load_leaderboard()
# Check that df is created correctly (extra columns should be filtered)
assert len(df) == 2
assert set(df.columns) == set(
DisplayColumns.values() + [RequiredInputColumns.MEAN_NORMALIZED_SCORE, RequiredInputColumns.MEAN_MEDAL_PCT]
)
# Verify the df doesn't have extra columns
assert "extra_col" not in df.columns
@patch("src.leaderboard.input.download_github_file_content")
def test_missing_columns_error(self, mock_download, sample_csv_missing_columns):
"""Test that missing required columns raise ValueError."""
mock_download.return_value = sample_csv_missing_columns
with pytest.raises(ValueError, match="Leaderboard is missing expected columns"):
load_leaderboard()
@patch("src.leaderboard.input.download_github_file_content")
def test_http_error(self, mock_download):
"""Test handling of HTTP errors."""
mock_download.side_effect = requests.HTTPError("404 Not Found")
with pytest.raises(requests.HTTPError):
load_leaderboard()
@patch("src.leaderboard.input.download_github_file_content")
def test_network_error(self, mock_download):
"""Test handling of network errors."""
mock_download.side_effect = requests.ConnectionError("Connection failed")
with pytest.raises(requests.ConnectionError):
load_leaderboard()
@patch("src.leaderboard.input.download_github_file_content")
def test_timeout_handling(self, mock_download):
"""Test that timeout parameter is passed correctly."""
csv_data = (
"experiment_id,mean_normalized_score,std_normalized_score,"
"mean_medal_pct,sem_medal_pct,Agent,LLM(s) used,Date\n"
"exp_001,0.85,0.01,0.87,0.01,Agent A,GPT-4,2024-01-15"
)
mock_download.return_value = csv_data
load_leaderboard()
# Verify timeout was passed to download_github_file_content
mock_download.assert_called_once()
call_args, call_kwargs = mock_download.call_args
assert call_kwargs["timeout"] == 30
@patch("src.leaderboard.input.download_github_file_content")
def test_empty_dataframe(self, mock_download):
"""Test handling of empty CSV (header only)."""
# Use the required input columns for empty CSV
csv_data = (
"experiment_id,mean_normalized_score,std_normalized_score,"
"mean_medal_pct,sem_medal_pct,Agent,LLM(s) used,Date"
)
mock_download.return_value = csv_data
df = load_leaderboard()
assert isinstance(df, pd.DataFrame)
assert len(df) == 0
assert list(df.columns) == DisplayColumns.values()
@patch("src.leaderboard.input.download_github_file_content")
def test_invalid_date_handling(self, mock_download):
"""Test that invalid dates are handled gracefully."""
csv_with_invalid_date = (
"experiment_id,mean_normalized_score,std_normalized_score,"
"mean_medal_pct,sem_medal_pct,Agent,LLM(s) used,Date\n"
"exp_001,0.854321,0.012345,0.876543,0.009876,Agent A,GPT-4,invalid-date\n"
"exp_002,0.789012,0.023456,0.765432,0.012345,Agent B,Claude-3,2024-01-20"
)
mock_download.return_value = csv_with_invalid_date
df = load_leaderboard()
# Invalid dates should become NaT and then "nan" string
# Find rows by Experiment Name since order may vary
row_001 = df[df[DisplayColumns.EXPERIMENT_NAME] == "exp_001"].iloc[0]
row_002 = df[df[DisplayColumns.EXPERIMENT_NAME] == "exp_002"].iloc[0]
assert pd.isna(row_001[DisplayColumns.DATE])
assert row_002[DisplayColumns.DATE] == "2024-01-20"
class TestRefreshLeaderboard:
"""Tests for refresh_leaderboard function."""
@patch("app.load_leaderboard")
def test_refresh_leaderboard_success(self, mock_download):
"""Test that refresh_leaderboard returns dataframe and status message."""
# Setup mocks
mock_df = pd.DataFrame(
{
DisplayColumns.EXPERIMENT_NAME: ["exp_001"],
DisplayColumns.AGENT: ["Agent A"],
DisplayColumns.LLM_USED: ["GPT-4"],
RequiredInputColumns.MEAN_NORMALIZED_SCORE: [0.850],
DisplayColumns.NORMALIZED_SCORE: ["0.850 ± 0.010"],
RequiredInputColumns.MEAN_MEDAL_PCT: [0.850],
DisplayColumns.ANY_MEDAL_SCORE: ["85.0 ± 1.0"],
DisplayColumns.DATE: ["2024-01-15"],
}
)
mock_download.return_value = mock_df
# Execute
df, status = refresh_leaderboard()
# Assertions
assert df is not None
assert "Showing data from" in status
assert "GitHub" in status
# Check that status contains timestamp in expected format (YYYY-MM-DD HH:MM UTC)
assert "UTC" in status
assert "Last refreshed:" in status
# Verify timestamp format (should match pattern YYYY-MM-DD HH:MM)
import re
timestamp_pattern = r"\d{4}-\d{2}-\d{2} \d{2}:\d{2} UTC"
assert re.search(timestamp_pattern, status) is not None
mock_download.assert_called_once()
@patch("app.load_leaderboard")
def test_refresh_leaderboard_includes_url(self, mock_download):
"""Test that status message includes the GitHub URL."""
mock_df = pd.DataFrame()
mock_download.return_value = mock_df
df, status = refresh_leaderboard()
assert "github.com" in status.lower() or "GitHub" in status
assert "upgini/mle-bench" in status
@patch("app.load_leaderboard")
def test_refresh_leaderboard_propagates_error(self, mock_download):
"""Test that errors from download_leaderboard are propagated."""
mock_download.side_effect = requests.HTTPError("404 Not Found")
with pytest.raises(requests.HTTPError):
refresh_leaderboard()