File size: 5,784 Bytes
093b0a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import torch
import torch.nn as nn

from layers.encoder import Encoder, EncoderLayer, ConvLayer
from layers.attn import FullAttention, AttentionLayer, ProbAttention
from layers.embed import DataEmbedding
from utils.masking import QuestionMask
from .wavelet import WaveletFront


class Stockformer(nn.Module):
    def __init__(self, config):
        super(Stockformer, self).__init__()
        self.pred_len = config.pred_len
        assert self.pred_len == 1, "Stockformer needs pred_len to be 1"
        self.attn = config.attn
        self.output_attention = config.output_attention
        self.seq_len = config.seq_len
        self.final_mode = config.final_mode

        self.wave_model = WaveletFront(in_channels=config.enc_in,
                                       d_model=config.d_model-config.emb_t2v_app_dim,
                                       kernel_size=31,
                                       n_fft=128)

        # Embedding
        self.enc_embedding = DataEmbedding(
            config.d_model-config.emb_t2v_app_dim,
            #  config.enc_in,
            config.d_model,
            config.t_embed,
            config.freq,
            config.dropout_emb,
            emb_t2v_app_dim=config.emb_t2v_app_dim,
            tok_emb=config.tok_emb,
        )
        # Attention
        Attn = ProbAttention if config.attn == "prob" else FullAttention
        # Encoder
        self.encoder = Encoder(
            [
                EncoderLayer(
                    AttentionLayer(
                        Attn(
                            True if config.final_mode == "mode3" else False,
                            config.factor,
                            attention_dropout=config.dropout,
                            output_attention=config.output_attention,
                        ),
                        config.d_model,
                        config.n_heads,
                        mix=False,
                    ),
                    config.d_model,
                    config.d_ff,
                    dropout=config.dropout,
                    activation=config.activation,
                    ln_mode=config.ln_mode,
                )
                for l in range(config.e_layers)
            ],
            [ConvLayer(config.d_model) for l in range(config.e_layers - 1)]
            if config.distil
            else None,
            #  norm_layer=torch.nn.LayerNorm(config.d_model),
            norm_layer=torch.nn.RMSNorm(config.d_model),
        )

        if config.final_mode == "mode1":
            self.final = nn.Linear(
                config.d_model * config.seq_len, config.c_out, bias=True
            )
        elif config.final_mode == "mode2" or config.final_mode == "mode3":
            self.final = nn.Linear(config.d_model, config.c_out, bias=True)
        else:
            raise Exception(f"Invalid final_mode: {config.final_mode}")
        # nn.init.xavier_normal_(self.final.weight, gain=nn.init.calculate_gain("tanh"))

        # self.final = nn.Sequential(*[
        #     nn.Linear(config.d_model * config.seq_len, config.d_model * 4, bias=True),
        #     nn.GELU(),
        #     nn.Linear(config.d_model * 4, config.c_out, bias=True)
        # ])

        # Load pre-trained model
        if config.load_model_path is not None:
            path = os.path.join(config.checkpoints, config.load_model_path)
            print(f"Loading Model from {path}")
            self.load_state_dict(torch.load(path))

    def forward(
        self,
        x_enc,
        x_mark_enc,
        x_dec,
        x_mark_dec,
        enc_self_mask=None,
        dec_self_mask=None,
        dec_enc_mask=None,
        pre_train=False,
    ):
        # x_enc is (batch_size / num gpus, seq_len, enc_in)
        # x_mark_enc is (batch_size / num gpus, seq_len, date-representation (7forhours)
        assert len(x_enc.shape) == 3
        assert x_enc.shape[1] == self.seq_len

        x_enc, reg, _ = self.wave_model(x_enc.permute(0, 2, 1))

        #  print(reg)
        lambda_low=100; lambda_high=100; lambda_overlap=100; lambda_parse=1e-2; lambda_shape=10
        loss_reg = (lambda_low*reg["L_low"] + lambda_high*reg["L_high"]
                    + lambda_overlap*reg["L_overlap"]
                    + lambda_parse*reg["L_parseval"]
                    + lambda_shape*reg["L_shape"]
                    )

        if self.final_mode == "mode3":
            # This gives the encoder a question input as the last token
            # TODO: Maybe this should be initialized differently, like to the mean of x_enc, random, mean of dataset?
            zeros = torch.zeros([x_enc.shape[0], 1, x_enc.shape[2]]).to(x_enc)
            x_enc = torch.cat([x_enc, zeros], 1)
            x_mark_enc = torch.cat([x_mark_enc, x_mark_dec], 1)
            assert enc_self_mask is None
            enc_self_mask = QuestionMask(
                x_enc.shape[0], x_enc.shape[1], device=x_enc.device
            )

        # emb_out is (batch_size / num gpus, seq_len, d_model)
        emb_out = self.enc_embedding(x_enc, x_mark_enc)

        # enc_out is (batch_size / num gpus, seq_len, d_model) but seq_len will change if distil
        enc_out, attns = self.encoder(emb_out, attn_mask=enc_self_mask)

        if self.final_mode == "mode1":
            out = self.final(enc_out.flatten(start_dim=1))
        elif self.final_mode == "mode2" or self.final_mode == "mode3":
            out = self.final(enc_out[:, -1, :])
        else:
            assert False, f"Forward missing valid final mode {self.final_mode}"

        # The None below is just adding a dummy dimension
        if self.output_attention:
            return out[:, None, :], attns
        else:
            return out[:, None, :], loss_reg  # (batch_size, 1, c_out)