File size: 2,988 Bytes
093b0a5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 | import torch
from torch import nn
from layers.embed import Time2Vec
class LSTM(nn.Module):
def __init__(self, config):
super(LSTM, self).__init__()
assert config.pred_len == 1
assert config.label_len == 0
# Hidden dimensions
self.d_model = config.d_model
# Number of hidden layers
self.e_layers = config.e_layers
self.enc_in = config.enc_in
# Time Embedding
self.t_embed = config.t_embed
if self.t_embed is not None:
if config.t_embed == "time2vec_app":
if not (config.emb_t2v_app_dim > 0):
raise Exception("Need to specify a valid emb_t2v_app_dim")
self.enc_in += config.emb_t2v_app_dim
self.temporal_embedding = Time2Vec(
time_emb_dim=config.emb_t2v_app_dim, freq=config.freq
)
elif config.t_embed == "time2vec_add":
self.temporal_embedding = Time2Vec(
time_emb_dim=self.enc_in, freq=config.freq
)
else:
raise Exception(
"The only options for t_embed with mlp are null and time2vec_app"
)
# batch_first=True causes input/output tensors to be of shape
# (batch_dim, seq_dim, feature_dim)
self.lstm = nn.LSTM(
input_size=self.enc_in,
hidden_size=config.d_model,
num_layers=config.e_layers,
batch_first=True,
dropout=config.dropout,
bidirectional=False,
)
self.fc_1 = nn.Linear(config.d_model, config.d_ff)
self.relu = nn.ReLU()
# Readout layer
self.fc = nn.Linear(config.d_ff, config.c_out)
def forward(self, x, x_mark, *args, **kwargs):
if self.t_embed is not None:
if self.t_embed == "time2vec_app":
time_emb = self.temporal_embedding(x_mark)
x = torch.concat([x, time_emb], dim=-1)
elif self.t_embed == "time2vec_add":
time_emb = self.temporal_embedding(x_mark)
x = x + time_emb
# Initialize hidden state with zeros
h0 = torch.zeros(self.e_layers, x.size(0), self.d_model).to(x)
# Initialize cell state
c0 = torch.zeros(self.e_layers, x.size(0), self.d_model).to(x)
# We need to detach as we are doing truncated backpropagation through time (BPTT)
# If we don't, we'll backprop all the way to the start even after going through another batch
out, (hn, cn) = self.lstm(x, (h0, c0))
# Index hidden state of last time step
# out.size() --> 100, 32, 100
# out[:, -1, :] --> 100, 100 --> just want last time step hidden states!
# out = self.relu(self.fc_1(out[:, -1, :]))
out = self.relu(self.fc_1(self.relu(hn[-1])))
out = self.fc(out)
# out.size() --> 100, 10
return out[:, None]
|