import os import torch import torch.nn as nn from layers.encoder import Encoder, EncoderLayer, ConvLayer from layers.attn import FullAttention, AttentionLayer, ProbAttention from layers.embed import DataEmbedding from utils.masking import QuestionMask from .wavelet import WaveletFront class Stockformer(nn.Module): def __init__(self, config): super(Stockformer, self).__init__() self.pred_len = config.pred_len assert self.pred_len == 1, "Stockformer needs pred_len to be 1" self.attn = config.attn self.output_attention = config.output_attention self.seq_len = config.seq_len self.final_mode = config.final_mode self.wave_model = WaveletFront(in_channels=config.enc_in, d_model=config.d_model-config.emb_t2v_app_dim, kernel_size=31, n_fft=128) # Embedding self.enc_embedding = DataEmbedding( config.d_model-config.emb_t2v_app_dim, # config.enc_in, config.d_model, config.t_embed, config.freq, config.dropout_emb, emb_t2v_app_dim=config.emb_t2v_app_dim, tok_emb=config.tok_emb, ) # Attention Attn = ProbAttention if config.attn == "prob" else FullAttention # Encoder self.encoder = Encoder( [ EncoderLayer( AttentionLayer( Attn( True if config.final_mode == "mode3" else False, config.factor, attention_dropout=config.dropout, output_attention=config.output_attention, ), config.d_model, config.n_heads, mix=False, ), config.d_model, config.d_ff, dropout=config.dropout, activation=config.activation, ln_mode=config.ln_mode, ) for l in range(config.e_layers) ], [ConvLayer(config.d_model) for l in range(config.e_layers - 1)] if config.distil else None, # norm_layer=torch.nn.LayerNorm(config.d_model), norm_layer=torch.nn.RMSNorm(config.d_model), ) if config.final_mode == "mode1": self.final = nn.Linear( config.d_model * config.seq_len, config.c_out, bias=True ) elif config.final_mode == "mode2" or config.final_mode == "mode3": self.final = nn.Linear(config.d_model, config.c_out, bias=True) else: raise Exception(f"Invalid final_mode: {config.final_mode}") # nn.init.xavier_normal_(self.final.weight, gain=nn.init.calculate_gain("tanh")) # self.final = nn.Sequential(*[ # nn.Linear(config.d_model * config.seq_len, config.d_model * 4, bias=True), # nn.GELU(), # nn.Linear(config.d_model * 4, config.c_out, bias=True) # ]) # Load pre-trained model if config.load_model_path is not None: path = os.path.join(config.checkpoints, config.load_model_path) print(f"Loading Model from {path}") self.load_state_dict(torch.load(path)) def forward( self, x_enc, x_mark_enc, x_dec, x_mark_dec, enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None, pre_train=False, ): # x_enc is (batch_size / num gpus, seq_len, enc_in) # x_mark_enc is (batch_size / num gpus, seq_len, date-representation (7forhours) assert len(x_enc.shape) == 3 assert x_enc.shape[1] == self.seq_len x_enc, reg, _ = self.wave_model(x_enc.permute(0, 2, 1)) # print(reg) lambda_low=100; lambda_high=100; lambda_overlap=100; lambda_parse=1e-2; lambda_shape=10 loss_reg = (lambda_low*reg["L_low"] + lambda_high*reg["L_high"] + lambda_overlap*reg["L_overlap"] + lambda_parse*reg["L_parseval"] + lambda_shape*reg["L_shape"] ) if self.final_mode == "mode3": # This gives the encoder a question input as the last token # TODO: Maybe this should be initialized differently, like to the mean of x_enc, random, mean of dataset? zeros = torch.zeros([x_enc.shape[0], 1, x_enc.shape[2]]).to(x_enc) x_enc = torch.cat([x_enc, zeros], 1) x_mark_enc = torch.cat([x_mark_enc, x_mark_dec], 1) assert enc_self_mask is None enc_self_mask = QuestionMask( x_enc.shape[0], x_enc.shape[1], device=x_enc.device ) # emb_out is (batch_size / num gpus, seq_len, d_model) emb_out = self.enc_embedding(x_enc, x_mark_enc) # enc_out is (batch_size / num gpus, seq_len, d_model) but seq_len will change if distil enc_out, attns = self.encoder(emb_out, attn_mask=enc_self_mask) if self.final_mode == "mode1": out = self.final(enc_out.flatten(start_dim=1)) elif self.final_mode == "mode2" or self.final_mode == "mode3": out = self.final(enc_out[:, -1, :]) else: assert False, f"Forward missing valid final mode {self.final_mode}" # The None below is just adding a dummy dimension if self.output_attention: return out[:, None, :], attns else: return out[:, None, :], loss_reg # (batch_size, 1, c_out)