|
|
import json |
|
|
import logging |
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
from gluonts.ev.metrics import ( |
|
|
MAE, |
|
|
MAPE, |
|
|
MASE, |
|
|
MSE, |
|
|
MSIS, |
|
|
ND, |
|
|
NRMSE, |
|
|
RMSE, |
|
|
SMAPE, |
|
|
MeanWeightedSumQuantileLoss, |
|
|
) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" |
|
|
|
|
|
|
|
|
|
|
|
_MODULE_DIR = Path(__file__).parent.parent.parent |
|
|
DATASET_PROPERTIES_PATH = _MODULE_DIR / "data" / "dataset_properties.json" |
|
|
|
|
|
|
|
|
try: |
|
|
with open(DATASET_PROPERTIES_PATH) as f: |
|
|
DATASET_PROPERTIES = json.load(f) |
|
|
except Exception as exc: |
|
|
DATASET_PROPERTIES = {} |
|
|
logger.warning( |
|
|
"Could not load dataset properties from %s: %s. Domain and num_variates will fall back to defaults.", |
|
|
DATASET_PROPERTIES_PATH, |
|
|
exc, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
SHORT_DATASETS = ( |
|
|
"m4_yearly", |
|
|
"m4_quarterly", |
|
|
"m4_monthly", |
|
|
"m4_weekly", |
|
|
"m4_daily", |
|
|
"m4_hourly", |
|
|
"electricity/15T", |
|
|
"electricity/H", |
|
|
"electricity/D", |
|
|
"electricity/W", |
|
|
"solar/10T", |
|
|
"solar/H", |
|
|
"solar/D", |
|
|
"solar/W", |
|
|
"hospital", |
|
|
"covid_deaths", |
|
|
"us_births/D", |
|
|
"us_births/M", |
|
|
"us_births/W", |
|
|
"saugeenday/D", |
|
|
"saugeenday/M", |
|
|
"saugeenday/W", |
|
|
"temperature_rain_with_missing", |
|
|
"kdd_cup_2018_with_missing/H", |
|
|
"kdd_cup_2018_with_missing/D", |
|
|
"car_parts_with_missing", |
|
|
"restaurant", |
|
|
"hierarchical_sales/D", |
|
|
"hierarchical_sales/W", |
|
|
"LOOP_SEATTLE/5T", |
|
|
"LOOP_SEATTLE/H", |
|
|
"LOOP_SEATTLE/D", |
|
|
"SZ_TAXI/15T", |
|
|
"SZ_TAXI/H", |
|
|
"M_DENSE/H", |
|
|
"M_DENSE/D", |
|
|
"ett1/15T", |
|
|
"ett1/H", |
|
|
"ett1/D", |
|
|
"ett1/W", |
|
|
"ett2/15T", |
|
|
"ett2/H", |
|
|
"ett2/D", |
|
|
"ett2/W", |
|
|
"jena_weather/10T", |
|
|
"jena_weather/H", |
|
|
"jena_weather/D", |
|
|
"bitbrains_fast_storage/5T", |
|
|
"bitbrains_fast_storage/H", |
|
|
"bitbrains_rnd/5T", |
|
|
"bitbrains_rnd/H", |
|
|
"bizitobs_application", |
|
|
"bizitobs_service", |
|
|
"bizitobs_l2c/5T", |
|
|
"bizitobs_l2c/H", |
|
|
) |
|
|
|
|
|
MED_LONG_DATASETS = ( |
|
|
"electricity/15T", |
|
|
"electricity/H", |
|
|
"solar/10T", |
|
|
"solar/H", |
|
|
"kdd_cup_2018_with_missing/H", |
|
|
"LOOP_SEATTLE/5T", |
|
|
"LOOP_SEATTLE/H", |
|
|
"SZ_TAXI/15T", |
|
|
"M_DENSE/H", |
|
|
"ett1/15T", |
|
|
"ett1/H", |
|
|
"ett2/15T", |
|
|
"ett2/H", |
|
|
"jena_weather/10T", |
|
|
"jena_weather/H", |
|
|
"bitbrains_fast_storage/5T", |
|
|
"bitbrains_rnd/5T", |
|
|
"bizitobs_application", |
|
|
"bizitobs_service", |
|
|
"bizitobs_l2c/5T", |
|
|
"bizitobs_l2c/H", |
|
|
) |
|
|
|
|
|
|
|
|
ALL_DATASETS = list(dict.fromkeys(SHORT_DATASETS + MED_LONG_DATASETS)) |
|
|
|
|
|
|
|
|
|
|
|
TERMS = ("short", "medium", "long") |
|
|
|
|
|
|
|
|
|
|
|
PRETTY_NAMES = { |
|
|
"saugeenday": "saugeen", |
|
|
"temperature_rain_with_missing": "temperature_rain", |
|
|
"kdd_cup_2018_with_missing": "kdd_cup_2018", |
|
|
"car_parts_with_missing": "car_parts", |
|
|
} |
|
|
|
|
|
|
|
|
METRICS = ( |
|
|
MSE(forecast_type="mean"), |
|
|
MSE(forecast_type=0.5), |
|
|
MAE(), |
|
|
MASE(), |
|
|
MAPE(), |
|
|
SMAPE(), |
|
|
MSIS(), |
|
|
RMSE(), |
|
|
NRMSE(), |
|
|
ND(), |
|
|
MeanWeightedSumQuantileLoss(quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]), |
|
|
) |
|
|
|
|
|
|
|
|
STANDARD_METRIC_NAMES = ( |
|
|
"MSE[mean]", |
|
|
"MSE[0.5]", |
|
|
"MAE[0.5]", |
|
|
"MASE[0.5]", |
|
|
"MAPE[0.5]", |
|
|
"sMAPE[0.5]", |
|
|
"MSIS", |
|
|
"RMSE[mean]", |
|
|
"NRMSE[mean]", |
|
|
"ND[0.5]", |
|
|
"mean_weighted_sum_quantile_loss", |
|
|
) |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"ALL_DATASETS", |
|
|
"DATASET_PROPERTIES", |
|
|
"DATASET_PROPERTIES_PATH", |
|
|
"MED_LONG_DATASETS", |
|
|
"METRICS", |
|
|
"PRETTY_NAMES", |
|
|
"SHORT_DATASETS", |
|
|
"STANDARD_METRIC_NAMES", |
|
|
"TERMS", |
|
|
] |
|
|
|