WaveLSFromer / exp_timeseries.py
ducheng678
Initial WaveLSFromer project
093b0a5
Raw
History Blame Contribute Delete
17.6 kB
import torch
import pytorch_lightning as pl
from collections import defaultdict
from models.Basic import MLP
from models.Lstm import LSTM
from models.Informer import Informer, InformerStack
from models.Stockformer import Stockformer
from utils.stock_metrics import get_stock_algo, pct_direction_torch
from torchmetrics import MeanSquaredError, MeanAbsoluteError
from torch_optimizer import Ranger
class ExpTimeseries(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.config = config
# pl makes self.learning_rate special
self.learning_rate = config.learning_rate
# Torch metrics has a state that resets but val and train can be called in unison so we split
# If pre_loss isn't supplied (ie: pre_loss is None) it will default to config.loss
self.train_criterion = self._select_criterion(
loss_override=self.config.pre_loss
)
self.other_criterion = self._select_criterion(
loss_override=self.config.pre_loss
)
self.metric = self._select_criterion(metric=True)
self.loss_switched = False
self._build_model()
# self.save_hyperparameters()
self.loss_reg = None
self.scale = None
self.val_log_growth_sum = None
self.val_log_growth_count = None
self.test_log_growth_sum = None
self.test_log_growth_count = None
def _build_model(self):
model_dict = {
"informer": Informer,
"informerstack": InformerStack,
"mlp": MLP,
"stockformer": Stockformer,
"lstm": LSTM,
}
assert (
self.config.model in model_dict
), f"Invalid config.model: {self.config.model}, options: {list(model_dict.keys())}"
self.model = model_dict[self.config.model](self.config).float()
# Load model
if self.config.load_model_path is not None:
self.load_from_checkpoint(self.config.load_model_path)
def _select_criterion(self, loss_override=None, metric=False):
loss = self.config.loss
if loss_override is not None:
loss = loss_override
def combine_loss(loss, weights=None):
if weights is None:
weights = [1.0] * len(loss)
def combined(pred, target, inv_pred, input_scale=None):
return loss[0](pred, target, input_scale=input_scale)
# return sum(w*l(inv_pred, target) if "Mean" in l.__class__.__name__ else w*l(pred, target) for w,l in zip(weights, loss))
return combined
def loss_lib(loss: str):
if "stock" in loss:
# Using Stock Loss
_, stock_loss_mode = loss.split("_")
target_type = self.config.target.split("_")[1]
assert (
target_type == "pctchange" or target_type == "logpctchange"
), "Can't use stock loss unless target is pctchange or logpctchange"
assert (
self.config.scale and
self.config.inverse_pred
# and not self.config.inverse_output
), "Can't use stock loss without scale, inverse pred, and not inverse output"
criterion = get_stock_algo(target_type, stock_loss_mode)
print("criterion:", criterion)
if metric:
def mt(x, y, input_scale):
return criterion.metric(x, y, input_scale=input_scale)
return mt
else:
return lambda x, y, input_scale: [-1 * criterion.loss(x, y).mean(), criterion.sharpe(x, y).mean()]
# return lambda x, y: -LogPctProfitTanhV1.loss(x, y).mean()
# return get_stock_loss(target_type, stock_loss_mode, threshold=0.0)
elif loss == "mae":
assert (
self.config.scale
and self.config.inverse_pred
# and self.config.inverse_output
), "Can't use mae loss without scale, inverse pred, and inverse output"
return MeanAbsoluteError().cuda()
elif loss == "mse":
assert (
self.config.scale
and self.config.inverse_pred
# and self.config.inverse_output
), "Can't use mse loss without scale, inverse pred, and inverse output"
return MeanSquaredError().cuda()
loss_list = [ loss_lib(loss_type) for loss_type in loss.split('+') ]
weights = [1.0] if '+' not in loss else [1.0, 0.1]
return combine_loss(loss_list, weights)
raise Exception(f"Invalid loss: {loss}")
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
return self.model(x)
def training_step(self, batch, batch_idx):
# training_step defines the train loop. It is independent of forward
batch_x, batch_y, batch_x_mark, batch_y_mark, _ = batch
sigma_x = 0.001
batch_x = batch_x + (torch.randn_like(batch_x)*2-1) * sigma_x
# print(sigma_x.mean(), batch_x.mean(), batch_x.shape)
# sigma_y = 0.01 * batch_y.std(dim=(1, 2), keepdim=True)
# batch_y = batch_y + (torch.randn_like(batch_y)*2-1) * sigma_y
pred, true, inv_pred = self._process_one_batch(
self.trainer.datamodule.data_train,
batch_x,
batch_y,
batch_x_mark,
batch_y_mark,
ds_index=None,
)
# print(self.loss_reg)
loss, sharpe = self.train_criterion(pred, true, inv_pred)
self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
self.log("train_sharpe", sharpe, prog_bar=True, on_step=False, on_epoch=True)
self.log("wavelet_loss", self.loss_reg, prog_bar=True, on_step=False, on_epoch=True)
# self.log(
# "tr_pct_dir",
# pct_direction_torch(pred, true),
# prog_bar=True,
# on_step=False,
# on_epoch=True,
# )
# self.log(
# "tr_mag",
# torch.linalg.norm(pred), # torch.mean(torch.abs(pred))
# prog_bar=False,
# on_step=False,
# on_epoch=True,
# )
if (
self.config.pre_epochs is not None
and self.config.pre_loss is not None
and self.current_epoch == self.config.pre_epochs
and not self.loss_switched
):
# Revert to default loss
self.train_criterion = self._select_criterion(
loss_override=self.config.loss
)
self.other_criterion = self._select_criterion(
loss_override=self.config.loss
)
self.loss_switched = True
return loss + torch.exp(-2.5*sharpe) + 1e0*self.loss_reg
def validation_step(self, batch, batch_idx, dataloader_idx=0):
# validation_step defines the validation loop. It is independent of forward
batch_x, batch_y, batch_x_mark, batch_y_mark, _ = batch
pred, true, inv_pred = self._process_one_batch(
self.trainer.datamodule.data_val,
batch_x,
batch_y,
batch_x_mark,
batch_y_mark,
ds_index=None,
)
if dataloader_idx == 0:
# Actual val dataset
# assert self.trainer.val_dataloaders[0].dataset.flag == "val"
loss, sharpe = self.other_criterion(pred, true, inv_pred)
self.log(
"val_loss",
loss,
prog_bar=True,
on_step=False,
on_epoch=True,
sync_dist=False,
add_dataloader_idx=False,
)
self.log(
"val_sharpe",
sharpe,
prog_bar=True,
on_step=False,
on_epoch=True,
sync_dist=False,
add_dataloader_idx=False,
)
raw, self.scale = self.metric(pred, true, inv_pred)
self.val_log_growth_sum[0] += raw.detach().sum()
self.val_log_growth_count[0] += raw.numel()
# self.log(
# "val_pct_dir",
# pct_direction_torch(pred, true),
# prog_bar=False,
# on_step=False,
# on_epoch=True,
# add_dataloader_idx=False,
# )
return
elif dataloader_idx == 1:
# TODO: If we are using torch metrics we should create an additional loss function
# Test dataset
assert self.trainer.val_dataloaders[1].dataset.flag == "test"
loss, sharpe = self.other_criterion(pred, true, inv_pred)
self.log(
"test_loss",
loss,
prog_bar=True,
on_step=False,
on_epoch=True,
sync_dist=False,
add_dataloader_idx=False,
)
self.log(
"test_sharpe",
sharpe,
prog_bar=True,
on_step=False,
on_epoch=True,
sync_dist=False,
add_dataloader_idx=False,
)
raw, _ = self.metric(pred, true, inv_pred, self.scale)
self.val_log_growth_sum[1] += raw.detach().sum()
self.val_log_growth_count[1] += raw.numel()
# self.log(
# "test_pct_dir",
# pct_direction_torch(pred, true),
# prog_bar=False,
# on_step=False,
# on_epoch=True,
# add_dataloader_idx=False,
# )
return
def on_validation_epoch_start(self):
self.val_log_growth_sum = defaultdict(lambda: 0.0)
self.val_log_growth_count = defaultdict(int)
def on_validation_epoch_end(self):
for dl_idx, sum_log in self.val_log_growth_sum.items():
# count = self.val_log_growth_count[dl_idx]
factor = torch.exp(sum_log)
roi = factor - 1
if dl_idx == 0:
name = "val_roi"
elif dl_idx == 1:
name = "test_roi"
else:
raise Exception
self.log(
name,
roi,
prog_bar=True,
on_step=False,
on_epoch=True,
sync_dist=False,
add_dataloader_idx=False,
)
# # 或者用平均 log-growth 当 metric(和 T 无关,更稳)
# mean_log_growth = self.val_log_growth_sum / self.val_log_growth_count
# self.log("val_mean_log_growth", mean_log_growth, prog_bar=False)
def test_step(self, batch, batch_idx, dataloader_idx=0):
# test_step defines the test loop. It is independent of forward
batch_x, batch_y, batch_x_mark, batch_y_mark, _ = batch
data_sets = [
self.trainer.datamodule.data_train,
self.trainer.datamodule.data_val,
self.trainer.datamodule.data_test,
]
pred, true, inv_pred = self._process_one_batch(
data_sets[dataloader_idx],
batch_x,
batch_y,
batch_x_mark,
batch_y_mark,
ds_index=None,
)
# loss = self.other_criterion(pred, true, inv_pred)
# # if dataloader_idx == 0:
# self.log(
# "test_loss",
# loss,
# sync_dist=False,
# )
if dataloader_idx == 0:
raw, _ = self.metric(pred, true, inv_pred)
if dataloader_idx == 1:
raw, self.scale = self.metric(pred, true, inv_pred)
if dataloader_idx == 2:
raw, _ = self.metric(pred, true, inv_pred, self.scale)
self.test_log_growth_sum[dataloader_idx] += raw.detach().sum()
self.test_log_growth_count[dataloader_idx] += raw.numel()
def on_test_epoch_start(self):
self.test_log_growth_sum = defaultdict(lambda: 0.0)
self.test_log_growth_count = defaultdict(int)
def on_test_epoch_end(self):
for dl_idx, sum_log in self.test_log_growth_sum.items():
factor = torch.exp(sum_log)
roi = factor - 1
self.log(
"test_roi",
roi,
sync_dist=False,
)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
batch_x, batch_y, batch_x_mark, batch_y_mark, _ = batch
data_sets = [
self.trainer.datamodule.data_train,
self.trainer.datamodule.data_val,
self.trainer.datamodule.data_test,
]
pred, true, inv_pred = self._process_one_batch(
data_sets[dataloader_idx],
batch_x,
batch_y,
batch_x_mark,
batch_y_mark,
ds_index=None,
)
# dataset = self.trainer.predict_dataloaders[dataloader_idx].dataset
# batch_x_raw_date, batch_y_raw_date = dataset.index_to_dates(batch_idx)
if "mse" in self.config.loss or "mae" in self.config.loss:
pred = inv_pred
return {
"pred": pred.detach().to(torch.float32),
"true": true.detach().to(torch.float32),
}
# def on_predict_epoch_end(self, results):
# pass
# def on_predict_end(self):
# pass
def _process_one_batch(
self,
dataset_object,
batch_x,
batch_y,
batch_x_mark,
batch_y_mark,
ds_index=None,
):
# Decoder input if self.config.dec_in
dec_inp = None
# if self.config.dec_in and (
# self.config.padding == 0 or self.config.padding == 1
# ):
# # FF: dec_inp = torch.zeros_like(batch_y[:, -self.config.pred_len:, :]).float()
# dec_inp = torch.full(
# [batch_y.shape[0], self.config.pred_len, batch_y.shape[-1]],
# self.config.padding,
# ).float()
# dec_inp = (
# torch.cat([batch_y[:, : self.config.label_len, :], dec_inp], dim=1)
# .float()
# .to(self.device)
# )
# Encoder - Decoder
if self.config.output_attention:
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
else:
outputs, loss_reg = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
self.loss_reg = loss_reg
# if self.config.inverse_output:
f_dim = -1 if self.config.features == "MS" else 0
# if ds_index is None:
batch_y = batch_y[:, -self.config.pred_len :, f_dim:]
# print(batch_y.std())
# batch_y = dataset_object.inverse_transform(batch_y)
# print(batch_y.std())
# while 1:pass
inv_outputs = dataset_object.inverse_transform(outputs)
return outputs, batch_y, inv_outputs
# else:
# batch_x_raw_dates, batch_y_raw_dates = dataset_object.index_to_dates(
# ds_index
# )
# assert batch_y_raw_dates.shape == batch_y.shape[0:2]
# batch_y = batch_y[:, -self.config.pred_len :, f_dim:].to(self.device)
# batch_y_raw_dates = batch_y_raw_dates[:, -self.config.pred_len :]
# return outputs, batch_y, batch_y_raw_dates
def configure_optimizers(self):
if self.config.optim == "AdamW":
optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
elif self.config.optim == "Ranger":
optimizer = Ranger(self.parameters(), lr=self.learning_rate)
else:
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
# optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate)
# Learning rate scheduler
if self.config.lradj == "type1":
lmbda = lambda epoch: 0.5
scheduler = torch.optim.lr_scheduler.MultiplicativeLR(
optimizer, lr_lambda=lmbda, verbose=True
)
elif self.config.lradj == "type2":
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
factor=0.5,
patience=10,
threshold=0,
cooldown=0,
verbose=True,
min_lr=1e-8,
)
scheduler = {
"scheduler": scheduler,
"interval": "epoch", # called after each training epoch
"monitor": "val_loss",
}
elif self.config.lradj == "type3":
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=self.config.learning_rate,
steps_per_epoch=len(self.trainer.datamodule.data_train)
// self.config.batch_size, # Would be nicer to use self.trainer.train_dataloader.dataset but there is a pl bug
epochs=self.config.max_epochs,
)
scheduler = {
"scheduler": scheduler,
"interval": "step", # called after each training step
}
else:
return optimizer
return [optimizer], [scheduler]