TempoPFN / src /gift_eval /constants.py
Vladyslav Moroshan
Apply ruff formatting
0a58567
import json
import logging
import os
from pathlib import Path
from gluonts.ev.metrics import (
MAE,
MAPE,
MASE,
MSE,
MSIS,
ND,
NRMSE,
RMSE,
SMAPE,
MeanWeightedSumQuantileLoss,
)
logger = logging.getLogger(__name__)
# Environment setup
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
# Use absolute path relative to the project root
_MODULE_DIR = Path(__file__).parent.parent.parent # Goes to project root
DATASET_PROPERTIES_PATH = _MODULE_DIR / "data" / "dataset_properties.json"
try:
with open(DATASET_PROPERTIES_PATH) as f:
DATASET_PROPERTIES = json.load(f)
except Exception as exc: # pragma: no cover - logging path
DATASET_PROPERTIES = {}
logger.warning(
"Could not load dataset properties from %s: %s. Domain and num_variates will fall back to defaults.",
DATASET_PROPERTIES_PATH,
exc,
)
# Datasets
SHORT_DATASETS = (
"m4_yearly",
"m4_quarterly",
"m4_monthly",
"m4_weekly",
"m4_daily",
"m4_hourly",
"electricity/15T",
"electricity/H",
"electricity/D",
"electricity/W",
"solar/10T",
"solar/H",
"solar/D",
"solar/W",
"hospital",
"covid_deaths",
"us_births/D",
"us_births/M",
"us_births/W",
"saugeenday/D",
"saugeenday/M",
"saugeenday/W",
"temperature_rain_with_missing",
"kdd_cup_2018_with_missing/H",
"kdd_cup_2018_with_missing/D",
"car_parts_with_missing",
"restaurant",
"hierarchical_sales/D",
"hierarchical_sales/W",
"LOOP_SEATTLE/5T",
"LOOP_SEATTLE/H",
"LOOP_SEATTLE/D",
"SZ_TAXI/15T",
"SZ_TAXI/H",
"M_DENSE/H",
"M_DENSE/D",
"ett1/15T",
"ett1/H",
"ett1/D",
"ett1/W",
"ett2/15T",
"ett2/H",
"ett2/D",
"ett2/W",
"jena_weather/10T",
"jena_weather/H",
"jena_weather/D",
"bitbrains_fast_storage/5T",
"bitbrains_fast_storage/H",
"bitbrains_rnd/5T",
"bitbrains_rnd/H",
"bizitobs_application",
"bizitobs_service",
"bizitobs_l2c/5T",
"bizitobs_l2c/H",
)
MED_LONG_DATASETS = (
"electricity/15T",
"electricity/H",
"solar/10T",
"solar/H",
"kdd_cup_2018_with_missing/H",
"LOOP_SEATTLE/5T",
"LOOP_SEATTLE/H",
"SZ_TAXI/15T",
"M_DENSE/H",
"ett1/15T",
"ett1/H",
"ett2/15T",
"ett2/H",
"jena_weather/10T",
"jena_weather/H",
"bitbrains_fast_storage/5T",
"bitbrains_rnd/5T",
"bizitobs_application",
"bizitobs_service",
"bizitobs_l2c/5T",
"bizitobs_l2c/H",
)
# Preserve insertion order from SHORT_DATASETS followed by MED_LONG_DATASETS
ALL_DATASETS = list(dict.fromkeys(SHORT_DATASETS + MED_LONG_DATASETS))
# Evaluation terms
TERMS = ("short", "medium", "long")
# Pretty names mapping (following GIFT eval standard)
PRETTY_NAMES = {
"saugeenday": "saugeen",
"temperature_rain_with_missing": "temperature_rain",
"kdd_cup_2018_with_missing": "kdd_cup_2018",
"car_parts_with_missing": "car_parts",
}
METRICS = (
MSE(forecast_type="mean"),
MSE(forecast_type=0.5),
MAE(),
MASE(),
MAPE(),
SMAPE(),
MSIS(),
RMSE(),
NRMSE(),
ND(),
MeanWeightedSumQuantileLoss(quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]),
)
STANDARD_METRIC_NAMES = (
"MSE[mean]",
"MSE[0.5]",
"MAE[0.5]",
"MASE[0.5]",
"MAPE[0.5]",
"sMAPE[0.5]",
"MSIS",
"RMSE[mean]",
"NRMSE[mean]",
"ND[0.5]",
"mean_weighted_sum_quantile_loss",
)
__all__ = [
"ALL_DATASETS",
"DATASET_PROPERTIES",
"DATASET_PROPERTIES_PATH",
"MED_LONG_DATASETS",
"METRICS",
"PRETTY_NAMES",
"SHORT_DATASETS",
"STANDARD_METRIC_NAMES",
"TERMS",
]