File size: 5,240 Bytes
093b0a5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | import pickle
from typing import Any
import pandas as pd
from pytorch_lightning.callbacks import BasePredictionWriter
import os
import torch
import numpy as np
class PredTrueDateWriter(BasePredictionWriter):
def __init__(self, output_dir, write_interval="epoch"):
super().__init__(write_interval)
self.output_dir = output_dir
def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
# this will create N (num processes) files in `output_dir` each containing
# the predictions of it's respective rank
folder_path = os.path.join(trainer.log_dir, "results/")
os.makedirs(folder_path, exist_ok=True)
data = {}
for dataloader_idx, (data_dl, bi) in enumerate(zip(predictions, batch_indices)):
dataloader_data = {}
dataset = trainer.predict_dataloaders[dataloader_idx].dataset
for key in ["pred", "true"]:
dataloader_data[key] = torch.concat(
[data_dl[i][key] for i in range(len(data_dl))]
)
# Date
batch_x_raw_dates, batch_y_raw_dates = dataset.index_to_dates(
np.concatenate(bi)
)
# TODO: Make safety shape assert
batch_y_raw_dates = batch_y_raw_dates[
:, -trainer.datamodule.config.pred_len :
]
dataloader_data["date"] = batch_y_raw_dates
data[dataset.flag] = dataloader_data
for key in dataloader_data:
np.save(
os.path.join(
folder_path, f"{key}_{dataset.flag}_{trainer.global_rank}.npy"
),
dataloader_data[key],
)
# class PredTrueDateWriterV2(BasePredictionWriter):
# def __init__(self, output_dir, write_interval="epoch"):
# super().__init__(write_interval)
# self.output_dir = output_dir
# def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
# # this will create N (num processes) files in `output_dir` each containing
# # the predictions of it's respective rank
# folder_path = os.path.join(trainer.log_dir, "results/")
# os.makedirs(folder_path, exist_ok=True)
# tpd_dict_tuple: dict[str, tuple[Any, Any, Any]] = {}
# data = {}
# for dataloader_idx, (data_dl, bi) in enumerate(zip(predictions, batch_indices)):
# dataloader_data = {}
# dataset = trainer.predict_dataloaders[dataloader_idx].dataset
# for key in ["pred", "true"]:
# dataloader_data[key] = torch.concat(
# [data_dl[i][key] for i in range(len(data_dl))]
# )
# # Date
# batch_x_raw_dates, batch_y_raw_dates = dataset.index_to_dates(
# np.concatenate(bi)
# )
# # TODO: Make safety shape assert
# batch_y_raw_dates = batch_y_raw_dates[
# :, -trainer.datamodule.config.pred_len :
# ]
# dataloader_data["date"] = batch_y_raw_dates
# data[dataset.flag] = dataloader_data
# for key in dataloader_data:
# np.save(
# os.path.join(
# folder_path, f"{key}_{dataset.flag}_{trainer.global_rank}.npy"
# ),
# dataloader_data[key],
# )
# for data_group in ["train", "val", "test"]:
# dp = [
# data[data_group]["true"][:, 0, 0],
# data[data_group]["pred"][:, 0, 0],
# data[data_group]["date"][:, 0],
# ]
# tpd_dict_tuple[data_group] = dp
# s = np.argsort(tpd_dict_tuple[data_group][2], axis=None)
# tpd_dict_tuple[data_group] = list(
# map(lambda x: x[s], tpd_dict_tuple[data_group])
# )
# tpd_dict_tuple[data_group][2] = pd.DatetimeIndex(
# tpd_dict_tuple[data_group][2], tz="UTC"
# )
# # # Override trues with df target data to get original numerical precision
# # if not ("mse" in args.loss and not args.inverse_output) and df is not None:
# # print("OVERRIDING trues with df target")
# # df_data_group = df.loc[tpd_dict_tuple[data_group][2]]
# # t = args.target.split("_")
# # df_target = df_data_group[t[0]][t[1]].to_numpy()
# # tpd_dict_tuple[data_group][0] = df_target
# tpd_dict: dict[str, dict[str, Any]] = {}
# for data_group in tpd_dict_tuple:
# tpd_dict[data_group] = {
# "trues": tpd_dict_tuple[data_group][0],
# "preds": tpd_dict_tuple[data_group][1],
# "dates": tpd_dict_tuple[data_group][2],
# }
# get_metrics(args: dotdict | None, pred: np.ndarray, true: np.ndarray, thresh: float = 0.0)
# with open(os.path.join(folder_path, "tpd_dict.pickle"), "wb") as handle:
# pickle.dump(tpd_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
# return tpd_dict
|