yangzhitao
refactor: update benchmark display in submit tab
ba366ad
from functools import cached_property
from pydantic import BaseModel, ConfigDict, computed_field
from typing_extensions import Self
class MetaToml(BaseModel):
model_config = ConfigDict(extra="allow", frozen=True)
models: list["MetaToml_Model"]
@cached_property
def model_key_to_model(self) -> dict[str, "MetaToml_Model"]:
return {model.key: model for model in self.models}
@cached_property
def model_title_to_model(self) -> dict[str, "MetaToml_Model"]:
"""Model title (lower case) to model mapping"""
return {model.title.lower(): model for model in self.models}
benchmarks: list["MetaToml_Benchmark"]
@cached_property
def benchmark_key_to_benchmark(self) -> dict[str, "MetaToml_Benchmark"]:
return {benchmark.key: benchmark for benchmark in self.benchmarks}
model_repos: list["MetaToml_ModelRepo"]
@cached_property
def model_key_to_repo(self) -> dict[str, "MetaToml_ModelRepo"]:
return {repo.key: repo for repo in self.model_repos}
# --- Helper properties ---
@cached_property
def model_title_to_repo(self) -> dict[str, "MetaToml_ModelRepo"]:
"""Model title (lower case) to model repo mapping"""
mapping = {}
for model in self.models:
title = model.title.lower()
key = model.key
repo = self.model_key_to_repo.get(key)
if repo:
mapping[title] = repo
return mapping
@cached_property
def model_title_to_key(self) -> dict[str, str]:
return {model.title.lower(): model.key for model in self.models}
@cached_property
def benchmark_title_to_key(self) -> dict[str, str]:
return {benchmark.title.lower(): benchmark.key for benchmark in self.benchmarks}
@cached_property
def model_key_to_repo_id(self) -> dict[str, str]:
return {model.key: model.repo_id for model in self.model_repos if model.repo_id is not None}
class _HashableComparableMixin(BaseModel):
model_config = ConfigDict(extra="allow", frozen=True)
key: str
title: str
def __hash__(self) -> int:
return hash(self.key)
def __eq__(self, other: Self) -> bool:
return (self.key, self.title) == (other.key, other.title)
def __lt__(self, other: Self) -> bool:
return (self.key, self.title) < (other.key, other.title)
def __gt__(self, other: Self) -> bool:
return (self.key, self.title) > (other.key, other.title)
def __le__(self, other: Self) -> bool:
return (self.key, self.title) <= (other.key, other.title)
def __ge__(self, other: Self) -> bool:
return (self.key, self.title) >= (other.key, other.title)
class MetaToml_Benchmark(_HashableComparableMixin):
disabled: bool = False
@computed_field
@property
def default_metric(self) -> str:
return "caa" if self.key.startswith("site") else "acc"
@property
def default_metric_label(self) -> str:
return "CAA" if self.default_metric == "caa" else "Acc."
class MetaToml_Model(_HashableComparableMixin): ...
class MetaToml_ModelRepo(BaseModel):
model_config = ConfigDict(extra="allow", frozen=True)
key: str
repo_id: str | None = None
link: str | None = None