File size: 1,966 Bytes
22df562 7939a4f 22df562 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | """
Display utilities and column definitions for MUSEval Leaderboard
"""
from dataclasses import dataclass
from typing import List, Dict, Any
from enum import Enum
# Column definitions for model information
@dataclass
class ModelInfoColumn:
name: str
type: str = "str"
displayed_by_default: bool = True
never_hidden: bool = False
hidden: bool = False
# Model information columns
model_info_columns = [
ModelInfoColumn("model", "str", True, True, False),
ModelInfoColumn("organization", "str", True, False, False),
ModelInfoColumn("submission_date", "str", True, False, False),
ModelInfoColumn("task", "str", True, False, False),
ModelInfoColumn("dataset_version", "str", True, False, False),
ModelInfoColumn("paper_url", "str", False, False, False),
ModelInfoColumn("code_url", "str", False, False, False),
ModelInfoColumn("domains", "number", True, False, False),
ModelInfoColumn("categories", "number", True, False, False),
ModelInfoColumn("datasets", "number", True, False, False),
]
# Benchmark columns (metrics)
BENCHMARK_COLS = [
"MAE", "Uni-MAE", "RMSE", "MAPE", "R²", "SMAPE", "Uni-Multi"
]
# Evaluation columns
EVAL_COLS = [
"model", "submitter", "submission_date", "domain", "category", "dataset",
"task", "dataset_version", "paper_url", "code_url"
]
# Evaluation types
EVAL_TYPES = ["multivariate_forecasting"]
# Model types
class ModelType(Enum):
FOUNDATION = "Foundation Model"
TRADITIONAL = "Traditional"
NEURAL = "Neural Network"
TRANSFORMER = "Transformer"
# Weight types
class WeightType(Enum):
LIGHTWEIGHT = "Lightweight"
MEDIUM = "Medium"
HEAVY = "Heavy"
# Precision types
class Precision(Enum):
FLOAT16 = "FP16"
FLOAT32 = "FP32"
MIXED = "Mixed"
# Fields function for dataclass
def fields(cls):
"""Get fields from dataclass"""
return cls.__dataclass_fields__.values() if hasattr(cls, '__dataclass_fields__') else []
|