Spaces:
Running
on
T4
Running
on
T4
| # Copyright (c) NXAI GmbH. | |
| # This software may be used and distributed according to the terms of the NXAI Community License Agreement. | |
| import pandas as pd | |
| import torch | |
| from gluonts.dataset.common import Dataset | |
| from gluonts.dataset.field_names import FieldName | |
| from gluonts.model.forecast import QuantileForecast | |
| from .standard_adapter import _batch_pad_iterable | |
| DEF_TARGET_COLUMN = FieldName.TARGET # target | |
| DEF_META_COLUMNS = (FieldName.START, FieldName.ITEM_ID) | |
| def _get_gluon_ts_map(**gluon_kwargs): | |
| target_col = gluon_kwargs.get("target_column", DEF_TARGET_COLUMN) | |
| meta_columns = gluon_kwargs.get("meta_columns", DEF_META_COLUMNS) | |
| def extract_gluon(series): | |
| ctx = torch.Tensor(series[target_col]) | |
| meta = {k: series[k] for k in meta_columns if k in series} | |
| meta["length"] = len(ctx) | |
| return ctx, meta | |
| return extract_gluon | |
| def get_gluon_batches(gluonDataset: Dataset, batch_size: int, **gluon_kwargs): | |
| return _batch_pad_iterable(map(_get_gluon_ts_map(**gluon_kwargs), gluonDataset), batch_size) | |
| def format_gluonts_output(quantile_forecasts: torch.Tensor, mean_forecasts, meta: list[dict], quantile_levels): | |
| forecasts = [] | |
| for i in range(quantile_forecasts.shape[0]): | |
| start_date = meta[i].get(FieldName.START, pd.Period("01-01-2000", freq=meta[i].get("freq", "h"))) | |
| start_date += meta[i].get("length", 0) | |
| forecasts.append( | |
| QuantileForecast( | |
| forecast_arrays=torch.cat((quantile_forecasts[i], mean_forecasts[i].unsqueeze(1)), dim=1) | |
| .T.cpu() | |
| .numpy(), | |
| start_date=start_date, | |
| item_id=meta[i].get(FieldName.ITEM_ID, None), | |
| forecast_keys=list(map(str, quantile_levels)) + ["mean"], | |
| ) | |
| ) | |
| return forecasts | |