import torch import pytorch_lightning as pl from collections import defaultdict from models.Basic import MLP from models.Lstm import LSTM from models.Informer import Informer, InformerStack from models.Stockformer import Stockformer from utils.stock_metrics import get_stock_algo, pct_direction_torch from torchmetrics import MeanSquaredError, MeanAbsoluteError from torch_optimizer import Ranger class ExpTimeseries(pl.LightningModule): def __init__(self, config): super().__init__() self.config = config # pl makes self.learning_rate special self.learning_rate = config.learning_rate # Torch metrics has a state that resets but val and train can be called in unison so we split # If pre_loss isn't supplied (ie: pre_loss is None) it will default to config.loss self.train_criterion = self._select_criterion( loss_override=self.config.pre_loss ) self.other_criterion = self._select_criterion( loss_override=self.config.pre_loss ) self.metric = self._select_criterion(metric=True) self.loss_switched = False self._build_model() # self.save_hyperparameters() self.loss_reg = None self.scale = None self.val_log_growth_sum = None self.val_log_growth_count = None self.test_log_growth_sum = None self.test_log_growth_count = None def _build_model(self): model_dict = { "informer": Informer, "informerstack": InformerStack, "mlp": MLP, "stockformer": Stockformer, "lstm": LSTM, } assert ( self.config.model in model_dict ), f"Invalid config.model: {self.config.model}, options: {list(model_dict.keys())}" self.model = model_dict[self.config.model](self.config).float() # Load model if self.config.load_model_path is not None: self.load_from_checkpoint(self.config.load_model_path) def _select_criterion(self, loss_override=None, metric=False): loss = self.config.loss if loss_override is not None: loss = loss_override def combine_loss(loss, weights=None): if weights is None: weights = [1.0] * len(loss) def combined(pred, target, inv_pred, input_scale=None): return loss[0](pred, target, input_scale=input_scale) # return sum(w*l(inv_pred, target) if "Mean" in l.__class__.__name__ else w*l(pred, target) for w,l in zip(weights, loss)) return combined def loss_lib(loss: str): if "stock" in loss: # Using Stock Loss _, stock_loss_mode = loss.split("_") target_type = self.config.target.split("_")[1] assert ( target_type == "pctchange" or target_type == "logpctchange" ), "Can't use stock loss unless target is pctchange or logpctchange" assert ( self.config.scale and self.config.inverse_pred # and not self.config.inverse_output ), "Can't use stock loss without scale, inverse pred, and not inverse output" criterion = get_stock_algo(target_type, stock_loss_mode) print("criterion:", criterion) if metric: def mt(x, y, input_scale): return criterion.metric(x, y, input_scale=input_scale) return mt else: return lambda x, y, input_scale: [-1 * criterion.loss(x, y).mean(), criterion.sharpe(x, y).mean()] # return lambda x, y: -LogPctProfitTanhV1.loss(x, y).mean() # return get_stock_loss(target_type, stock_loss_mode, threshold=0.0) elif loss == "mae": assert ( self.config.scale and self.config.inverse_pred # and self.config.inverse_output ), "Can't use mae loss without scale, inverse pred, and inverse output" return MeanAbsoluteError().cuda() elif loss == "mse": assert ( self.config.scale and self.config.inverse_pred # and self.config.inverse_output ), "Can't use mse loss without scale, inverse pred, and inverse output" return MeanSquaredError().cuda() loss_list = [ loss_lib(loss_type) for loss_type in loss.split('+') ] weights = [1.0] if '+' not in loss else [1.0, 0.1] return combine_loss(loss_list, weights) raise Exception(f"Invalid loss: {loss}") def forward(self, x): # in lightning, forward defines the prediction/inference actions return self.model(x) def training_step(self, batch, batch_idx): # training_step defines the train loop. It is independent of forward batch_x, batch_y, batch_x_mark, batch_y_mark, _ = batch sigma_x = 0.001 batch_x = batch_x + (torch.randn_like(batch_x)*2-1) * sigma_x # print(sigma_x.mean(), batch_x.mean(), batch_x.shape) # sigma_y = 0.01 * batch_y.std(dim=(1, 2), keepdim=True) # batch_y = batch_y + (torch.randn_like(batch_y)*2-1) * sigma_y pred, true, inv_pred = self._process_one_batch( self.trainer.datamodule.data_train, batch_x, batch_y, batch_x_mark, batch_y_mark, ds_index=None, ) # print(self.loss_reg) loss, sharpe = self.train_criterion(pred, true, inv_pred) self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True) self.log("train_sharpe", sharpe, prog_bar=True, on_step=False, on_epoch=True) self.log("wavelet_loss", self.loss_reg, prog_bar=True, on_step=False, on_epoch=True) # self.log( # "tr_pct_dir", # pct_direction_torch(pred, true), # prog_bar=True, # on_step=False, # on_epoch=True, # ) # self.log( # "tr_mag", # torch.linalg.norm(pred), # torch.mean(torch.abs(pred)) # prog_bar=False, # on_step=False, # on_epoch=True, # ) if ( self.config.pre_epochs is not None and self.config.pre_loss is not None and self.current_epoch == self.config.pre_epochs and not self.loss_switched ): # Revert to default loss self.train_criterion = self._select_criterion( loss_override=self.config.loss ) self.other_criterion = self._select_criterion( loss_override=self.config.loss ) self.loss_switched = True return loss + torch.exp(-2.5*sharpe) + 1e0*self.loss_reg def validation_step(self, batch, batch_idx, dataloader_idx=0): # validation_step defines the validation loop. It is independent of forward batch_x, batch_y, batch_x_mark, batch_y_mark, _ = batch pred, true, inv_pred = self._process_one_batch( self.trainer.datamodule.data_val, batch_x, batch_y, batch_x_mark, batch_y_mark, ds_index=None, ) if dataloader_idx == 0: # Actual val dataset # assert self.trainer.val_dataloaders[0].dataset.flag == "val" loss, sharpe = self.other_criterion(pred, true, inv_pred) self.log( "val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=False, add_dataloader_idx=False, ) self.log( "val_sharpe", sharpe, prog_bar=True, on_step=False, on_epoch=True, sync_dist=False, add_dataloader_idx=False, ) raw, self.scale = self.metric(pred, true, inv_pred) self.val_log_growth_sum[0] += raw.detach().sum() self.val_log_growth_count[0] += raw.numel() # self.log( # "val_pct_dir", # pct_direction_torch(pred, true), # prog_bar=False, # on_step=False, # on_epoch=True, # add_dataloader_idx=False, # ) return elif dataloader_idx == 1: # TODO: If we are using torch metrics we should create an additional loss function # Test dataset assert self.trainer.val_dataloaders[1].dataset.flag == "test" loss, sharpe = self.other_criterion(pred, true, inv_pred) self.log( "test_loss", loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=False, add_dataloader_idx=False, ) self.log( "test_sharpe", sharpe, prog_bar=True, on_step=False, on_epoch=True, sync_dist=False, add_dataloader_idx=False, ) raw, _ = self.metric(pred, true, inv_pred, self.scale) self.val_log_growth_sum[1] += raw.detach().sum() self.val_log_growth_count[1] += raw.numel() # self.log( # "test_pct_dir", # pct_direction_torch(pred, true), # prog_bar=False, # on_step=False, # on_epoch=True, # add_dataloader_idx=False, # ) return def on_validation_epoch_start(self): self.val_log_growth_sum = defaultdict(lambda: 0.0) self.val_log_growth_count = defaultdict(int) def on_validation_epoch_end(self): for dl_idx, sum_log in self.val_log_growth_sum.items(): # count = self.val_log_growth_count[dl_idx] factor = torch.exp(sum_log) roi = factor - 1 if dl_idx == 0: name = "val_roi" elif dl_idx == 1: name = "test_roi" else: raise Exception self.log( name, roi, prog_bar=True, on_step=False, on_epoch=True, sync_dist=False, add_dataloader_idx=False, ) # # 或者用平均 log-growth 当 metric(和 T 无关,更稳) # mean_log_growth = self.val_log_growth_sum / self.val_log_growth_count # self.log("val_mean_log_growth", mean_log_growth, prog_bar=False) def test_step(self, batch, batch_idx, dataloader_idx=0): # test_step defines the test loop. It is independent of forward batch_x, batch_y, batch_x_mark, batch_y_mark, _ = batch data_sets = [ self.trainer.datamodule.data_train, self.trainer.datamodule.data_val, self.trainer.datamodule.data_test, ] pred, true, inv_pred = self._process_one_batch( data_sets[dataloader_idx], batch_x, batch_y, batch_x_mark, batch_y_mark, ds_index=None, ) # loss = self.other_criterion(pred, true, inv_pred) # # if dataloader_idx == 0: # self.log( # "test_loss", # loss, # sync_dist=False, # ) if dataloader_idx == 0: raw, _ = self.metric(pred, true, inv_pred) if dataloader_idx == 1: raw, self.scale = self.metric(pred, true, inv_pred) if dataloader_idx == 2: raw, _ = self.metric(pred, true, inv_pred, self.scale) self.test_log_growth_sum[dataloader_idx] += raw.detach().sum() self.test_log_growth_count[dataloader_idx] += raw.numel() def on_test_epoch_start(self): self.test_log_growth_sum = defaultdict(lambda: 0.0) self.test_log_growth_count = defaultdict(int) def on_test_epoch_end(self): for dl_idx, sum_log in self.test_log_growth_sum.items(): factor = torch.exp(sum_log) roi = factor - 1 self.log( "test_roi", roi, sync_dist=False, ) def predict_step(self, batch, batch_idx, dataloader_idx=0): batch_x, batch_y, batch_x_mark, batch_y_mark, _ = batch data_sets = [ self.trainer.datamodule.data_train, self.trainer.datamodule.data_val, self.trainer.datamodule.data_test, ] pred, true, inv_pred = self._process_one_batch( data_sets[dataloader_idx], batch_x, batch_y, batch_x_mark, batch_y_mark, ds_index=None, ) # dataset = self.trainer.predict_dataloaders[dataloader_idx].dataset # batch_x_raw_date, batch_y_raw_date = dataset.index_to_dates(batch_idx) if "mse" in self.config.loss or "mae" in self.config.loss: pred = inv_pred return { "pred": pred.detach().to(torch.float32), "true": true.detach().to(torch.float32), } # def on_predict_epoch_end(self, results): # pass # def on_predict_end(self): # pass def _process_one_batch( self, dataset_object, batch_x, batch_y, batch_x_mark, batch_y_mark, ds_index=None, ): # Decoder input if self.config.dec_in dec_inp = None # if self.config.dec_in and ( # self.config.padding == 0 or self.config.padding == 1 # ): # # FF: dec_inp = torch.zeros_like(batch_y[:, -self.config.pred_len:, :]).float() # dec_inp = torch.full( # [batch_y.shape[0], self.config.pred_len, batch_y.shape[-1]], # self.config.padding, # ).float() # dec_inp = ( # torch.cat([batch_y[:, : self.config.label_len, :], dec_inp], dim=1) # .float() # .to(self.device) # ) # Encoder - Decoder if self.config.output_attention: outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0] else: outputs, loss_reg = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) self.loss_reg = loss_reg # if self.config.inverse_output: f_dim = -1 if self.config.features == "MS" else 0 # if ds_index is None: batch_y = batch_y[:, -self.config.pred_len :, f_dim:] # print(batch_y.std()) # batch_y = dataset_object.inverse_transform(batch_y) # print(batch_y.std()) # while 1:pass inv_outputs = dataset_object.inverse_transform(outputs) return outputs, batch_y, inv_outputs # else: # batch_x_raw_dates, batch_y_raw_dates = dataset_object.index_to_dates( # ds_index # ) # assert batch_y_raw_dates.shape == batch_y.shape[0:2] # batch_y = batch_y[:, -self.config.pred_len :, f_dim:].to(self.device) # batch_y_raw_dates = batch_y_raw_dates[:, -self.config.pred_len :] # return outputs, batch_y, batch_y_raw_dates def configure_optimizers(self): if self.config.optim == "AdamW": optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate) elif self.config.optim == "Ranger": optimizer = Ranger(self.parameters(), lr=self.learning_rate) else: optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) # optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate) # Learning rate scheduler if self.config.lradj == "type1": lmbda = lambda epoch: 0.5 scheduler = torch.optim.lr_scheduler.MultiplicativeLR( optimizer, lr_lambda=lmbda, verbose=True ) elif self.config.lradj == "type2": scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, factor=0.5, patience=10, threshold=0, cooldown=0, verbose=True, min_lr=1e-8, ) scheduler = { "scheduler": scheduler, "interval": "epoch", # called after each training epoch "monitor": "val_loss", } elif self.config.lradj == "type3": scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=self.config.learning_rate, steps_per_epoch=len(self.trainer.datamodule.data_train) // self.config.batch_size, # Would be nicer to use self.trainer.train_dataloader.dataset but there is a pl bug epochs=self.config.max_epochs, ) scheduler = { "scheduler": scheduler, "interval": "step", # called after each training step } else: return optimizer return [optimizer], [scheduler]