""" TimeSeriesDataSet builder for pytorch_forecasting. Wraps the feature_store output into train / validation / test splits with proper temporal ordering (no leakage). """ from __future__ import annotations import logging import os from typing import Optional import numpy as np import pandas as pd from deep_learning.config import TFTASROConfig, get_tft_config logger = logging.getLogger(__name__) def build_datasets( master_df: pd.DataFrame, time_varying_unknown_reals: list[str], time_varying_known_reals: list[str], target_cols: list[str], cfg: Optional[TFTASROConfig] = None, ): """ Create pytorch_forecasting TimeSeriesDataSet objects for train / val / test. Uses chronological splitting: [train | val | test] Returns: (training_dataset, validation_dataset, test_dataset) """ from pytorch_forecasting import TimeSeriesDataSet if cfg is None: cfg = get_tft_config() n = len(master_df) test_size = int(n * cfg.training.test_ratio) val_size = int(n * cfg.training.val_ratio) train_size = n - val_size - test_size if train_size < cfg.model.max_encoder_length + cfg.model.max_prediction_length: raise ValueError( f"Not enough data for TFT: {train_size} train rows, " f"need at least {cfg.model.max_encoder_length + cfg.model.max_prediction_length}" ) train_cutoff = master_df["time_idx"].iloc[train_size - 1] val_cutoff = master_df["time_idx"].iloc[train_size + val_size - 1] logger.info( "Data split: train=%d (idx<=%.0f), val=%d (idx<=%.0f), test=%d", train_size, train_cutoff, val_size, val_cutoff, test_size, ) target = target_cols[0] if target_cols else "target" training = TimeSeriesDataSet( master_df[master_df["time_idx"] <= train_cutoff], time_idx="time_idx", target=target, group_ids=["group_id"], max_encoder_length=cfg.model.max_encoder_length, max_prediction_length=cfg.model.max_prediction_length, time_varying_unknown_reals=time_varying_unknown_reals, time_varying_known_reals=time_varying_known_reals, static_categoricals=["group_id"], add_relative_time_idx=True, add_target_scales=True, add_encoder_length=True, allow_missing_timesteps=True, ) validation = TimeSeriesDataSet.from_dataset( training, master_df[ (master_df["time_idx"] > train_cutoff - cfg.model.max_encoder_length) & (master_df["time_idx"] <= val_cutoff) ], stop_randomization=True, ) test = TimeSeriesDataSet.from_dataset( training, master_df[master_df["time_idx"] > val_cutoff - cfg.model.max_encoder_length], stop_randomization=True, ) logger.info( "Datasets created: train=%d samples, val=%d, test=%d | " "encoder_len=%d, prediction_len=%d | " "%d unknown reals, %d known reals", len(training), len(validation), len(test), cfg.model.max_encoder_length, cfg.model.max_prediction_length, len(time_varying_unknown_reals), len(time_varying_known_reals), ) return training, validation, test def _resolve_num_workers(configured: int) -> int: """ Return a safe num_workers value for the current platform. On Windows (os.name == 'nt'), PyTorch DataLoader multiprocessing requires the script to be inside an ``if __name__ == '__main__'`` guard, which is not the case in training scripts. Force 0 to avoid deadlocks. On Linux/macOS (GitHub Actions, HF Spaces), use the configured value; default to 2 when the config still carries the old 0. """ if os.name == "nt": return 0 # On POSIX: honour config; upgrade 0 → 2 as a sensible floor return max(configured, 2) def create_dataloaders( training_dataset, validation_dataset, test_dataset=None, cfg: Optional[TFTASROConfig] = None, ): """ Create PyTorch DataLoaders from TimeSeriesDataSet objects. """ if cfg is None: cfg = get_tft_config() nw = _resolve_num_workers(cfg.training.num_workers) logger.info( "DataLoader workers: %d (platform=%s, configured=%d)", nw, os.name, cfg.training.num_workers, ) train_dl = training_dataset.to_dataloader( train=True, batch_size=cfg.training.batch_size, num_workers=nw, ) val_dl = validation_dataset.to_dataloader( train=False, batch_size=cfg.training.batch_size, num_workers=nw, ) test_dl = None if test_dataset is not None: test_dl = test_dataset.to_dataloader( train=False, batch_size=cfg.training.batch_size, num_workers=nw, ) return train_dl, val_dl, test_dl