| import os |
| import torch |
| import torch.nn as nn |
|
|
| from layers.encoder import Encoder, EncoderLayer, ConvLayer |
| from layers.attn import FullAttention, AttentionLayer, ProbAttention |
| from layers.embed import DataEmbedding |
| from utils.masking import QuestionMask |
| from .wavelet import WaveletFront |
|
|
|
|
| class Stockformer(nn.Module): |
| def __init__(self, config): |
| super(Stockformer, self).__init__() |
| self.pred_len = config.pred_len |
| assert self.pred_len == 1, "Stockformer needs pred_len to be 1" |
| self.attn = config.attn |
| self.output_attention = config.output_attention |
| self.seq_len = config.seq_len |
| self.final_mode = config.final_mode |
|
|
| self.wave_model = WaveletFront(in_channels=config.enc_in, |
| d_model=config.d_model-config.emb_t2v_app_dim, |
| kernel_size=31, |
| n_fft=128) |
|
|
| |
| self.enc_embedding = DataEmbedding( |
| config.d_model-config.emb_t2v_app_dim, |
| |
| config.d_model, |
| config.t_embed, |
| config.freq, |
| config.dropout_emb, |
| emb_t2v_app_dim=config.emb_t2v_app_dim, |
| tok_emb=config.tok_emb, |
| ) |
| |
| Attn = ProbAttention if config.attn == "prob" else FullAttention |
| |
| self.encoder = Encoder( |
| [ |
| EncoderLayer( |
| AttentionLayer( |
| Attn( |
| True if config.final_mode == "mode3" else False, |
| config.factor, |
| attention_dropout=config.dropout, |
| output_attention=config.output_attention, |
| ), |
| config.d_model, |
| config.n_heads, |
| mix=False, |
| ), |
| config.d_model, |
| config.d_ff, |
| dropout=config.dropout, |
| activation=config.activation, |
| ln_mode=config.ln_mode, |
| ) |
| for l in range(config.e_layers) |
| ], |
| [ConvLayer(config.d_model) for l in range(config.e_layers - 1)] |
| if config.distil |
| else None, |
| |
| norm_layer=torch.nn.RMSNorm(config.d_model), |
| ) |
|
|
| if config.final_mode == "mode1": |
| self.final = nn.Linear( |
| config.d_model * config.seq_len, config.c_out, bias=True |
| ) |
| elif config.final_mode == "mode2" or config.final_mode == "mode3": |
| self.final = nn.Linear(config.d_model, config.c_out, bias=True) |
| else: |
| raise Exception(f"Invalid final_mode: {config.final_mode}") |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| if config.load_model_path is not None: |
| path = os.path.join(config.checkpoints, config.load_model_path) |
| print(f"Loading Model from {path}") |
| self.load_state_dict(torch.load(path)) |
|
|
| def forward( |
| self, |
| x_enc, |
| x_mark_enc, |
| x_dec, |
| x_mark_dec, |
| enc_self_mask=None, |
| dec_self_mask=None, |
| dec_enc_mask=None, |
| pre_train=False, |
| ): |
| |
| |
| assert len(x_enc.shape) == 3 |
| assert x_enc.shape[1] == self.seq_len |
|
|
| x_enc, reg, _ = self.wave_model(x_enc.permute(0, 2, 1)) |
|
|
| |
| lambda_low=100; lambda_high=100; lambda_overlap=100; lambda_parse=1e-2; lambda_shape=10 |
| loss_reg = (lambda_low*reg["L_low"] + lambda_high*reg["L_high"] |
| + lambda_overlap*reg["L_overlap"] |
| + lambda_parse*reg["L_parseval"] |
| + lambda_shape*reg["L_shape"] |
| ) |
|
|
| if self.final_mode == "mode3": |
| |
| |
| zeros = torch.zeros([x_enc.shape[0], 1, x_enc.shape[2]]).to(x_enc) |
| x_enc = torch.cat([x_enc, zeros], 1) |
| x_mark_enc = torch.cat([x_mark_enc, x_mark_dec], 1) |
| assert enc_self_mask is None |
| enc_self_mask = QuestionMask( |
| x_enc.shape[0], x_enc.shape[1], device=x_enc.device |
| ) |
|
|
| |
| emb_out = self.enc_embedding(x_enc, x_mark_enc) |
|
|
| |
| enc_out, attns = self.encoder(emb_out, attn_mask=enc_self_mask) |
|
|
| if self.final_mode == "mode1": |
| out = self.final(enc_out.flatten(start_dim=1)) |
| elif self.final_mode == "mode2" or self.final_mode == "mode3": |
| out = self.final(enc_out[:, -1, :]) |
| else: |
| assert False, f"Forward missing valid final mode {self.final_mode}" |
|
|
| |
| if self.output_attention: |
| return out[:, None, :], attns |
| else: |
| return out[:, None, :], loss_reg |
|
|