import torch import pytorch_lightning as pl from models.Basic import MLP from models.Lstm import LSTM from models.Informer import Informer, InformerStack from models.Stockformer import Stockformer from utils.stock_metrics import get_stock_algo, pct_direction_torch from torchmetrics import MeanSquaredError, MeanAbsoluteError from torch_optimizer import Ranger class ExpTimeseries(pl.LightningModule): def __init__(self, config): super().__init__() self.config = config # pl makes self.learning_rate special self.learning_rate = config.learning_rate # Torch metrics has a state that resets but val and train can be called in unison so we split # If pre_loss isn't supplied (ie: pre_loss is None) it will default to config.loss self.train_criterion = self._select_criterion( loss_override=self.config.pre_loss ) self.other_criterion = self._select_criterion( loss_override=self.config.pre_loss ) self.loss_switched = False self._build_model() # self.save_hyperparameters() def _build_model(self): model_dict = { "informer": Informer, "informerstack": InformerStack, "mlp": MLP, "stockformer": Stockformer, "lstm": LSTM, } assert ( self.config.model in model_dict ), f"Invalid config.model: {self.config.model}, options: {list(model_dict.keys())}" self.model = model_dict[self.config.model](self.config).float() # Load model if self.config.load_model_path is not None: self.load_from_checkpoint(self.config.load_model_path) def _select_criterion(self, loss_override=None): loss = self.config.loss if loss_override is not None: loss = loss_override def combine_loss(loss, weights=None): if weights is None: weights = [1.0] * len(loss) def combined(pred, target, inv_pred): # print(pred.shape, target.shape) # while(1):pass return sum(w*l(inv_pred, target) if "Mean" in l.__class__.__name__ else w*l(pred, target) for w,l in zip(weights, loss)) return combined def loss_lib(loss: str): if "stock" in loss: # Using Stock Loss _, stock_loss_mode = loss.split("_") target_type = self.config.target.split("_")[1] assert ( target_type == "pctchange" or target_type == "logpctchange" ), "Can't use stock loss unless target is pctchange or logpctchange" assert ( self.config.scale and self.config.inverse_pred # and not self.config.inverse_output ), "Can't use stock loss without scale, inverse pred, and not inverse output" criterion = get_stock_algo(target_type, stock_loss_mode) print("criterion:", criterion) return lambda x, y: -1 * criterion.loss(x, y).mean() # return lambda x, y: -LogPctProfitTanhV1.loss(x, y).mean() # return get_stock_loss(target_type, stock_loss_mode, threshold=0.0) elif loss == "mae": assert ( self.config.scale and self.config.inverse_pred # and self.config.inverse_output ), "Can't use mae loss without scale, inverse pred, and inverse output" return MeanAbsoluteError().cuda() elif loss == "mse": assert ( self.config.scale and self.config.inverse_pred # and self.config.inverse_output ), "Can't use mse loss without scale, inverse pred, and inverse output" return MeanSquaredError().cuda() loss_list = [ loss_lib(loss_type) for loss_type in loss.split('+') ] weights = [1.0] if '+' not in loss else [10.0, 1.0] return combine_loss(loss_list, weights) raise Exception(f"Invalid loss: {loss}") def forward(self, x): # in lightning, forward defines the prediction/inference actions return self.model(x) def training_step(self, batch, batch_idx): # training_step defines the train loop. It is independent of forward batch_x, batch_y, batch_x_mark, batch_y_mark, _ = batch pred, true, inv_pred = self._process_one_batch( self.trainer.datamodule.data_train, batch_x, batch_y, batch_x_mark, batch_y_mark, ds_index=None, ) loss = self.train_criterion(pred, true, inv_pred) self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True) self.log( "train_pct_dir", pct_direction_torch(pred, true), prog_bar=True, on_step=False, on_epoch=True, ) self.log( "train_mag", torch.linalg.norm(pred), # torch.mean(torch.abs(pred)) prog_bar=False, on_step=False, on_epoch=True, ) if ( self.config.pre_epochs is not None and self.config.pre_loss is not None and self.current_epoch == self.config.pre_epochs and not self.loss_switched ): # Revert to default loss self.train_criterion = self._select_criterion( loss_override=self.config.loss ) self.other_criterion = self._select_criterion( loss_override=self.config.loss ) self.loss_switched = True return loss def validation_step(self, batch, batch_idx, dataloader_idx=0): # validation_step defines the validation loop. It is independent of forward batch_x, batch_y, batch_x_mark, batch_y_mark, _ = batch pred, true, inv_pred = self._process_one_batch( self.trainer.datamodule.data_val, batch_x, batch_y, batch_x_mark, batch_y_mark, ds_index=None, ) if dataloader_idx == 0: # Actual val dataset assert self.trainer.val_dataloaders[0].dataset.flag == "val" loss = self.other_criterion(pred, true, inv_pred) self.log( "val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=False, add_dataloader_idx=False, ) self.log( "val_pct_dir", pct_direction_torch(pred, true), prog_bar=False, on_step=False, on_epoch=True, add_dataloader_idx=False, ) return loss elif dataloader_idx == 1: # TODO: If we are using torch metrics we should create an additional loss function # Test dataset assert self.trainer.val_dataloaders[1].dataset.flag == "test" loss = self.other_criterion(pred, true, inv_pred) self.log( "test_loss", loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=False, add_dataloader_idx=False, ) self.log( "test_pct_dir", pct_direction_torch(pred, true), prog_bar=False, on_step=False, on_epoch=True, add_dataloader_idx=False, ) return loss def test_step(self, batch, batch_idx, dataloader_idx=0): # test_step defines the test loop. It is independent of forward batch_x, batch_y, batch_x_mark, batch_y_mark, _ = batch data_sets = [ self.trainer.datamodule.data_train, self.trainer.datamodule.data_val, self.trainer.datamodule.data_test, ] pred, true, inv_pred = self._process_one_batch( data_sets[dataloader_idx], batch_x, batch_y, batch_x_mark, batch_y_mark, ds_index=None, ) loss = self.other_criterion(pred, true, inv_pred) # if dataloader_idx == 0: self.log( "test_loss", loss, sync_dist=False, ) def predict_step(self, batch, batch_idx, dataloader_idx=0): batch_x, batch_y, batch_x_mark, batch_y_mark, _ = batch data_sets = [ self.trainer.datamodule.data_train, self.trainer.datamodule.data_val, self.trainer.datamodule.data_test, ] pred, true, inv_pred = self._process_one_batch( data_sets[dataloader_idx], batch_x, batch_y, batch_x_mark, batch_y_mark, ds_index=None, ) # dataset = self.trainer.predict_dataloaders[dataloader_idx].dataset # batch_x_raw_date, batch_y_raw_date = dataset.index_to_dates(batch_idx) if "mse" in self.config.loss or "mae" in self.config.loss: pred = inv_pred return { "pred": pred, "true": true, } # def on_predict_epoch_end(self, results): # pass # def on_predict_end(self): # pass def _process_one_batch( self, dataset_object, batch_x, batch_y, batch_x_mark, batch_y_mark, ds_index=None, ): # Decoder input if self.config.dec_in dec_inp = None # if self.config.dec_in and ( # self.config.padding == 0 or self.config.padding == 1 # ): # # FF: dec_inp = torch.zeros_like(batch_y[:, -self.config.pred_len:, :]).float() # dec_inp = torch.full( # [batch_y.shape[0], self.config.pred_len, batch_y.shape[-1]], # self.config.padding, # ).float() # dec_inp = ( # torch.cat([batch_y[:, : self.config.label_len, :], dec_inp], dim=1) # .float() # .to(self.device) # ) # Encoder - Decoder if self.config.output_attention: outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0] else: outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) # if self.config.inverse_output: inv_outputs = dataset_object.inverse_transform(outputs) f_dim = -1 if self.config.features == "MS" else 0 # if ds_index is None: batch_y = batch_y[:, -self.config.pred_len :, f_dim:] return outputs, batch_y, inv_outputs # else: # batch_x_raw_dates, batch_y_raw_dates = dataset_object.index_to_dates( # ds_index # ) # assert batch_y_raw_dates.shape == batch_y.shape[0:2] # batch_y = batch_y[:, -self.config.pred_len :, f_dim:].to(self.device) # batch_y_raw_dates = batch_y_raw_dates[:, -self.config.pred_len :] # return outputs, batch_y, batch_y_raw_dates def configure_optimizers(self): if self.config.optim == "AdamW": optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate) elif self.config.optim == "Ranger": optimizer = Ranger(self.parameters(), lr=self.learning_rate) else: optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) # optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate) # Learning rate scheduler if self.config.lradj == "type1": lmbda = lambda epoch: 0.5 scheduler = torch.optim.lr_scheduler.MultiplicativeLR( optimizer, lr_lambda=lmbda, verbose=True ) elif self.config.lradj == "type2": scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, factor=0.5, patience=10, threshold=0, cooldown=0, verbose=True, min_lr=1e-8, ) scheduler = { "scheduler": scheduler, "interval": "epoch", # called after each training epoch "monitor": "val_loss", } elif self.config.lradj == "type3": scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=self.config.learning_rate, steps_per_epoch=len(self.trainer.datamodule.data_train) // self.config.batch_size, # Would be nicer to use self.trainer.train_dataloader.dataset but there is a pl bug epochs=self.config.max_epochs, ) scheduler = { "scheduler": scheduler, "interval": "step", # called after each training step } else: return optimizer return [optimizer], [scheduler]