WaveLSFromer / utils /callbacks.py
ducheng678
Initial WaveLSFromer project
093b0a5
Raw
History Blame Contribute Delete
5.24 kB
import pickle
from typing import Any
import pandas as pd
from pytorch_lightning.callbacks import BasePredictionWriter
import os
import torch
import numpy as np
class PredTrueDateWriter(BasePredictionWriter):
def __init__(self, output_dir, write_interval="epoch"):
super().__init__(write_interval)
self.output_dir = output_dir
def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
# this will create N (num processes) files in `output_dir` each containing
# the predictions of it's respective rank
folder_path = os.path.join(trainer.log_dir, "results/")
os.makedirs(folder_path, exist_ok=True)
data = {}
for dataloader_idx, (data_dl, bi) in enumerate(zip(predictions, batch_indices)):
dataloader_data = {}
dataset = trainer.predict_dataloaders[dataloader_idx].dataset
for key in ["pred", "true"]:
dataloader_data[key] = torch.concat(
[data_dl[i][key] for i in range(len(data_dl))]
)
# Date
batch_x_raw_dates, batch_y_raw_dates = dataset.index_to_dates(
np.concatenate(bi)
)
# TODO: Make safety shape assert
batch_y_raw_dates = batch_y_raw_dates[
:, -trainer.datamodule.config.pred_len :
]
dataloader_data["date"] = batch_y_raw_dates
data[dataset.flag] = dataloader_data
for key in dataloader_data:
np.save(
os.path.join(
folder_path, f"{key}_{dataset.flag}_{trainer.global_rank}.npy"
),
dataloader_data[key],
)
# class PredTrueDateWriterV2(BasePredictionWriter):
# def __init__(self, output_dir, write_interval="epoch"):
# super().__init__(write_interval)
# self.output_dir = output_dir
# def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
# # this will create N (num processes) files in `output_dir` each containing
# # the predictions of it's respective rank
# folder_path = os.path.join(trainer.log_dir, "results/")
# os.makedirs(folder_path, exist_ok=True)
# tpd_dict_tuple: dict[str, tuple[Any, Any, Any]] = {}
# data = {}
# for dataloader_idx, (data_dl, bi) in enumerate(zip(predictions, batch_indices)):
# dataloader_data = {}
# dataset = trainer.predict_dataloaders[dataloader_idx].dataset
# for key in ["pred", "true"]:
# dataloader_data[key] = torch.concat(
# [data_dl[i][key] for i in range(len(data_dl))]
# )
# # Date
# batch_x_raw_dates, batch_y_raw_dates = dataset.index_to_dates(
# np.concatenate(bi)
# )
# # TODO: Make safety shape assert
# batch_y_raw_dates = batch_y_raw_dates[
# :, -trainer.datamodule.config.pred_len :
# ]
# dataloader_data["date"] = batch_y_raw_dates
# data[dataset.flag] = dataloader_data
# for key in dataloader_data:
# np.save(
# os.path.join(
# folder_path, f"{key}_{dataset.flag}_{trainer.global_rank}.npy"
# ),
# dataloader_data[key],
# )
# for data_group in ["train", "val", "test"]:
# dp = [
# data[data_group]["true"][:, 0, 0],
# data[data_group]["pred"][:, 0, 0],
# data[data_group]["date"][:, 0],
# ]
# tpd_dict_tuple[data_group] = dp
# s = np.argsort(tpd_dict_tuple[data_group][2], axis=None)
# tpd_dict_tuple[data_group] = list(
# map(lambda x: x[s], tpd_dict_tuple[data_group])
# )
# tpd_dict_tuple[data_group][2] = pd.DatetimeIndex(
# tpd_dict_tuple[data_group][2], tz="UTC"
# )
# # # Override trues with df target data to get original numerical precision
# # if not ("mse" in args.loss and not args.inverse_output) and df is not None:
# # print("OVERRIDING trues with df target")
# # df_data_group = df.loc[tpd_dict_tuple[data_group][2]]
# # t = args.target.split("_")
# # df_target = df_data_group[t[0]][t[1]].to_numpy()
# # tpd_dict_tuple[data_group][0] = df_target
# tpd_dict: dict[str, dict[str, Any]] = {}
# for data_group in tpd_dict_tuple:
# tpd_dict[data_group] = {
# "trues": tpd_dict_tuple[data_group][0],
# "preds": tpd_dict_tuple[data_group][1],
# "dates": tpd_dict_tuple[data_group][2],
# }
# get_metrics(args: dotdict | None, pred: np.ndarray, true: np.ndarray, thresh: float = 0.0)
# with open(os.path.join(folder_path, "tpd_dict.pickle"), "wb") as handle:
# pickle.dump(tpd_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
# return tpd_dict