| import torch |
| import pytorch_lightning as pl |
|
|
| from collections import defaultdict |
| from models.Basic import MLP |
| from models.Lstm import LSTM |
| from models.Informer import Informer, InformerStack |
| from models.Stockformer import Stockformer |
| from utils.stock_metrics import get_stock_algo, pct_direction_torch |
| from torchmetrics import MeanSquaredError, MeanAbsoluteError |
| from torch_optimizer import Ranger |
|
|
|
|
| class ExpTimeseries(pl.LightningModule): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
|
|
| |
| self.learning_rate = config.learning_rate |
|
|
| |
| |
| self.train_criterion = self._select_criterion( |
| loss_override=self.config.pre_loss |
| ) |
| self.other_criterion = self._select_criterion( |
| loss_override=self.config.pre_loss |
| ) |
| self.metric = self._select_criterion(metric=True) |
| self.loss_switched = False |
|
|
| self._build_model() |
| |
| self.loss_reg = None |
| self.scale = None |
| self.val_log_growth_sum = None |
| self.val_log_growth_count = None |
| self.test_log_growth_sum = None |
| self.test_log_growth_count = None |
|
|
| def _build_model(self): |
| model_dict = { |
| "informer": Informer, |
| "informerstack": InformerStack, |
| "mlp": MLP, |
| "stockformer": Stockformer, |
| "lstm": LSTM, |
| } |
| assert ( |
| self.config.model in model_dict |
| ), f"Invalid config.model: {self.config.model}, options: {list(model_dict.keys())}" |
| self.model = model_dict[self.config.model](self.config).float() |
|
|
| |
| if self.config.load_model_path is not None: |
| self.load_from_checkpoint(self.config.load_model_path) |
|
|
| def _select_criterion(self, loss_override=None, metric=False): |
| loss = self.config.loss |
| if loss_override is not None: |
| loss = loss_override |
|
|
| def combine_loss(loss, weights=None): |
| if weights is None: |
| weights = [1.0] * len(loss) |
| def combined(pred, target, inv_pred, input_scale=None): |
| return loss[0](pred, target, input_scale=input_scale) |
| |
| return combined |
|
|
| def loss_lib(loss: str): |
| if "stock" in loss: |
| |
| _, stock_loss_mode = loss.split("_") |
| target_type = self.config.target.split("_")[1] |
| assert ( |
| target_type == "pctchange" or target_type == "logpctchange" |
| ), "Can't use stock loss unless target is pctchange or logpctchange" |
| assert ( |
| self.config.scale and |
| self.config.inverse_pred |
| |
| ), "Can't use stock loss without scale, inverse pred, and not inverse output" |
|
|
| criterion = get_stock_algo(target_type, stock_loss_mode) |
| print("criterion:", criterion) |
| if metric: |
| def mt(x, y, input_scale): |
| return criterion.metric(x, y, input_scale=input_scale) |
| return mt |
| else: |
| return lambda x, y, input_scale: [-1 * criterion.loss(x, y).mean(), criterion.sharpe(x, y).mean()] |
| |
| |
| elif loss == "mae": |
| assert ( |
| self.config.scale |
| and self.config.inverse_pred |
| |
| ), "Can't use mae loss without scale, inverse pred, and inverse output" |
| return MeanAbsoluteError().cuda() |
| elif loss == "mse": |
| assert ( |
| self.config.scale |
| and self.config.inverse_pred |
| |
| ), "Can't use mse loss without scale, inverse pred, and inverse output" |
| return MeanSquaredError().cuda() |
| loss_list = [ loss_lib(loss_type) for loss_type in loss.split('+') ] |
| weights = [1.0] if '+' not in loss else [1.0, 0.1] |
| return combine_loss(loss_list, weights) |
|
|
| raise Exception(f"Invalid loss: {loss}") |
|
|
| def forward(self, x): |
| |
| return self.model(x) |
|
|
| def training_step(self, batch, batch_idx): |
| |
| batch_x, batch_y, batch_x_mark, batch_y_mark, _ = batch |
|
|
| sigma_x = 0.001 |
| batch_x = batch_x + (torch.randn_like(batch_x)*2-1) * sigma_x |
| |
| |
| |
|
|
| pred, true, inv_pred = self._process_one_batch( |
| self.trainer.datamodule.data_train, |
| batch_x, |
| batch_y, |
| batch_x_mark, |
| batch_y_mark, |
| ds_index=None, |
| ) |
| |
| loss, sharpe = self.train_criterion(pred, true, inv_pred) |
| self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True) |
| self.log("train_sharpe", sharpe, prog_bar=True, on_step=False, on_epoch=True) |
| self.log("wavelet_loss", self.loss_reg, prog_bar=True, on_step=False, on_epoch=True) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| if ( |
| self.config.pre_epochs is not None |
| and self.config.pre_loss is not None |
| and self.current_epoch == self.config.pre_epochs |
| and not self.loss_switched |
| ): |
| |
| self.train_criterion = self._select_criterion( |
| loss_override=self.config.loss |
| ) |
| self.other_criterion = self._select_criterion( |
| loss_override=self.config.loss |
| ) |
| self.loss_switched = True |
|
|
| return loss + torch.exp(-2.5*sharpe) + 1e0*self.loss_reg |
|
|
| def validation_step(self, batch, batch_idx, dataloader_idx=0): |
| |
| batch_x, batch_y, batch_x_mark, batch_y_mark, _ = batch |
|
|
| pred, true, inv_pred = self._process_one_batch( |
| self.trainer.datamodule.data_val, |
| batch_x, |
| batch_y, |
| batch_x_mark, |
| batch_y_mark, |
| ds_index=None, |
| ) |
|
|
| if dataloader_idx == 0: |
| |
| |
| loss, sharpe = self.other_criterion(pred, true, inv_pred) |
| self.log( |
| "val_loss", |
| loss, |
| prog_bar=True, |
| on_step=False, |
| on_epoch=True, |
| sync_dist=False, |
| add_dataloader_idx=False, |
| ) |
| self.log( |
| "val_sharpe", |
| sharpe, |
| prog_bar=True, |
| on_step=False, |
| on_epoch=True, |
| sync_dist=False, |
| add_dataloader_idx=False, |
| ) |
| raw, self.scale = self.metric(pred, true, inv_pred) |
| self.val_log_growth_sum[0] += raw.detach().sum() |
| self.val_log_growth_count[0] += raw.numel() |
| |
| |
| |
| |
| |
| |
| |
| |
| return |
| elif dataloader_idx == 1: |
| |
| |
| assert self.trainer.val_dataloaders[1].dataset.flag == "test" |
| loss, sharpe = self.other_criterion(pred, true, inv_pred) |
| self.log( |
| "test_loss", |
| loss, |
| prog_bar=True, |
| on_step=False, |
| on_epoch=True, |
| sync_dist=False, |
| add_dataloader_idx=False, |
| ) |
| self.log( |
| "test_sharpe", |
| sharpe, |
| prog_bar=True, |
| on_step=False, |
| on_epoch=True, |
| sync_dist=False, |
| add_dataloader_idx=False, |
| ) |
| raw, _ = self.metric(pred, true, inv_pred, self.scale) |
| self.val_log_growth_sum[1] += raw.detach().sum() |
| self.val_log_growth_count[1] += raw.numel() |
| |
| |
| |
| |
| |
| |
| |
| |
| return |
|
|
| def on_validation_epoch_start(self): |
| self.val_log_growth_sum = defaultdict(lambda: 0.0) |
| self.val_log_growth_count = defaultdict(int) |
|
|
| def on_validation_epoch_end(self): |
| for dl_idx, sum_log in self.val_log_growth_sum.items(): |
| |
| factor = torch.exp(sum_log) |
| roi = factor - 1 |
| if dl_idx == 0: |
| name = "val_roi" |
| elif dl_idx == 1: |
| name = "test_roi" |
| else: |
| raise Exception |
| self.log( |
| name, |
| roi, |
| prog_bar=True, |
| on_step=False, |
| on_epoch=True, |
| sync_dist=False, |
| add_dataloader_idx=False, |
| ) |
| |
| |
| |
|
|
| def test_step(self, batch, batch_idx, dataloader_idx=0): |
| |
| batch_x, batch_y, batch_x_mark, batch_y_mark, _ = batch |
|
|
| data_sets = [ |
| self.trainer.datamodule.data_train, |
| self.trainer.datamodule.data_val, |
| self.trainer.datamodule.data_test, |
| ] |
|
|
| pred, true, inv_pred = self._process_one_batch( |
| data_sets[dataloader_idx], |
| batch_x, |
| batch_y, |
| batch_x_mark, |
| batch_y_mark, |
| ds_index=None, |
| ) |
| |
| |
| |
| |
| |
| |
| |
|
|
| if dataloader_idx == 0: |
| raw, _ = self.metric(pred, true, inv_pred) |
| if dataloader_idx == 1: |
| raw, self.scale = self.metric(pred, true, inv_pred) |
| if dataloader_idx == 2: |
| raw, _ = self.metric(pred, true, inv_pred, self.scale) |
| self.test_log_growth_sum[dataloader_idx] += raw.detach().sum() |
| self.test_log_growth_count[dataloader_idx] += raw.numel() |
|
|
| def on_test_epoch_start(self): |
| self.test_log_growth_sum = defaultdict(lambda: 0.0) |
| self.test_log_growth_count = defaultdict(int) |
|
|
| def on_test_epoch_end(self): |
| for dl_idx, sum_log in self.test_log_growth_sum.items(): |
| factor = torch.exp(sum_log) |
| roi = factor - 1 |
| self.log( |
| "test_roi", |
| roi, |
| sync_dist=False, |
| ) |
|
|
| def predict_step(self, batch, batch_idx, dataloader_idx=0): |
| batch_x, batch_y, batch_x_mark, batch_y_mark, _ = batch |
|
|
| data_sets = [ |
| self.trainer.datamodule.data_train, |
| self.trainer.datamodule.data_val, |
| self.trainer.datamodule.data_test, |
| ] |
|
|
| pred, true, inv_pred = self._process_one_batch( |
| data_sets[dataloader_idx], |
| batch_x, |
| batch_y, |
| batch_x_mark, |
| batch_y_mark, |
| ds_index=None, |
| ) |
|
|
| |
| |
|
|
| if "mse" in self.config.loss or "mae" in self.config.loss: |
| pred = inv_pred |
| return { |
| "pred": pred.detach().to(torch.float32), |
| "true": true.detach().to(torch.float32), |
| } |
|
|
| |
| |
|
|
| |
| |
|
|
| def _process_one_batch( |
| self, |
| dataset_object, |
| batch_x, |
| batch_y, |
| batch_x_mark, |
| batch_y_mark, |
| ds_index=None, |
| ): |
| |
| dec_inp = None |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| if self.config.output_attention: |
| outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0] |
| else: |
| outputs, loss_reg = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) |
| self.loss_reg = loss_reg |
| |
| f_dim = -1 if self.config.features == "MS" else 0 |
|
|
| |
| batch_y = batch_y[:, -self.config.pred_len :, f_dim:] |
| |
| |
| |
| |
| inv_outputs = dataset_object.inverse_transform(outputs) |
| return outputs, batch_y, inv_outputs |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def configure_optimizers(self): |
| if self.config.optim == "AdamW": |
| optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate) |
| elif self.config.optim == "Ranger": |
| optimizer = Ranger(self.parameters(), lr=self.learning_rate) |
| else: |
| optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) |
| |
|
|
| |
| if self.config.lradj == "type1": |
| lmbda = lambda epoch: 0.5 |
| scheduler = torch.optim.lr_scheduler.MultiplicativeLR( |
| optimizer, lr_lambda=lmbda, verbose=True |
| ) |
| elif self.config.lradj == "type2": |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
| optimizer, |
| factor=0.5, |
| patience=10, |
| threshold=0, |
| cooldown=0, |
| verbose=True, |
| min_lr=1e-8, |
| ) |
| scheduler = { |
| "scheduler": scheduler, |
| "interval": "epoch", |
| "monitor": "val_loss", |
| } |
| elif self.config.lradj == "type3": |
| scheduler = torch.optim.lr_scheduler.OneCycleLR( |
| optimizer, |
| max_lr=self.config.learning_rate, |
| steps_per_epoch=len(self.trainer.datamodule.data_train) |
| // self.config.batch_size, |
| epochs=self.config.max_epochs, |
| ) |
| scheduler = { |
| "scheduler": scheduler, |
| "interval": "step", |
| } |
| else: |
| return optimizer |
|
|
| return [optimizer], [scheduler] |
|
|