File size: 4,995 Bytes
093b0a5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | import os
import torch
import torch.nn as nn
from layers.encoder import Encoder, EncoderLayer, ConvLayer
from layers.attn import FullAttention, AttentionLayer, ProbAttention
from layers.embed import DataEmbedding
from utils.masking import QuestionMask
class Stockformer(nn.Module):
def __init__(self, config):
super(Stockformer, self).__init__()
self.pred_len = config.pred_len
assert self.pred_len == 1, "Stockformer needs pred_len to be 1"
self.attn = config.attn
self.output_attention = config.output_attention
self.seq_len = config.seq_len
self.final_mode = config.final_mode
# Embedding
self.enc_embedding = DataEmbedding(
config.enc_in,
config.d_model,
config.t_embed,
config.freq,
config.dropout_emb,
emb_t2v_app_dim=config.emb_t2v_app_dim,
tok_emb=config.tok_emb,
)
# Attention
Attn = ProbAttention if config.attn == "prob" else FullAttention
# Encoder
self.encoder = Encoder(
[
EncoderLayer(
AttentionLayer(
Attn(
True if config.final_mode == "mode3" else False,
config.factor,
attention_dropout=config.dropout,
output_attention=config.output_attention,
),
config.d_model,
config.n_heads,
mix=False,
),
config.d_model,
config.d_ff,
dropout=config.dropout,
activation=config.activation,
ln_mode=config.ln_mode,
)
for l in range(config.e_layers)
],
[ConvLayer(config.d_model) for l in range(config.e_layers - 1)]
if config.distil
else None,
# norm_layer=torch.nn.LayerNorm(config.d_model),
norm_layer=torch.nn.RMSNorm(config.d_model),
)
if config.final_mode == "mode1":
self.final = nn.Linear(
config.d_model * config.seq_len, config.c_out, bias=True
)
elif config.final_mode == "mode2" or config.final_mode == "mode3":
self.final = nn.Linear(config.d_model, config.c_out, bias=True)
else:
raise Exception(f"Invalid final_mode: {config.final_mode}")
# nn.init.xavier_normal_(self.final.weight, gain=nn.init.calculate_gain("tanh"))
# self.final = nn.Sequential(*[
# nn.Linear(config.d_model * config.seq_len, config.d_model * 4, bias=True),
# nn.GELU(),
# nn.Linear(config.d_model * 4, config.c_out, bias=True)
# ])
# Load pre-trained model
if config.load_model_path is not None:
path = os.path.join(config.checkpoints, config.load_model_path)
print(f"Loading Model from {path}")
self.load_state_dict(torch.load(path))
def forward(
self,
x_enc,
x_mark_enc,
x_dec,
x_mark_dec,
enc_self_mask=None,
dec_self_mask=None,
dec_enc_mask=None,
pre_train=False,
):
# x_enc is (batch_size / num gpus, seq_len, enc_in)
# x_mark_enc is (batch_size / num gpus, seq_len, date-representation (7forhours)
assert len(x_enc.shape) == 3
assert x_enc.shape[1] == self.seq_len
if self.final_mode == "mode3":
# This gives the encoder a question input as the last token
# TODO: Maybe this should be initialized differently, like to the mean of x_enc, random, mean of dataset?
zeros = torch.zeros([x_enc.shape[0], 1, x_enc.shape[2]]).to(x_enc)
x_enc = torch.cat([x_enc, zeros], 1)
x_mark_enc = torch.cat([x_mark_enc, x_mark_dec], 1)
assert enc_self_mask is None
enc_self_mask = QuestionMask(
x_enc.shape[0], x_enc.shape[1], device=x_enc.device
)
# emb_out is (batch_size / num gpus, seq_len, d_model)
emb_out = self.enc_embedding(x_enc, x_mark_enc)
# enc_out is (batch_size / num gpus, seq_len, d_model) but seq_len will change if distil
enc_out, attns = self.encoder(emb_out, attn_mask=enc_self_mask)
if self.final_mode == "mode1":
out = self.final(enc_out.flatten(start_dim=1))
elif self.final_mode == "mode2" or self.final_mode == "mode3":
out = self.final(enc_out[:, -1, :])
else:
assert False, f"Forward missing valid final mode {self.final_mode}"
# The None below is just adding a dummy dimension
if self.output_attention:
return out[:, None, :], attns
else:
return out[:, None, :] # (batch_size, 1, c_out)
|