File size: 6,395 Bytes
f9d3aeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer
from layers.SelfAttention_Family import ProbAttention, AttentionLayer
from layers.Embed import DataEmbedding
class Model(nn.Module):
"""
Informer with Propspare attention in O(LlogL) complexity
Paper link: https://ojs.aaai.org/index.php/AAAI/article/view/17325/17132
"""
def __init__(self, configs):
super(Model, self).__init__()
self.task_name = configs.task_name
self.pred_len = configs.pred_len
self.label_len = configs.label_len
# Embedding
self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
configs.dropout)
self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq,
configs.dropout)
# Encoder
self.encoder = Encoder(
[
EncoderLayer(
AttentionLayer(
ProbAttention(False, configs.factor, attention_dropout=configs.dropout,
output_attention=False),
configs.d_model, configs.n_heads),
configs.d_model,
configs.d_ff,
dropout=configs.dropout,
activation=configs.activation
) for l in range(configs.e_layers)
],
[
ConvLayer(
configs.d_model
) for l in range(configs.e_layers - 1)
] if configs.distil and ('forecast' in configs.task_name) else None,
norm_layer=torch.nn.LayerNorm(configs.d_model)
)
# Decoder
self.decoder = Decoder(
[
DecoderLayer(
AttentionLayer(
ProbAttention(True, configs.factor, attention_dropout=configs.dropout, output_attention=False),
configs.d_model, configs.n_heads),
AttentionLayer(
ProbAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=False),
configs.d_model, configs.n_heads),
configs.d_model,
configs.d_ff,
dropout=configs.dropout,
activation=configs.activation,
)
for l in range(configs.d_layers)
],
norm_layer=torch.nn.LayerNorm(configs.d_model),
projection=nn.Linear(configs.d_model, configs.c_out, bias=True)
)
if self.task_name == 'imputation':
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
if self.task_name == 'anomaly_detection':
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
if self.task_name == 'classification':
self.act = F.gelu
self.dropout = nn.Dropout(configs.dropout)
self.projection = nn.Linear(configs.d_model * configs.seq_len, configs.num_class)
def long_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
enc_out = self.enc_embedding(x_enc, x_mark_enc)
dec_out = self.dec_embedding(x_dec, x_mark_dec)
enc_out, attns = self.encoder(enc_out, attn_mask=None)
dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None)
return dec_out # [B, L, D]
def short_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
# Normalization
mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E
x_enc = x_enc - mean_enc
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() # B x 1 x E
x_enc = x_enc / std_enc
enc_out = self.enc_embedding(x_enc, x_mark_enc)
dec_out = self.dec_embedding(x_dec, x_mark_dec)
enc_out, attns = self.encoder(enc_out, attn_mask=None)
dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None)
dec_out = dec_out * std_enc + mean_enc
return dec_out # [B, L, D]
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
# enc
enc_out = self.enc_embedding(x_enc, x_mark_enc)
enc_out, attns = self.encoder(enc_out, attn_mask=None)
# final
dec_out = self.projection(enc_out)
return dec_out
def anomaly_detection(self, x_enc):
# enc
enc_out = self.enc_embedding(x_enc, None)
enc_out, attns = self.encoder(enc_out, attn_mask=None)
# final
dec_out = self.projection(enc_out)
return dec_out
def classification(self, x_enc, x_mark_enc):
# enc
enc_out = self.enc_embedding(x_enc, None)
enc_out, attns = self.encoder(enc_out, attn_mask=None)
# Output
output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity
output = self.dropout(output)
output = output * x_mark_enc.unsqueeze(-1) # zero-out padding embeddings
output = output.reshape(output.shape[0], -1) # (batch_size, seq_length * d_model)
output = self.projection(output) # (batch_size, num_classes)
return output
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if self.task_name == 'long_term_forecast':
dec_out = self.long_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len:, :] # [B, L, D]
if self.task_name == 'short_term_forecast':
dec_out = self.short_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len:, :] # [B, L, D]
if self.task_name == 'imputation':
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
return dec_out # [B, L, D]
if self.task_name == 'anomaly_detection':
dec_out = self.anomaly_detection(x_enc)
return dec_out # [B, L, D]
if self.task_name == 'classification':
dec_out = self.classification(x_enc, x_mark_enc)
return dec_out # [B, N]
return None
|