Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| from dataclasses import dataclass | |
| from enum import Enum | |
| from air_benchmark.tasks.tasks import BenchmarkTable | |
| from src.envs import BENCHMARK_VERSION_LIST, METRIC_LIST | |
| def get_safe_name(name: str): | |
| """Get RFC 1123 compatible safe name""" | |
| name = name.replace("-", "_") | |
| return "".join(character.lower() for character in name if (character.isalnum() or character == "_")) | |
| class Benchmark: | |
| name: str # [domain]_[language]_[metric], task_key in the json file, | |
| metric: str # ndcg_at_1 ,metric_key in the json file | |
| col_name: str # [domain]_[language], name to display in the leaderboard | |
| domain: str | |
| lang: str | |
| task: str | |
| # create a function return an enum class containing all the benchmarks | |
| def get_benchmarks_enum(benchmark_version, task_type): | |
| benchmark_dict = {} | |
| if task_type == "qa": | |
| for task, domain_dict in BenchmarkTable[benchmark_version].items(): | |
| if task != task_type: | |
| continue | |
| for domain, lang_dict in domain_dict.items(): | |
| for lang, dataset_list in lang_dict.items(): | |
| benchmark_name = get_safe_name(f"{domain}_{lang}") | |
| col_name = benchmark_name | |
| for metric in dataset_list: | |
| if "test" not in dataset_list[metric]["splits"]: | |
| continue | |
| benchmark_dict[benchmark_name] = Benchmark( | |
| benchmark_name, metric, col_name, domain, lang, task | |
| ) | |
| elif task_type == "long-doc": | |
| for task, domain_dict in BenchmarkTable[benchmark_version].items(): | |
| if task != task_type: | |
| continue | |
| for domain, lang_dict in domain_dict.items(): | |
| for lang, dataset_list in lang_dict.items(): | |
| for dataset in dataset_list: | |
| benchmark_name = f"{domain}_{lang}_{dataset}" | |
| benchmark_name = get_safe_name(benchmark_name) | |
| col_name = benchmark_name | |
| if "test" not in dataset_list[dataset]["splits"]: | |
| continue | |
| for metric in METRIC_LIST: | |
| benchmark_dict[benchmark_name] = Benchmark( | |
| benchmark_name, metric, col_name, domain, lang, task | |
| ) | |
| return benchmark_dict | |
| qa_benchmark_dict = {} | |
| for version in BENCHMARK_VERSION_LIST: | |
| safe_version_name = get_safe_name(version)[-4:] | |
| qa_benchmark_dict[safe_version_name] = Enum( | |
| f"QABenchmarks_{safe_version_name}", get_benchmarks_enum(version, "qa") | |
| ) | |
| long_doc_benchmark_dict = {} | |
| for version in BENCHMARK_VERSION_LIST: | |
| safe_version_name = get_safe_name(version)[-4:] | |
| long_doc_benchmark_dict[safe_version_name] = Enum( | |
| f"LongDocBenchmarks_{safe_version_name}", get_benchmarks_enum(version, "long-doc") | |
| ) | |
| QABenchmarks = Enum("QABenchmarks", qa_benchmark_dict) | |
| LongDocBenchmarks = Enum("LongDocBenchmarks", long_doc_benchmark_dict) | |