import pickle from typing import Any import pandas as pd from pytorch_lightning.callbacks import BasePredictionWriter import os import torch import numpy as np class PredTrueDateWriter(BasePredictionWriter): def __init__(self, output_dir, write_interval="epoch"): super().__init__(write_interval) self.output_dir = output_dir def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): # this will create N (num processes) files in `output_dir` each containing # the predictions of it's respective rank folder_path = os.path.join(trainer.log_dir, "results/") os.makedirs(folder_path, exist_ok=True) data = {} for dataloader_idx, (data_dl, bi) in enumerate(zip(predictions, batch_indices)): dataloader_data = {} dataset = trainer.predict_dataloaders[dataloader_idx].dataset for key in ["pred", "true"]: dataloader_data[key] = torch.concat( [data_dl[i][key] for i in range(len(data_dl))] ) # Date batch_x_raw_dates, batch_y_raw_dates = dataset.index_to_dates( np.concatenate(bi) ) # TODO: Make safety shape assert batch_y_raw_dates = batch_y_raw_dates[ :, -trainer.datamodule.config.pred_len : ] dataloader_data["date"] = batch_y_raw_dates data[dataset.flag] = dataloader_data for key in dataloader_data: np.save( os.path.join( folder_path, f"{key}_{dataset.flag}_{trainer.global_rank}.npy" ), dataloader_data[key], ) # class PredTrueDateWriterV2(BasePredictionWriter): # def __init__(self, output_dir, write_interval="epoch"): # super().__init__(write_interval) # self.output_dir = output_dir # def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): # # this will create N (num processes) files in `output_dir` each containing # # the predictions of it's respective rank # folder_path = os.path.join(trainer.log_dir, "results/") # os.makedirs(folder_path, exist_ok=True) # tpd_dict_tuple: dict[str, tuple[Any, Any, Any]] = {} # data = {} # for dataloader_idx, (data_dl, bi) in enumerate(zip(predictions, batch_indices)): # dataloader_data = {} # dataset = trainer.predict_dataloaders[dataloader_idx].dataset # for key in ["pred", "true"]: # dataloader_data[key] = torch.concat( # [data_dl[i][key] for i in range(len(data_dl))] # ) # # Date # batch_x_raw_dates, batch_y_raw_dates = dataset.index_to_dates( # np.concatenate(bi) # ) # # TODO: Make safety shape assert # batch_y_raw_dates = batch_y_raw_dates[ # :, -trainer.datamodule.config.pred_len : # ] # dataloader_data["date"] = batch_y_raw_dates # data[dataset.flag] = dataloader_data # for key in dataloader_data: # np.save( # os.path.join( # folder_path, f"{key}_{dataset.flag}_{trainer.global_rank}.npy" # ), # dataloader_data[key], # ) # for data_group in ["train", "val", "test"]: # dp = [ # data[data_group]["true"][:, 0, 0], # data[data_group]["pred"][:, 0, 0], # data[data_group]["date"][:, 0], # ] # tpd_dict_tuple[data_group] = dp # s = np.argsort(tpd_dict_tuple[data_group][2], axis=None) # tpd_dict_tuple[data_group] = list( # map(lambda x: x[s], tpd_dict_tuple[data_group]) # ) # tpd_dict_tuple[data_group][2] = pd.DatetimeIndex( # tpd_dict_tuple[data_group][2], tz="UTC" # ) # # # Override trues with df target data to get original numerical precision # # if not ("mse" in args.loss and not args.inverse_output) and df is not None: # # print("OVERRIDING trues with df target") # # df_data_group = df.loc[tpd_dict_tuple[data_group][2]] # # t = args.target.split("_") # # df_target = df_data_group[t[0]][t[1]].to_numpy() # # tpd_dict_tuple[data_group][0] = df_target # tpd_dict: dict[str, dict[str, Any]] = {} # for data_group in tpd_dict_tuple: # tpd_dict[data_group] = { # "trues": tpd_dict_tuple[data_group][0], # "preds": tpd_dict_tuple[data_group][1], # "dates": tpd_dict_tuple[data_group][2], # } # get_metrics(args: dotdict | None, pred: np.ndarray, true: np.ndarray, thresh: float = 0.0) # with open(os.path.join(folder_path, "tpd_dict.pickle"), "wb") as handle: # pickle.dump(tpd_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) # return tpd_dict