File size: 5,240 Bytes
093b0a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import pickle
from typing import Any
import pandas as pd
from pytorch_lightning.callbacks import BasePredictionWriter
import os
import torch
import numpy as np


class PredTrueDateWriter(BasePredictionWriter):
    def __init__(self, output_dir, write_interval="epoch"):
        super().__init__(write_interval)
        self.output_dir = output_dir

    def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
        # this will create N (num processes) files in `output_dir` each containing
        # the predictions of it's respective rank
        folder_path = os.path.join(trainer.log_dir, "results/")
        os.makedirs(folder_path, exist_ok=True)

        data = {}
        for dataloader_idx, (data_dl, bi) in enumerate(zip(predictions, batch_indices)):
            dataloader_data = {}
            dataset = trainer.predict_dataloaders[dataloader_idx].dataset

            for key in ["pred", "true"]:
                dataloader_data[key] = torch.concat(
                    [data_dl[i][key] for i in range(len(data_dl))]
                )

            # Date
            batch_x_raw_dates, batch_y_raw_dates = dataset.index_to_dates(
                np.concatenate(bi)
            )

            # TODO: Make safety shape assert
            batch_y_raw_dates = batch_y_raw_dates[
                :, -trainer.datamodule.config.pred_len :
            ]
            dataloader_data["date"] = batch_y_raw_dates

            data[dataset.flag] = dataloader_data

            for key in dataloader_data:
                np.save(
                    os.path.join(
                        folder_path, f"{key}_{dataset.flag}_{trainer.global_rank}.npy"
                    ),
                    dataloader_data[key],
                )


# class PredTrueDateWriterV2(BasePredictionWriter):
#     def __init__(self, output_dir, write_interval="epoch"):
#         super().__init__(write_interval)
#         self.output_dir = output_dir

#     def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
#         # this will create N (num processes) files in `output_dir` each containing
#         # the predictions of it's respective rank
#         folder_path = os.path.join(trainer.log_dir, "results/")
#         os.makedirs(folder_path, exist_ok=True)

#         tpd_dict_tuple: dict[str, tuple[Any, Any, Any]] = {}

#         data = {}
#         for dataloader_idx, (data_dl, bi) in enumerate(zip(predictions, batch_indices)):
#             dataloader_data = {}
#             dataset = trainer.predict_dataloaders[dataloader_idx].dataset

#             for key in ["pred", "true"]:
#                 dataloader_data[key] = torch.concat(
#                     [data_dl[i][key] for i in range(len(data_dl))]
#                 )

#             # Date
#             batch_x_raw_dates, batch_y_raw_dates = dataset.index_to_dates(
#                 np.concatenate(bi)
#             )

#             # TODO: Make safety shape assert
#             batch_y_raw_dates = batch_y_raw_dates[
#                 :, -trainer.datamodule.config.pred_len :
#             ]
#             dataloader_data["date"] = batch_y_raw_dates

#             data[dataset.flag] = dataloader_data

#             for key in dataloader_data:
#                 np.save(
#                     os.path.join(
#                         folder_path, f"{key}_{dataset.flag}_{trainer.global_rank}.npy"
#                     ),
#                     dataloader_data[key],
#                 )


#         for data_group in ["train", "val", "test"]:

#             dp = [
#                 data[data_group]["true"][:, 0, 0],
#                 data[data_group]["pred"][:, 0, 0],
#                 data[data_group]["date"][:, 0],
#             ]
#             tpd_dict_tuple[data_group] = dp
#             s = np.argsort(tpd_dict_tuple[data_group][2], axis=None)
#             tpd_dict_tuple[data_group] = list(
#                 map(lambda x: x[s], tpd_dict_tuple[data_group])
#             )

#             tpd_dict_tuple[data_group][2] = pd.DatetimeIndex(
#                 tpd_dict_tuple[data_group][2], tz="UTC"
#             )

#             # # Override trues with df target data to get original numerical precision
#             # if not ("mse" in args.loss and not args.inverse_output) and df is not None:
#             #     print("OVERRIDING trues with df target")
#             #     df_data_group = df.loc[tpd_dict_tuple[data_group][2]]
#             #     t = args.target.split("_")
#             #     df_target = df_data_group[t[0]][t[1]].to_numpy()
#             #     tpd_dict_tuple[data_group][0] = df_target


#         tpd_dict: dict[str, dict[str, Any]] = {}
#         for data_group in tpd_dict_tuple:
#             tpd_dict[data_group] = {
#                 "trues": tpd_dict_tuple[data_group][0],
#                 "preds": tpd_dict_tuple[data_group][1],
#                 "dates": tpd_dict_tuple[data_group][2],
#             }

#         get_metrics(args: dotdict | None, pred: np.ndarray, true: np.ndarray, thresh: float = 0.0)

#         with open(os.path.join(folder_path, "tpd_dict.pickle"), "wb") as handle:
#             pickle.dump(tpd_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
#         return tpd_dict