Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
refactor: use enum class for the task type
Browse files- app.py +9 -8
- src/benchmarks.py +8 -7
- src/loaders.py +22 -29
- src/models.py +13 -2
- src/utils.py +45 -54
- tests/test_utils.py +2 -2
app.py
CHANGED
|
@@ -35,6 +35,7 @@ from src.envs import (
|
|
| 35 |
TOKEN,
|
| 36 |
)
|
| 37 |
from src.loaders import load_eval_results
|
|
|
|
| 38 |
from src.utils import remove_html, reset_rank, set_listeners, submit_results, update_metric, upload_file
|
| 39 |
|
| 40 |
|
|
@@ -75,7 +76,7 @@ def update_qa_metric(
|
|
| 75 |
global datastore
|
| 76 |
return update_metric(
|
| 77 |
datastore,
|
| 78 |
-
|
| 79 |
metric,
|
| 80 |
domains,
|
| 81 |
langs,
|
|
@@ -98,7 +99,7 @@ def update_doc_metric(
|
|
| 98 |
global datastore
|
| 99 |
return update_metric(
|
| 100 |
datastore,
|
| 101 |
-
|
| 102 |
metric,
|
| 103 |
domains,
|
| 104 |
langs,
|
|
@@ -181,7 +182,7 @@ with demo:
|
|
| 181 |
)
|
| 182 |
|
| 183 |
set_listeners(
|
| 184 |
-
|
| 185 |
qa_df_elem_ret_rerank,
|
| 186 |
qa_df_elem_ret_rerank_hidden,
|
| 187 |
search_bar,
|
|
@@ -224,7 +225,7 @@ with demo:
|
|
| 224 |
)
|
| 225 |
|
| 226 |
set_listeners(
|
| 227 |
-
|
| 228 |
qa_df_elem_ret,
|
| 229 |
qa_df_elem_ret_hidden,
|
| 230 |
search_bar_ret,
|
|
@@ -281,7 +282,7 @@ with demo:
|
|
| 281 |
)
|
| 282 |
|
| 283 |
set_listeners(
|
| 284 |
-
|
| 285 |
qa_df_elem_rerank,
|
| 286 |
qa_df_elem_rerank_hidden,
|
| 287 |
qa_search_bar_rerank,
|
|
@@ -348,7 +349,7 @@ with demo:
|
|
| 348 |
)
|
| 349 |
|
| 350 |
set_listeners(
|
| 351 |
-
|
| 352 |
doc_df_elem_ret_rerank,
|
| 353 |
doc_df_elem_ret_rerank_hidden,
|
| 354 |
search_bar,
|
|
@@ -405,7 +406,7 @@ with demo:
|
|
| 405 |
)
|
| 406 |
|
| 407 |
set_listeners(
|
| 408 |
-
|
| 409 |
doc_df_elem_ret,
|
| 410 |
doc_df_elem_ret_hidden,
|
| 411 |
search_bar_ret,
|
|
@@ -462,7 +463,7 @@ with demo:
|
|
| 462 |
)
|
| 463 |
|
| 464 |
set_listeners(
|
| 465 |
-
|
| 466 |
doc_df_elem_rerank,
|
| 467 |
doc_df_elem_rerank_hidden,
|
| 468 |
doc_search_bar_rerank,
|
|
|
|
| 35 |
TOKEN,
|
| 36 |
)
|
| 37 |
from src.loaders import load_eval_results
|
| 38 |
+
from src.models import TaskType
|
| 39 |
from src.utils import remove_html, reset_rank, set_listeners, submit_results, update_metric, upload_file
|
| 40 |
|
| 41 |
|
|
|
|
| 76 |
global datastore
|
| 77 |
return update_metric(
|
| 78 |
datastore,
|
| 79 |
+
TaskType.qa,
|
| 80 |
metric,
|
| 81 |
domains,
|
| 82 |
langs,
|
|
|
|
| 99 |
global datastore
|
| 100 |
return update_metric(
|
| 101 |
datastore,
|
| 102 |
+
TaskType.long_doc,
|
| 103 |
metric,
|
| 104 |
domains,
|
| 105 |
langs,
|
|
|
|
| 182 |
)
|
| 183 |
|
| 184 |
set_listeners(
|
| 185 |
+
TaskType.qa,
|
| 186 |
qa_df_elem_ret_rerank,
|
| 187 |
qa_df_elem_ret_rerank_hidden,
|
| 188 |
search_bar,
|
|
|
|
| 225 |
)
|
| 226 |
|
| 227 |
set_listeners(
|
| 228 |
+
TaskType.qa,
|
| 229 |
qa_df_elem_ret,
|
| 230 |
qa_df_elem_ret_hidden,
|
| 231 |
search_bar_ret,
|
|
|
|
| 282 |
)
|
| 283 |
|
| 284 |
set_listeners(
|
| 285 |
+
TaskType.qa,
|
| 286 |
qa_df_elem_rerank,
|
| 287 |
qa_df_elem_rerank_hidden,
|
| 288 |
qa_search_bar_rerank,
|
|
|
|
| 349 |
)
|
| 350 |
|
| 351 |
set_listeners(
|
| 352 |
+
TaskType.long_doc,
|
| 353 |
doc_df_elem_ret_rerank,
|
| 354 |
doc_df_elem_ret_rerank_hidden,
|
| 355 |
search_bar,
|
|
|
|
| 406 |
)
|
| 407 |
|
| 408 |
set_listeners(
|
| 409 |
+
TaskType.long_doc,
|
| 410 |
doc_df_elem_ret,
|
| 411 |
doc_df_elem_ret_hidden,
|
| 412 |
search_bar_ret,
|
|
|
|
| 463 |
)
|
| 464 |
|
| 465 |
set_listeners(
|
| 466 |
+
TaskType.long_doc,
|
| 467 |
doc_df_elem_rerank,
|
| 468 |
doc_df_elem_rerank_hidden,
|
| 469 |
doc_search_bar_rerank,
|
src/benchmarks.py
CHANGED
|
@@ -4,6 +4,7 @@ from enum import Enum
|
|
| 4 |
from air_benchmark.tasks.tasks import BenchmarkTable
|
| 5 |
|
| 6 |
from src.envs import BENCHMARK_VERSION_LIST, METRIC_LIST
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
def get_safe_name(name: str):
|
|
@@ -23,11 +24,11 @@ class Benchmark:
|
|
| 23 |
|
| 24 |
|
| 25 |
# create a function return an enum class containing all the benchmarks
|
| 26 |
-
def get_benchmarks_enum(benchmark_version, task_type):
|
| 27 |
benchmark_dict = {}
|
| 28 |
-
if task_type ==
|
| 29 |
for task, domain_dict in BenchmarkTable[benchmark_version].items():
|
| 30 |
-
if task != task_type:
|
| 31 |
continue
|
| 32 |
for domain, lang_dict in domain_dict.items():
|
| 33 |
for lang, dataset_list in lang_dict.items():
|
|
@@ -39,9 +40,9 @@ def get_benchmarks_enum(benchmark_version, task_type):
|
|
| 39 |
benchmark_dict[benchmark_name] = Benchmark(
|
| 40 |
benchmark_name, metric, col_name, domain, lang, task
|
| 41 |
)
|
| 42 |
-
elif task_type ==
|
| 43 |
for task, domain_dict in BenchmarkTable[benchmark_version].items():
|
| 44 |
-
if task != task_type:
|
| 45 |
continue
|
| 46 |
for domain, lang_dict in domain_dict.items():
|
| 47 |
for lang, dataset_list in lang_dict.items():
|
|
@@ -62,14 +63,14 @@ qa_benchmark_dict = {}
|
|
| 62 |
for version in BENCHMARK_VERSION_LIST:
|
| 63 |
safe_version_name = get_safe_name(version)[-4:]
|
| 64 |
qa_benchmark_dict[safe_version_name] = Enum(
|
| 65 |
-
f"QABenchmarks_{safe_version_name}", get_benchmarks_enum(version,
|
| 66 |
)
|
| 67 |
|
| 68 |
long_doc_benchmark_dict = {}
|
| 69 |
for version in BENCHMARK_VERSION_LIST:
|
| 70 |
safe_version_name = get_safe_name(version)[-4:]
|
| 71 |
long_doc_benchmark_dict[safe_version_name] = Enum(
|
| 72 |
-
f"LongDocBenchmarks_{safe_version_name}", get_benchmarks_enum(version,
|
| 73 |
)
|
| 74 |
|
| 75 |
|
|
|
|
| 4 |
from air_benchmark.tasks.tasks import BenchmarkTable
|
| 5 |
|
| 6 |
from src.envs import BENCHMARK_VERSION_LIST, METRIC_LIST
|
| 7 |
+
from src.models import TaskType
|
| 8 |
|
| 9 |
|
| 10 |
def get_safe_name(name: str):
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
# create a function return an enum class containing all the benchmarks
|
| 27 |
+
def get_benchmarks_enum(benchmark_version: str, task_type: TaskType):
|
| 28 |
benchmark_dict = {}
|
| 29 |
+
if task_type == TaskType.qa:
|
| 30 |
for task, domain_dict in BenchmarkTable[benchmark_version].items():
|
| 31 |
+
if task != task_type.value:
|
| 32 |
continue
|
| 33 |
for domain, lang_dict in domain_dict.items():
|
| 34 |
for lang, dataset_list in lang_dict.items():
|
|
|
|
| 40 |
benchmark_dict[benchmark_name] = Benchmark(
|
| 41 |
benchmark_name, metric, col_name, domain, lang, task
|
| 42 |
)
|
| 43 |
+
elif task_type == TaskType.long_doc:
|
| 44 |
for task, domain_dict in BenchmarkTable[benchmark_version].items():
|
| 45 |
+
if task != task_type.value:
|
| 46 |
continue
|
| 47 |
for domain, lang_dict in domain_dict.items():
|
| 48 |
for lang, dataset_list in lang_dict.items():
|
|
|
|
| 63 |
for version in BENCHMARK_VERSION_LIST:
|
| 64 |
safe_version_name = get_safe_name(version)[-4:]
|
| 65 |
qa_benchmark_dict[safe_version_name] = Enum(
|
| 66 |
+
f"QABenchmarks_{safe_version_name}", get_benchmarks_enum(version, TaskType.qa)
|
| 67 |
)
|
| 68 |
|
| 69 |
long_doc_benchmark_dict = {}
|
| 70 |
for version in BENCHMARK_VERSION_LIST:
|
| 71 |
safe_version_name = get_safe_name(version)[-4:]
|
| 72 |
long_doc_benchmark_dict[safe_version_name] = Enum(
|
| 73 |
+
f"LongDocBenchmarks_{safe_version_name}", get_benchmarks_enum(version, TaskType.long_doc)
|
| 74 |
)
|
| 75 |
|
| 76 |
|
src/loaders.py
CHANGED
|
@@ -11,7 +11,7 @@ from src.envs import (
|
|
| 11 |
DEFAULT_METRIC_LONG_DOC,
|
| 12 |
DEFAULT_METRIC_QA,
|
| 13 |
)
|
| 14 |
-
from src.models import FullEvalResult, LeaderboardDataStore
|
| 15 |
from src.utils import get_default_cols, get_leaderboard_df
|
| 16 |
|
| 17 |
pd.options.mode.copy_on_write = True
|
|
@@ -64,34 +64,27 @@ def get_safe_name(name: str):
|
|
| 64 |
|
| 65 |
def load_leaderboard_datastore(file_path, version) -> LeaderboardDataStore:
|
| 66 |
slug = get_safe_name(version)[-4:]
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
print(f"raw data: {len(
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
print(f"QA data loaded: {
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
][shown_columns_long_doc]
|
| 89 |
-
lb_data_store.doc_fmt_df.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
|
| 90 |
-
|
| 91 |
-
lb_data_store.reranking_models = sorted(
|
| 92 |
-
list(frozenset([eval_result.reranking_model for eval_result in lb_data_store.raw_data]))
|
| 93 |
-
)
|
| 94 |
-
return lb_data_store
|
| 95 |
|
| 96 |
|
| 97 |
def load_eval_results(file_path: str) -> Dict[str, LeaderboardDataStore]:
|
|
|
|
| 11 |
DEFAULT_METRIC_LONG_DOC,
|
| 12 |
DEFAULT_METRIC_QA,
|
| 13 |
)
|
| 14 |
+
from src.models import FullEvalResult, LeaderboardDataStore, TaskType
|
| 15 |
from src.utils import get_default_cols, get_leaderboard_df
|
| 16 |
|
| 17 |
pd.options.mode.copy_on_write = True
|
|
|
|
| 64 |
|
| 65 |
def load_leaderboard_datastore(file_path, version) -> LeaderboardDataStore:
|
| 66 |
slug = get_safe_name(version)[-4:]
|
| 67 |
+
datastore = LeaderboardDataStore(version, slug, None, None, None, None, None, None, None, None)
|
| 68 |
+
datastore.raw_data = load_raw_eval_results(file_path)
|
| 69 |
+
print(f"raw data: {len(datastore.raw_data)}")
|
| 70 |
+
|
| 71 |
+
datastore.qa_raw_df = get_leaderboard_df(datastore, TaskType.qa, DEFAULT_METRIC_QA)
|
| 72 |
+
print(f"QA data loaded: {datastore.qa_raw_df.shape}")
|
| 73 |
+
datastore.qa_fmt_df = datastore.qa_raw_df.copy()
|
| 74 |
+
qa_cols, datastore.qa_types = get_default_cols(TaskType.qa, datastore.slug, add_fix_cols=True)
|
| 75 |
+
datastore.qa_fmt_df = datastore.qa_fmt_df[~datastore.qa_fmt_df[COL_NAME_IS_ANONYMOUS]][qa_cols]
|
| 76 |
+
datastore.qa_fmt_df.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
|
| 77 |
+
|
| 78 |
+
datastore.doc_raw_df = get_leaderboard_df(datastore, TaskType.long_doc, DEFAULT_METRIC_LONG_DOC)
|
| 79 |
+
print(f"Long-Doc data loaded: {len(datastore.doc_raw_df)}")
|
| 80 |
+
datastore.doc_fmt_df = datastore.doc_raw_df.copy()
|
| 81 |
+
doc_cols, datastore.doc_types = get_default_cols(TaskType.long_doc, datastore.slug, add_fix_cols=True)
|
| 82 |
+
datastore.doc_fmt_df = datastore.doc_fmt_df[~datastore.doc_fmt_df[COL_NAME_IS_ANONYMOUS]][doc_cols]
|
| 83 |
+
datastore.doc_fmt_df.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
|
| 84 |
+
|
| 85 |
+
datastore.reranking_models = \
|
| 86 |
+
sorted(list(frozenset([eval_result.reranking_model for eval_result in datastore.raw_data])))
|
| 87 |
+
return datastore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
|
| 90 |
def load_eval_results(file_path: str) -> Dict[str, LeaderboardDataStore]:
|
src/models.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
| 1 |
import json
|
|
|
|
|
|
|
| 2 |
from collections import defaultdict
|
| 3 |
from dataclasses import dataclass
|
| 4 |
from typing import List, Optional
|
| 5 |
|
| 6 |
import pandas as pd
|
| 7 |
|
| 8 |
-
from src.benchmarks import get_safe_name
|
| 9 |
from src.display.formatting import make_clickable_model
|
| 10 |
from src.envs import (
|
| 11 |
COL_NAME_IS_ANONYMOUS,
|
|
@@ -17,6 +18,10 @@ from src.envs import (
|
|
| 17 |
COL_NAME_TIMESTAMP,
|
| 18 |
)
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
@dataclass
|
| 22 |
class EvalResult:
|
|
@@ -147,4 +152,10 @@ class LeaderboardDataStore:
|
|
| 147 |
doc_fmt_df: Optional[pd.DataFrame]
|
| 148 |
reranking_models: Optional[list]
|
| 149 |
qa_types: Optional[list]
|
| 150 |
-
doc_types: Optional[list]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
+
from enum import Enum
|
| 3 |
+
|
| 4 |
from collections import defaultdict
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from typing import List, Optional
|
| 7 |
|
| 8 |
import pandas as pd
|
| 9 |
|
|
|
|
| 10 |
from src.display.formatting import make_clickable_model
|
| 11 |
from src.envs import (
|
| 12 |
COL_NAME_IS_ANONYMOUS,
|
|
|
|
| 18 |
COL_NAME_TIMESTAMP,
|
| 19 |
)
|
| 20 |
|
| 21 |
+
def get_safe_name(name: str):
|
| 22 |
+
"""Get RFC 1123 compatible safe name"""
|
| 23 |
+
name = name.replace("-", "_")
|
| 24 |
+
return "".join(character.lower() for character in name if (character.isalnum() or character == "_"))
|
| 25 |
|
| 26 |
@dataclass
|
| 27 |
class EvalResult:
|
|
|
|
| 152 |
doc_fmt_df: Optional[pd.DataFrame]
|
| 153 |
reranking_models: Optional[list]
|
| 154 |
qa_types: Optional[list]
|
| 155 |
+
doc_types: Optional[list]
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# Define an enum class with the name `TaskType`. There are two types of tasks, `qa` and `long-doc`.
|
| 159 |
+
class TaskType(Enum):
|
| 160 |
+
qa = "qa"
|
| 161 |
+
long_doc = "long-doc"
|
src/utils.py
CHANGED
|
@@ -6,6 +6,7 @@ from pathlib import Path
|
|
| 6 |
|
| 7 |
import pandas as pd
|
| 8 |
|
|
|
|
| 9 |
from src.benchmarks import LongDocBenchmarks, QABenchmarks
|
| 10 |
from src.display.columns import get_default_col_names_and_types, get_fixed_col_names_and_types
|
| 11 |
from src.display.formatting import styled_error, styled_message
|
|
@@ -69,12 +70,12 @@ def search_table(df: pd.DataFrame, query: str) -> pd.DataFrame:
|
|
| 69 |
return df[(df[COL_NAME_RETRIEVAL_MODEL].str.contains(query, case=False))]
|
| 70 |
|
| 71 |
|
| 72 |
-
def get_default_cols(task:
|
| 73 |
cols = []
|
| 74 |
types = []
|
| 75 |
-
if task ==
|
| 76 |
benchmarks = QABenchmarks[version_slug]
|
| 77 |
-
elif task ==
|
| 78 |
benchmarks = LongDocBenchmarks[version_slug]
|
| 79 |
else:
|
| 80 |
raise NotImplementedError
|
|
@@ -85,7 +86,6 @@ def get_default_cols(task: str, version_slug, add_fix_cols: bool = True) -> tupl
|
|
| 85 |
continue
|
| 86 |
cols.append(col_name)
|
| 87 |
types.append(col_type)
|
| 88 |
-
|
| 89 |
if add_fix_cols:
|
| 90 |
_cols = []
|
| 91 |
_types = []
|
|
@@ -104,16 +104,16 @@ def select_columns(
|
|
| 104 |
df: pd.DataFrame,
|
| 105 |
domain_query: list,
|
| 106 |
language_query: list,
|
| 107 |
-
task:
|
| 108 |
reset_ranking: bool = True,
|
| 109 |
version_slug: str = None,
|
| 110 |
) -> pd.DataFrame:
|
| 111 |
cols, _ = get_default_cols(task=task, version_slug=version_slug, add_fix_cols=False)
|
| 112 |
selected_cols = []
|
| 113 |
for c in cols:
|
| 114 |
-
if task ==
|
| 115 |
eval_col = QABenchmarks[version_slug].value[c].value
|
| 116 |
-
elif task ==
|
| 117 |
eval_col = LongDocBenchmarks[version_slug].value[c].value
|
| 118 |
else:
|
| 119 |
raise NotImplementedError
|
|
@@ -141,10 +141,10 @@ def get_safe_name(name: str):
|
|
| 141 |
return "".join(character.lower() for character in name if (character.isalnum() or character == "_"))
|
| 142 |
|
| 143 |
|
| 144 |
-
def
|
| 145 |
-
task:
|
| 146 |
version: str,
|
| 147 |
-
|
| 148 |
domains: list,
|
| 149 |
langs: list,
|
| 150 |
reranking_query: list,
|
|
@@ -154,7 +154,7 @@ def _update_table(
|
|
| 154 |
show_revision_and_timestamp: bool = False,
|
| 155 |
):
|
| 156 |
version_slug = get_safe_name(version)[-4:]
|
| 157 |
-
filtered_df =
|
| 158 |
if not show_anonymous:
|
| 159 |
filtered_df = filtered_df[~filtered_df[COL_NAME_IS_ANONYMOUS]]
|
| 160 |
filtered_df = filter_models(filtered_df, reranking_query)
|
|
@@ -165,7 +165,7 @@ def _update_table(
|
|
| 165 |
return filtered_df
|
| 166 |
|
| 167 |
|
| 168 |
-
def
|
| 169 |
version: str,
|
| 170 |
hidden_df: pd.DataFrame,
|
| 171 |
domains: list,
|
|
@@ -176,8 +176,8 @@ def update_table_long_doc(
|
|
| 176 |
show_revision_and_timestamp: bool = False,
|
| 177 |
reset_ranking: bool = True,
|
| 178 |
):
|
| 179 |
-
return
|
| 180 |
-
|
| 181 |
version,
|
| 182 |
hidden_df,
|
| 183 |
domains,
|
|
@@ -192,7 +192,7 @@ def update_table_long_doc(
|
|
| 192 |
|
| 193 |
def update_metric(
|
| 194 |
datastore,
|
| 195 |
-
task:
|
| 196 |
metric: str,
|
| 197 |
domains: list,
|
| 198 |
langs: list,
|
|
@@ -201,33 +201,24 @@ def update_metric(
|
|
| 201 |
show_anonymous: bool = False,
|
| 202 |
show_revision_and_timestamp: bool = False,
|
| 203 |
) -> pd.DataFrame:
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
version,
|
| 223 |
-
leaderboard_df,
|
| 224 |
-
domains,
|
| 225 |
-
langs,
|
| 226 |
-
reranking_model,
|
| 227 |
-
query,
|
| 228 |
-
show_anonymous,
|
| 229 |
-
show_revision_and_timestamp,
|
| 230 |
-
)
|
| 231 |
|
| 232 |
|
| 233 |
def upload_file(filepath: str):
|
|
@@ -341,7 +332,7 @@ def reset_rank(df):
|
|
| 341 |
return df
|
| 342 |
|
| 343 |
|
| 344 |
-
def get_leaderboard_df(datastore, task:
|
| 345 |
"""
|
| 346 |
Creates a dataframe from all the individual experiment results
|
| 347 |
"""
|
|
@@ -349,9 +340,9 @@ def get_leaderboard_df(datastore, task: str, metric: str) -> pd.DataFrame:
|
|
| 349 |
cols = [
|
| 350 |
COL_NAME_IS_ANONYMOUS,
|
| 351 |
]
|
| 352 |
-
if task ==
|
| 353 |
benchmarks = QABenchmarks[datastore.slug]
|
| 354 |
-
elif task ==
|
| 355 |
benchmarks = LongDocBenchmarks[datastore.slug]
|
| 356 |
else:
|
| 357 |
raise NotImplementedError
|
|
@@ -360,7 +351,7 @@ def get_leaderboard_df(datastore, task: str, metric: str) -> pd.DataFrame:
|
|
| 360 |
benchmark_cols = [t.value.col_name for t in list(benchmarks.value)]
|
| 361 |
all_data_json = []
|
| 362 |
for v in raw_data:
|
| 363 |
-
all_data_json += v.to_dict(task=task, metric=metric)
|
| 364 |
df = pd.DataFrame.from_records(all_data_json)
|
| 365 |
|
| 366 |
_benchmark_cols = frozenset(benchmark_cols).intersection(frozenset(df.columns.to_list()))
|
|
@@ -385,7 +376,7 @@ def get_leaderboard_df(datastore, task: str, metric: str) -> pd.DataFrame:
|
|
| 385 |
|
| 386 |
|
| 387 |
def set_listeners(
|
| 388 |
-
task,
|
| 389 |
target_df,
|
| 390 |
source_df,
|
| 391 |
search_bar,
|
|
@@ -396,10 +387,10 @@ def set_listeners(
|
|
| 396 |
show_anonymous,
|
| 397 |
show_revision_and_timestamp,
|
| 398 |
):
|
| 399 |
-
if task ==
|
| 400 |
-
update_table_func =
|
| 401 |
-
elif task ==
|
| 402 |
-
update_table_func =
|
| 403 |
else:
|
| 404 |
raise NotImplementedError
|
| 405 |
selector_list = [selected_domains, selected_langs, selected_rerankings, search_bar, show_anonymous]
|
|
@@ -427,7 +418,7 @@ def set_listeners(
|
|
| 427 |
)
|
| 428 |
|
| 429 |
|
| 430 |
-
def
|
| 431 |
version: str,
|
| 432 |
hidden_df: pd.DataFrame,
|
| 433 |
domains: list,
|
|
@@ -438,8 +429,8 @@ def update_table(
|
|
| 438 |
show_revision_and_timestamp: bool = False,
|
| 439 |
reset_ranking: bool = True,
|
| 440 |
):
|
| 441 |
-
return
|
| 442 |
-
|
| 443 |
version,
|
| 444 |
hidden_df,
|
| 445 |
domains,
|
|
|
|
| 6 |
|
| 7 |
import pandas as pd
|
| 8 |
|
| 9 |
+
from src.models import TaskType
|
| 10 |
from src.benchmarks import LongDocBenchmarks, QABenchmarks
|
| 11 |
from src.display.columns import get_default_col_names_and_types, get_fixed_col_names_and_types
|
| 12 |
from src.display.formatting import styled_error, styled_message
|
|
|
|
| 70 |
return df[(df[COL_NAME_RETRIEVAL_MODEL].str.contains(query, case=False))]
|
| 71 |
|
| 72 |
|
| 73 |
+
def get_default_cols(task: TaskType, version_slug, add_fix_cols: bool = True) -> tuple:
|
| 74 |
cols = []
|
| 75 |
types = []
|
| 76 |
+
if task == TaskType.qa:
|
| 77 |
benchmarks = QABenchmarks[version_slug]
|
| 78 |
+
elif task == TaskType.long_doc:
|
| 79 |
benchmarks = LongDocBenchmarks[version_slug]
|
| 80 |
else:
|
| 81 |
raise NotImplementedError
|
|
|
|
| 86 |
continue
|
| 87 |
cols.append(col_name)
|
| 88 |
types.append(col_type)
|
|
|
|
| 89 |
if add_fix_cols:
|
| 90 |
_cols = []
|
| 91 |
_types = []
|
|
|
|
| 104 |
df: pd.DataFrame,
|
| 105 |
domain_query: list,
|
| 106 |
language_query: list,
|
| 107 |
+
task: TaskType = TaskType.qa,
|
| 108 |
reset_ranking: bool = True,
|
| 109 |
version_slug: str = None,
|
| 110 |
) -> pd.DataFrame:
|
| 111 |
cols, _ = get_default_cols(task=task, version_slug=version_slug, add_fix_cols=False)
|
| 112 |
selected_cols = []
|
| 113 |
for c in cols:
|
| 114 |
+
if task == TaskType.qa:
|
| 115 |
eval_col = QABenchmarks[version_slug].value[c].value
|
| 116 |
+
elif task == TaskType.long_doc:
|
| 117 |
eval_col = LongDocBenchmarks[version_slug].value[c].value
|
| 118 |
else:
|
| 119 |
raise NotImplementedError
|
|
|
|
| 141 |
return "".join(character.lower() for character in name if (character.isalnum() or character == "_"))
|
| 142 |
|
| 143 |
|
| 144 |
+
def _update_df_elem(
|
| 145 |
+
task: TaskType,
|
| 146 |
version: str,
|
| 147 |
+
source_df: pd.DataFrame,
|
| 148 |
domains: list,
|
| 149 |
langs: list,
|
| 150 |
reranking_query: list,
|
|
|
|
| 154 |
show_revision_and_timestamp: bool = False,
|
| 155 |
):
|
| 156 |
version_slug = get_safe_name(version)[-4:]
|
| 157 |
+
filtered_df = source_df.copy()
|
| 158 |
if not show_anonymous:
|
| 159 |
filtered_df = filtered_df[~filtered_df[COL_NAME_IS_ANONYMOUS]]
|
| 160 |
filtered_df = filter_models(filtered_df, reranking_query)
|
|
|
|
| 165 |
return filtered_df
|
| 166 |
|
| 167 |
|
| 168 |
+
def update_doc_df_elem(
|
| 169 |
version: str,
|
| 170 |
hidden_df: pd.DataFrame,
|
| 171 |
domains: list,
|
|
|
|
| 176 |
show_revision_and_timestamp: bool = False,
|
| 177 |
reset_ranking: bool = True,
|
| 178 |
):
|
| 179 |
+
return _update_df_elem(
|
| 180 |
+
TaskType.long_doc,
|
| 181 |
version,
|
| 182 |
hidden_df,
|
| 183 |
domains,
|
|
|
|
| 192 |
|
| 193 |
def update_metric(
|
| 194 |
datastore,
|
| 195 |
+
task: TaskType,
|
| 196 |
metric: str,
|
| 197 |
domains: list,
|
| 198 |
langs: list,
|
|
|
|
| 201 |
show_anonymous: bool = False,
|
| 202 |
show_revision_and_timestamp: bool = False,
|
| 203 |
) -> pd.DataFrame:
|
| 204 |
+
if task == TaskType.qa:
|
| 205 |
+
update_func = update_qa_df_elem
|
| 206 |
+
elif task == TaskType.long_doc:
|
| 207 |
+
update_func = update_doc_df_elem
|
| 208 |
+
else:
|
| 209 |
+
raise NotImplemented
|
| 210 |
+
df_elem = get_leaderboard_df(datastore, task=task, metric=metric)
|
| 211 |
+
version = datastore.version
|
| 212 |
+
return update_func(
|
| 213 |
+
version,
|
| 214 |
+
df_elem,
|
| 215 |
+
domains,
|
| 216 |
+
langs,
|
| 217 |
+
reranking_model,
|
| 218 |
+
query,
|
| 219 |
+
show_anonymous,
|
| 220 |
+
show_revision_and_timestamp,
|
| 221 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
|
| 223 |
|
| 224 |
def upload_file(filepath: str):
|
|
|
|
| 332 |
return df
|
| 333 |
|
| 334 |
|
| 335 |
+
def get_leaderboard_df(datastore, task: TaskType, metric: str) -> pd.DataFrame:
|
| 336 |
"""
|
| 337 |
Creates a dataframe from all the individual experiment results
|
| 338 |
"""
|
|
|
|
| 340 |
cols = [
|
| 341 |
COL_NAME_IS_ANONYMOUS,
|
| 342 |
]
|
| 343 |
+
if task == TaskType.qa:
|
| 344 |
benchmarks = QABenchmarks[datastore.slug]
|
| 345 |
+
elif task == TaskType.long_doc:
|
| 346 |
benchmarks = LongDocBenchmarks[datastore.slug]
|
| 347 |
else:
|
| 348 |
raise NotImplementedError
|
|
|
|
| 351 |
benchmark_cols = [t.value.col_name for t in list(benchmarks.value)]
|
| 352 |
all_data_json = []
|
| 353 |
for v in raw_data:
|
| 354 |
+
all_data_json += v.to_dict(task=task.value, metric=metric)
|
| 355 |
df = pd.DataFrame.from_records(all_data_json)
|
| 356 |
|
| 357 |
_benchmark_cols = frozenset(benchmark_cols).intersection(frozenset(df.columns.to_list()))
|
|
|
|
| 376 |
|
| 377 |
|
| 378 |
def set_listeners(
|
| 379 |
+
task: TaskType,
|
| 380 |
target_df,
|
| 381 |
source_df,
|
| 382 |
search_bar,
|
|
|
|
| 387 |
show_anonymous,
|
| 388 |
show_revision_and_timestamp,
|
| 389 |
):
|
| 390 |
+
if task == TaskType.qa:
|
| 391 |
+
update_table_func = update_qa_df_elem
|
| 392 |
+
elif task == TaskType.long_doc:
|
| 393 |
+
update_table_func = update_doc_df_elem
|
| 394 |
else:
|
| 395 |
raise NotImplementedError
|
| 396 |
selector_list = [selected_domains, selected_langs, selected_rerankings, search_bar, show_anonymous]
|
|
|
|
| 418 |
)
|
| 419 |
|
| 420 |
|
| 421 |
+
def update_qa_df_elem(
|
| 422 |
version: str,
|
| 423 |
hidden_df: pd.DataFrame,
|
| 424 |
domains: list,
|
|
|
|
| 429 |
show_revision_and_timestamp: bool = False,
|
| 430 |
reset_ranking: bool = True,
|
| 431 |
):
|
| 432 |
+
return _update_df_elem(
|
| 433 |
+
TaskType.qa,
|
| 434 |
version,
|
| 435 |
hidden_df,
|
| 436 |
domains,
|
tests/test_utils.py
CHANGED
|
@@ -18,7 +18,7 @@ from src.utils import (
|
|
| 18 |
get_iso_format_timestamp,
|
| 19 |
search_table,
|
| 20 |
select_columns,
|
| 21 |
-
|
| 22 |
)
|
| 23 |
|
| 24 |
|
|
@@ -90,7 +90,7 @@ def test_select_columns(toy_df):
|
|
| 90 |
|
| 91 |
|
| 92 |
def test_update_table_long_doc(toy_df_long_doc):
|
| 93 |
-
df_result =
|
| 94 |
toy_df_long_doc,
|
| 95 |
[
|
| 96 |
"law",
|
|
|
|
| 18 |
get_iso_format_timestamp,
|
| 19 |
search_table,
|
| 20 |
select_columns,
|
| 21 |
+
update_doc_df_elem,
|
| 22 |
)
|
| 23 |
|
| 24 |
|
|
|
|
| 90 |
|
| 91 |
|
| 92 |
def test_update_table_long_doc(toy_df_long_doc):
|
| 93 |
+
df_result = update_doc_df_elem(
|
| 94 |
toy_df_long_doc,
|
| 95 |
[
|
| 96 |
"law",
|