|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.nn.modules.linear import Linear |
|
|
from layers.SelfAttention_Family import AttentionLayer, FullAttention |
|
|
from layers.Embed import DataEmbedding |
|
|
import math |
|
|
|
|
|
|
|
|
def get_mask(input_size, window_size, inner_size): |
|
|
"""Get the attention mask of PAM-Naive""" |
|
|
|
|
|
all_size = [] |
|
|
all_size.append(input_size) |
|
|
for i in range(len(window_size)): |
|
|
layer_size = math.floor(all_size[i] / window_size[i]) |
|
|
all_size.append(layer_size) |
|
|
|
|
|
seq_length = sum(all_size) |
|
|
mask = torch.zeros(seq_length, seq_length) |
|
|
|
|
|
|
|
|
inner_window = inner_size // 2 |
|
|
for layer_idx in range(len(all_size)): |
|
|
start = sum(all_size[:layer_idx]) |
|
|
for i in range(start, start + all_size[layer_idx]): |
|
|
left_side = max(i - inner_window, start) |
|
|
right_side = min(i + inner_window + 1, start + all_size[layer_idx]) |
|
|
mask[i, left_side:right_side] = 1 |
|
|
|
|
|
|
|
|
for layer_idx in range(1, len(all_size)): |
|
|
start = sum(all_size[:layer_idx]) |
|
|
for i in range(start, start + all_size[layer_idx]): |
|
|
left_side = (start - all_size[layer_idx - 1]) + \ |
|
|
(i - start) * window_size[layer_idx - 1] |
|
|
if i == (start + all_size[layer_idx] - 1): |
|
|
right_side = start |
|
|
else: |
|
|
right_side = ( |
|
|
start - all_size[layer_idx - 1]) + (i - start + 1) * window_size[layer_idx - 1] |
|
|
mask[i, left_side:right_side] = 1 |
|
|
mask[left_side:right_side, i] = 1 |
|
|
|
|
|
mask = (1 - mask).bool() |
|
|
|
|
|
return mask, all_size |
|
|
|
|
|
|
|
|
def refer_points(all_sizes, window_size): |
|
|
"""Gather features from PAM's pyramid sequences""" |
|
|
input_size = all_sizes[0] |
|
|
indexes = torch.zeros(input_size, len(all_sizes)) |
|
|
|
|
|
for i in range(input_size): |
|
|
indexes[i][0] = i |
|
|
former_index = i |
|
|
for j in range(1, len(all_sizes)): |
|
|
start = sum(all_sizes[:j]) |
|
|
inner_layer_idx = former_index - (start - all_sizes[j - 1]) |
|
|
former_index = start + \ |
|
|
min(inner_layer_idx // window_size[j - 1], all_sizes[j] - 1) |
|
|
indexes[i][j] = former_index |
|
|
|
|
|
indexes = indexes.unsqueeze(0).unsqueeze(3) |
|
|
|
|
|
return indexes.long() |
|
|
|
|
|
|
|
|
class RegularMask(): |
|
|
def __init__(self, mask): |
|
|
self._mask = mask.unsqueeze(1) |
|
|
|
|
|
@property |
|
|
def mask(self): |
|
|
return self._mask |
|
|
|
|
|
|
|
|
class EncoderLayer(nn.Module): |
|
|
""" Compose with two layers """ |
|
|
|
|
|
def __init__(self, d_model, d_inner, n_head, dropout=0.1, normalize_before=True): |
|
|
super(EncoderLayer, self).__init__() |
|
|
|
|
|
self.slf_attn = AttentionLayer( |
|
|
FullAttention(mask_flag=True, factor=0, |
|
|
attention_dropout=dropout, output_attention=False), |
|
|
d_model, n_head) |
|
|
self.pos_ffn = PositionwiseFeedForward( |
|
|
d_model, d_inner, dropout=dropout, normalize_before=normalize_before) |
|
|
|
|
|
def forward(self, enc_input, slf_attn_mask=None): |
|
|
attn_mask = RegularMask(slf_attn_mask) |
|
|
enc_output, _ = self.slf_attn( |
|
|
enc_input, enc_input, enc_input, attn_mask=attn_mask) |
|
|
enc_output = self.pos_ffn(enc_output) |
|
|
return enc_output |
|
|
|
|
|
|
|
|
class Encoder(nn.Module): |
|
|
""" A encoder model with self attention mechanism. """ |
|
|
|
|
|
def __init__(self, configs, window_size, inner_size): |
|
|
super().__init__() |
|
|
|
|
|
d_bottleneck = configs.d_model//4 |
|
|
|
|
|
self.mask, self.all_size = get_mask( |
|
|
configs.seq_len, window_size, inner_size) |
|
|
self.indexes = refer_points(self.all_size, window_size) |
|
|
self.layers = nn.ModuleList([ |
|
|
EncoderLayer(configs.d_model, configs.d_ff, configs.n_heads, dropout=configs.dropout, |
|
|
normalize_before=False) for _ in range(configs.e_layers) |
|
|
]) |
|
|
|
|
|
self.enc_embedding = DataEmbedding( |
|
|
configs.enc_in, configs.d_model, configs.dropout) |
|
|
self.conv_layers = Bottleneck_Construct( |
|
|
configs.d_model, window_size, d_bottleneck) |
|
|
|
|
|
def forward(self, x_enc, x_mark_enc): |
|
|
seq_enc = self.enc_embedding(x_enc, x_mark_enc) |
|
|
|
|
|
mask = self.mask.repeat(len(seq_enc), 1, 1).to(x_enc.device) |
|
|
seq_enc = self.conv_layers(seq_enc) |
|
|
|
|
|
for i in range(len(self.layers)): |
|
|
seq_enc = self.layers[i](seq_enc, mask) |
|
|
|
|
|
indexes = self.indexes.repeat(seq_enc.size( |
|
|
0), 1, 1, seq_enc.size(2)).to(seq_enc.device) |
|
|
indexes = indexes.view(seq_enc.size(0), -1, seq_enc.size(2)) |
|
|
all_enc = torch.gather(seq_enc, 1, indexes) |
|
|
seq_enc = all_enc.view(seq_enc.size(0), self.all_size[0], -1) |
|
|
|
|
|
return seq_enc |
|
|
|
|
|
|
|
|
class ConvLayer(nn.Module): |
|
|
def __init__(self, c_in, window_size): |
|
|
super(ConvLayer, self).__init__() |
|
|
self.downConv = nn.Conv1d(in_channels=c_in, |
|
|
out_channels=c_in, |
|
|
kernel_size=window_size, |
|
|
stride=window_size) |
|
|
self.norm = nn.BatchNorm1d(c_in) |
|
|
self.activation = nn.ELU() |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.downConv(x) |
|
|
x = self.norm(x) |
|
|
x = self.activation(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class Bottleneck_Construct(nn.Module): |
|
|
"""Bottleneck convolution CSCM""" |
|
|
|
|
|
def __init__(self, d_model, window_size, d_inner): |
|
|
super(Bottleneck_Construct, self).__init__() |
|
|
if not isinstance(window_size, list): |
|
|
self.conv_layers = nn.ModuleList([ |
|
|
ConvLayer(d_inner, window_size), |
|
|
ConvLayer(d_inner, window_size), |
|
|
ConvLayer(d_inner, window_size) |
|
|
]) |
|
|
else: |
|
|
self.conv_layers = [] |
|
|
for i in range(len(window_size)): |
|
|
self.conv_layers.append(ConvLayer(d_inner, window_size[i])) |
|
|
self.conv_layers = nn.ModuleList(self.conv_layers) |
|
|
self.up = Linear(d_inner, d_model) |
|
|
self.down = Linear(d_model, d_inner) |
|
|
self.norm = nn.LayerNorm(d_model) |
|
|
|
|
|
def forward(self, enc_input): |
|
|
temp_input = self.down(enc_input).permute(0, 2, 1) |
|
|
all_inputs = [] |
|
|
for i in range(len(self.conv_layers)): |
|
|
temp_input = self.conv_layers[i](temp_input) |
|
|
all_inputs.append(temp_input) |
|
|
|
|
|
all_inputs = torch.cat(all_inputs, dim=2).transpose(1, 2) |
|
|
all_inputs = self.up(all_inputs) |
|
|
all_inputs = torch.cat([enc_input, all_inputs], dim=1) |
|
|
|
|
|
all_inputs = self.norm(all_inputs) |
|
|
return all_inputs |
|
|
|
|
|
|
|
|
class PositionwiseFeedForward(nn.Module): |
|
|
""" Two-layer position-wise feed-forward neural network. """ |
|
|
|
|
|
def __init__(self, d_in, d_hid, dropout=0.1, normalize_before=True): |
|
|
super().__init__() |
|
|
|
|
|
self.normalize_before = normalize_before |
|
|
|
|
|
self.w_1 = nn.Linear(d_in, d_hid) |
|
|
self.w_2 = nn.Linear(d_hid, d_in) |
|
|
|
|
|
self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, x): |
|
|
residual = x |
|
|
if self.normalize_before: |
|
|
x = self.layer_norm(x) |
|
|
|
|
|
x = F.gelu(self.w_1(x)) |
|
|
x = self.dropout(x) |
|
|
x = self.w_2(x) |
|
|
x = self.dropout(x) |
|
|
x = x + residual |
|
|
|
|
|
if not self.normalize_before: |
|
|
x = self.layer_norm(x) |
|
|
return x |
|
|
|