Spaces:
Running
Running
yangzhitao
refactor: enhance submission functionality with new tabs and improved benchmark handling, and update editorconfig for consistent formatting
3f84332
| """Based on https://huggingface.co/spaces/demo-leaderboard-backend/leaderboard/blob/main/src/display/utils.py | |
| Enhanced with Pydantic models. | |
| """ | |
| from enum import Enum | |
| from typing import Literal, Union | |
| from pydantic import BaseModel, ConfigDict, create_model | |
| from typing_extensions import Self | |
| from src.prepare import get_benchmarks | |
| def fields( | |
| raw_class: Union[ | |
| type["_AutoEvalColumnBase"], | |
| "_AutoEvalColumnBase", | |
| type["EvalQueueColumnCls"], | |
| "EvalQueueColumnCls", | |
| ], | |
| ) -> list["ColumnContent"]: | |
| return [v.default for k, v in raw_class.model_fields.items() if k[:2] != "__" and k[-2:] != "__"] | |
| # These classes are for user facing column names, | |
| # to avoid having to change them all around the code | |
| # when a modif is needed | |
| class ColumnContent(BaseModel): | |
| name: str | |
| type: Literal["str", "number", "bool", "markdown"] | |
| displayed_by_default: bool | Literal["Original"] = False | |
| hidden: bool = False | |
| never_hidden: bool = False | |
| not_supported: bool = False # for not supported columns, should not be displayed | |
| def new( | |
| cls, | |
| name: str, | |
| type: Literal["str", "number", "bool", "markdown"], | |
| displayed_by_default: bool | Literal["Original"] = False, | |
| *, | |
| hidden: bool = False, | |
| never_hidden: bool = False, | |
| not_supported: bool = False, | |
| ) -> Self: | |
| return cls( | |
| name=name, | |
| type=type, | |
| displayed_by_default=displayed_by_default, | |
| hidden=hidden, | |
| never_hidden=never_hidden, | |
| not_supported=not_supported, | |
| ) | |
| class _AutoEvalColumnBase(BaseModel): | |
| model_config: ConfigDict = ConfigDict(extra="forbid", frozen=True) | |
| model_type_symbol: ColumnContent = ColumnContent( | |
| name="T", | |
| type="str", | |
| displayed_by_default=True, | |
| # never_hidden=True, | |
| ) | |
| model: ColumnContent = ColumnContent.new("Model", "markdown", True, never_hidden=True) | |
| average: ColumnContent = ColumnContent.new("Average ⬆️", "number", True) | |
| model_type: ColumnContent = ColumnContent.new("Type", "str", not_supported=True) # TODO: Hide for now | |
| architecture: ColumnContent = ColumnContent.new("Architecture", "str", not_supported=True) | |
| weight_type: ColumnContent = ColumnContent.new("Weight type", "str", hidden=True) | |
| precision: ColumnContent = ColumnContent.new("Precision", "str", not_supported=True) | |
| license: ColumnContent = ColumnContent.new("Hub License", "str", not_supported=True) | |
| params: ColumnContent = ColumnContent.new("#Params (B)", "number", not_supported=True) | |
| likes: ColumnContent = ColumnContent.new("Hub ❤️", "number", not_supported=True) | |
| still_on_hub: ColumnContent = ColumnContent.new("Available on the hub", "bool", not_supported=True) | |
| revision: ColumnContent = ColumnContent.new("Model sha", "str", not_supported=True) | |
| BENCHMARKS = get_benchmarks() | |
| # We use create_model to dynamically fill the scores from Tasks | |
| field_definitions = { | |
| task.key: ( | |
| ColumnContent, | |
| ColumnContent.new(task.title, "number", True), | |
| ) | |
| for task in BENCHMARKS | |
| } | |
| AutoEvalColumnCls: type[_AutoEvalColumnBase] = create_model( # pyright: ignore[reportCallIssue] | |
| '_AutoEvalColumnCls', | |
| __base__=_AutoEvalColumnBase, | |
| **field_definitions, # pyright: ignore[reportArgumentType] | |
| ) | |
| AutoEvalColumn = AutoEvalColumnCls() | |
| # For the queue columns in the submission tab | |
| class EvalQueueColumnCls(BaseModel): # Queue column | |
| model_config = ConfigDict(extra="forbid", frozen=True) | |
| model: ColumnContent = ColumnContent.new("model", "markdown", True) | |
| revision: ColumnContent = ColumnContent.new("revision", "str", True) | |
| private: ColumnContent = ColumnContent.new("private", "bool", True) | |
| precision: ColumnContent = ColumnContent.new("precision", "str", True) | |
| weight_type: ColumnContent = ColumnContent.new("weight_type", "str", "Original") | |
| status: ColumnContent = ColumnContent.new("status", "str", True) | |
| EvalQueueColumn = EvalQueueColumnCls() | |
| # All the model information that we might need | |
| class ModelDetails(BaseModel): | |
| name: str | |
| display_name: str = "" | |
| symbol: str = "" # emoji | |
| class ModelType(Enum): | |
| PT = ModelDetails(name="pretrained", symbol="🟢") | |
| FT = ModelDetails(name="fine-tuned", symbol="🔶") | |
| IFT = ModelDetails(name="instruction-tuned", symbol="⭕") | |
| RL = ModelDetails(name="RL-tuned", symbol="🟦") | |
| Unknown = ModelDetails(name="", symbol="?") | |
| def to_str(self, separator=" "): | |
| return f"{self.value.symbol}{separator}{self.value.name}" | |
| def from_str(type): | |
| if "fine-tuned" in type or "🔶" in type: | |
| return ModelType.FT | |
| if "pretrained" in type or "🟢" in type: | |
| return ModelType.PT | |
| if "RL-tuned" in type or "🟦" in type: | |
| return ModelType.RL | |
| if "instruction-tuned" in type or "⭕" in type: | |
| return ModelType.IFT | |
| return ModelType.Unknown | |
| class WeightType(Enum): | |
| Adapter = ModelDetails(name="Adapter") | |
| Original = ModelDetails(name="Original") | |
| Delta = ModelDetails(name="Delta") | |
| class Precision(Enum): | |
| float16 = ModelDetails(name="float16") | |
| bfloat16 = ModelDetails(name="bfloat16") | |
| float32 = ModelDetails(name="float32") | |
| float64 = ModelDetails(name="float64") | |
| int8 = ModelDetails(name="int8") | |
| uint8 = ModelDetails(name="uint8") | |
| int16 = ModelDetails(name="int16") | |
| int32 = ModelDetails(name="int32") | |
| int64 = ModelDetails(name="int64") | |
| Unknown = ModelDetails(name="?") | |
| def from_str(cls, precision): | |
| if precision in ["torch.float16", "float16"]: | |
| return Precision.float16 | |
| if precision in ["torch.bfloat16", "bfloat16"]: | |
| return Precision.bfloat16 | |
| if precision in ["torch.float32", "float32"]: | |
| return Precision.float32 | |
| if precision in ["torch.float64", "float64"]: | |
| return Precision.float64 | |
| if precision in ["torch.int8", "int8"]: | |
| return Precision.int8 | |
| if precision in ["torch.uint8", "uint8"]: | |
| return Precision.uint8 | |
| if precision in ["torch.int16", "int16"]: | |
| return Precision.int16 | |
| if precision in ["torch.int32", "int32"]: | |
| return Precision.int32 | |
| if precision in ["torch.int64", "int64"]: | |
| return Precision.int64 | |
| return Precision.Unknown | |
| # Column selection | |
| # COLS: list[str] = [c.name for c in fields(AutoEvalColumnCls) if not c.hidden] | |
| BASE_COLS: list[str] = [c.name for c in fields(_AutoEvalColumnBase) if not c.hidden] | |
| EVAL_COLS: list[str] = [c.name for c in fields(EvalQueueColumnCls)] | |
| EVAL_TYPES: list[Literal["str", "number", "bool", "markdown"]] = [c.type for c in fields(EvalQueueColumnCls)] | |
| NOT_SUPPORTED_COLS: list[str] = [c.name for c in fields(AutoEvalColumnCls) if c.not_supported] | |
| # BENCHMARK_COLS: list[str] = [t.value.col_name for t in Tasks] | |
| BENCHMARK_COLS: list[str] = [t.title for t in BENCHMARKS] | |