|
|
| import math |
| import torch.nn.functional as F |
| from torch.autograd import Variable |
| from torch import nn |
| import torch |
| import argparse |
| import numpy as np |
|
|
| class Splitting(nn.Module): |
| def __init__(self): |
| super(Splitting, self).__init__() |
|
|
| def even(self, x): |
| return x[:, ::2, :] |
|
|
| def odd(self, x): |
| return x[:, 1::2, :] |
|
|
| def forward(self, x): |
| '''Returns the odd and even part''' |
| return (self.even(x), self.odd(x)) |
|
|
|
|
| class Interactor(nn.Module): |
| def __init__(self, in_planes, splitting=True, |
| kernel = 5, dropout=0.5, groups = 1, hidden_size = 1, INN = True): |
| super(Interactor, self).__init__() |
| self.modified = INN |
| self.kernel_size = kernel |
| self.dilation = 1 |
| self.dropout = dropout |
| self.hidden_size = hidden_size |
| self.groups = groups |
| if self.kernel_size % 2 == 0: |
| pad_l = self.dilation * (self.kernel_size - 2) // 2 + 1 |
| pad_r = self.dilation * (self.kernel_size) // 2 + 1 |
|
|
| else: |
| pad_l = self.dilation * (self.kernel_size - 1) // 2 + 1 |
| pad_r = self.dilation * (self.kernel_size - 1) // 2 + 1 |
| self.splitting = splitting |
| self.split = Splitting() |
|
|
| modules_P = [] |
| modules_U = [] |
| modules_psi = [] |
| modules_phi = [] |
| prev_size = 1 |
|
|
| size_hidden = self.hidden_size |
| modules_P += [ |
| nn.ReplicationPad1d((pad_l, pad_r)), |
|
|
| nn.Conv1d(in_planes * prev_size, int(in_planes * size_hidden), |
| kernel_size=self.kernel_size, dilation=self.dilation, stride=1, groups= self.groups), |
| nn.LeakyReLU(negative_slope=0.01, inplace=True), |
|
|
| nn.Dropout(self.dropout), |
| nn.Conv1d(int(in_planes * size_hidden), in_planes, |
| kernel_size=3, stride=1, groups= self.groups), |
| nn.Tanh() |
| ] |
| modules_U += [ |
| nn.ReplicationPad1d((pad_l, pad_r)), |
| nn.Conv1d(in_planes * prev_size, int(in_planes * size_hidden), |
| kernel_size=self.kernel_size, dilation=self.dilation, stride=1, groups= self.groups), |
| nn.LeakyReLU(negative_slope=0.01, inplace=True), |
| nn.Dropout(self.dropout), |
| nn.Conv1d(int(in_planes * size_hidden), in_planes, |
| kernel_size=3, stride=1, groups= self.groups), |
| nn.Tanh() |
| ] |
|
|
| modules_phi += [ |
| nn.ReplicationPad1d((pad_l, pad_r)), |
| nn.Conv1d(in_planes * prev_size, int(in_planes * size_hidden), |
| kernel_size=self.kernel_size, dilation=self.dilation, stride=1, groups= self.groups), |
| nn.LeakyReLU(negative_slope=0.01, inplace=True), |
| nn.Dropout(self.dropout), |
| nn.Conv1d(int(in_planes * size_hidden), in_planes, |
| kernel_size=3, stride=1, groups= self.groups), |
| nn.Tanh() |
| ] |
| modules_psi += [ |
| nn.ReplicationPad1d((pad_l, pad_r)), |
| nn.Conv1d(in_planes * prev_size, int(in_planes * size_hidden), |
| kernel_size=self.kernel_size, dilation=self.dilation, stride=1, groups= self.groups), |
| nn.LeakyReLU(negative_slope=0.01, inplace=True), |
| nn.Dropout(self.dropout), |
| nn.Conv1d(int(in_planes * size_hidden), in_planes, |
| kernel_size=3, stride=1, groups= self.groups), |
| nn.Tanh() |
| ] |
| self.phi = nn.Sequential(*modules_phi) |
| self.psi = nn.Sequential(*modules_psi) |
| self.P = nn.Sequential(*modules_P) |
| self.U = nn.Sequential(*modules_U) |
|
|
| def forward(self, x): |
| if self.splitting: |
| (x_even, x_odd) = self.split(x) |
| else: |
| (x_even, x_odd) = x |
|
|
| if self.modified: |
| x_even = x_even.permute(0, 2, 1) |
| x_odd = x_odd.permute(0, 2, 1) |
|
|
| d = x_odd.mul(torch.exp(self.phi(x_even))) |
| c = x_even.mul(torch.exp(self.psi(x_odd))) |
|
|
| x_even_update = c + self.U(d) |
| x_odd_update = d - self.P(c) |
|
|
| return (x_even_update, x_odd_update) |
|
|
| else: |
| x_even = x_even.permute(0, 2, 1) |
| x_odd = x_odd.permute(0, 2, 1) |
|
|
| d = x_odd - self.P(x_even) |
| c = x_even + self.U(d) |
|
|
| return (c, d) |
|
|
|
|
| class InteractorLevel(nn.Module): |
| def __init__(self, in_planes, kernel, dropout, groups , hidden_size, INN): |
| super(InteractorLevel, self).__init__() |
| self.level = Interactor(in_planes = in_planes, splitting=True, |
| kernel = kernel, dropout=dropout, groups = groups, hidden_size = hidden_size, INN = INN) |
|
|
| def forward(self, x): |
| (x_even_update, x_odd_update) = self.level(x) |
| return (x_even_update, x_odd_update) |
|
|
| class LevelSCINet(nn.Module): |
| def __init__(self,in_planes, kernel_size, dropout, groups, hidden_size, INN): |
| super(LevelSCINet, self).__init__() |
| self.interact = InteractorLevel(in_planes= in_planes, kernel = kernel_size, dropout = dropout, groups =groups , hidden_size = hidden_size, INN = INN) |
|
|
| def forward(self, x): |
| (x_even_update, x_odd_update) = self.interact(x) |
| return x_even_update.permute(0, 2, 1), x_odd_update.permute(0, 2, 1) |
|
|
| class SCINet_Tree(nn.Module): |
| def __init__(self, in_planes, current_level, kernel_size, dropout, groups, hidden_size, INN): |
| super().__init__() |
| self.current_level = current_level |
|
|
|
|
| self.workingblock = LevelSCINet( |
| in_planes = in_planes, |
| kernel_size = kernel_size, |
| dropout = dropout, |
| groups= groups, |
| hidden_size = hidden_size, |
| INN = INN) |
|
|
|
|
| if current_level!=0: |
| self.SCINet_Tree_odd=SCINet_Tree(in_planes, current_level-1, kernel_size, dropout, groups, hidden_size, INN) |
| self.SCINet_Tree_even=SCINet_Tree(in_planes, current_level-1, kernel_size, dropout, groups, hidden_size, INN) |
| |
| def zip_up_the_pants(self, even, odd): |
| even = even.permute(1, 0, 2) |
| odd = odd.permute(1, 0, 2) |
| even_len = even.shape[0] |
| odd_len = odd.shape[0] |
| mlen = min((odd_len, even_len)) |
| _ = [] |
| for i in range(mlen): |
| _.append(even[i].unsqueeze(0)) |
| _.append(odd[i].unsqueeze(0)) |
| if odd_len < even_len: |
| _.append(even[-1].unsqueeze(0)) |
| return torch.cat(_,0).permute(1,0,2) |
| |
| def forward(self, x): |
| x_even_update, x_odd_update= self.workingblock(x) |
| |
| if self.current_level ==0: |
| return self.zip_up_the_pants(x_even_update, x_odd_update) |
| else: |
| return self.zip_up_the_pants(self.SCINet_Tree_even(x_even_update), self.SCINet_Tree_odd(x_odd_update)) |
|
|
| class EncoderTree(nn.Module): |
| def __init__(self, in_planes, num_levels, kernel_size, dropout, groups, hidden_size, INN): |
| super().__init__() |
| self.levels=num_levels |
| self.SCINet_Tree = SCINet_Tree( |
| in_planes = in_planes, |
| current_level = num_levels-1, |
| kernel_size = kernel_size, |
| dropout =dropout , |
| groups = groups, |
| hidden_size = hidden_size, |
| INN = INN) |
| |
| def forward(self, x): |
|
|
| x= self.SCINet_Tree(x) |
|
|
| return x |
|
|
| class Model(nn.Module): |
| def __init__(self, args, input_dim = 9, hid_size = 1, num_stacks = 1, |
| num_levels = 3, num_decoder_layer = 1, concat_len = 0, groups = 1, kernel = 5, dropout = 0.5, |
| single_step_output_One = 0, input_len_seg = 0, positionalE = False, modified = True, RIN=False): |
| super(Model, self).__init__() |
|
|
| self.input_dim = args.enc_in |
| self.input_len = args.seq_len |
| self.output_len = args.pred_len |
| self.hidden_size = args.hidden_size |
| self.num_levels = args.levels |
| self.groups = args.groups |
| self.modified = modified |
| self.kernel_size = args.kernel |
| self.dropout = args.dropout |
| self.stacks = args.stacks |
| self.single_step_output_One = single_step_output_One |
| self.concat_len = concat_len |
| self.pe = positionalE |
| self.RIN=RIN |
| self.num_decoder_layer = num_decoder_layer |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| self.blocks1 = EncoderTree( |
| in_planes=self.input_dim, |
| num_levels = self.num_levels, |
| kernel_size = self.kernel_size, |
| dropout = self.dropout, |
| groups = self.groups, |
| hidden_size = self.hidden_size, |
| INN = modified) |
|
|
| if num_stacks == 2: |
| self.blocks2 = EncoderTree( |
| in_planes=self.input_dim, |
| num_levels = self.num_levels, |
| kernel_size = self.kernel_size, |
| dropout = self.dropout, |
| groups = self.groups, |
| hidden_size = self.hidden_size, |
| INN = modified) |
|
|
| self.stacks = num_stacks |
|
|
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
| m.weight.data.normal_(0, math.sqrt(2. / n)) |
| elif isinstance(m, nn.BatchNorm2d): |
| m.weight.data.fill_(1) |
| m.bias.data.zero_() |
| elif isinstance(m, nn.Linear): |
| m.bias.data.zero_() |
| self.projection1 = nn.Conv1d(self.input_len, self.output_len, kernel_size=1, stride=1, bias=False) |
| self.div_projection = nn.ModuleList() |
| self.overlap_len = self.input_len//4 |
| self.div_len = self.input_len//6 |
|
|
| if self.num_decoder_layer > 1: |
| self.projection1 = nn.Linear(self.input_len, self.output_len) |
| for layer_idx in range(self.num_decoder_layer-1): |
| div_projection = nn.ModuleList() |
| for i in range(6): |
| lens = min(i*self.div_len+self.overlap_len,self.input_len) - i*self.div_len |
| div_projection.append(nn.Linear(lens, self.div_len)) |
| self.div_projection.append(div_projection) |
|
|
| if self.single_step_output_One: |
| if self.stacks == 2: |
| if self.concat_len: |
| self.projection2 = nn.Conv1d(self.concat_len + self.output_len, 1, |
| kernel_size = 1, bias = False) |
| else: |
| self.projection2 = nn.Conv1d(self.input_len + self.output_len, 1, |
| kernel_size = 1, bias = False) |
| else: |
| if self.stacks == 2: |
| if self.concat_len: |
| self.projection2 = nn.Conv1d(self.concat_len + self.output_len, self.output_len, |
| kernel_size = 1, bias = False) |
| else: |
| self.projection2 = nn.Conv1d(self.input_len + self.output_len, self.output_len, |
| kernel_size = 1, bias = False) |
|
|
| |
| self.pe_hidden_size = input_dim |
| if self.pe_hidden_size % 2 == 1: |
| self.pe_hidden_size += 1 |
| |
| num_timescales = self.pe_hidden_size // 2 |
| max_timescale = 10000.0 |
| min_timescale = 1.0 |
|
|
| log_timescale_increment = ( |
| math.log(float(max_timescale) / float(min_timescale)) / |
| max(num_timescales - 1, 1)) |
| temp = torch.arange(num_timescales, dtype=torch.float32) |
| inv_timescales = min_timescale * torch.exp( |
| torch.arange(num_timescales, dtype=torch.float32) * |
| -log_timescale_increment) |
| self.register_buffer('inv_timescales', inv_timescales) |
|
|
| |
| if self.RIN: |
| self.affine_weight = nn.Parameter(torch.ones(1, 1, input_dim)) |
| self.affine_bias = nn.Parameter(torch.zeros(1, 1, input_dim)) |
| |
| def get_position_encoding(self, x): |
| max_length = x.size()[1] |
| position = torch.arange(max_length, dtype=torch.float32, device=x.device) |
| temp1 = position.unsqueeze(1) |
| temp2 = self.inv_timescales.unsqueeze(0) |
| scaled_time = position.unsqueeze(1) * self.inv_timescales.unsqueeze(0) |
| signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) |
| signal = F.pad(signal, (0, 0, 0, self.pe_hidden_size % 2)) |
| signal = signal.view(1, max_length, self.pe_hidden_size) |
| |
| return signal |
|
|
| def forward(self, x): |
| assert self.input_len % (np.power(2, self.num_levels)) == 0 |
| last_value = x[:,-1:,:].detach() |
| x = x - last_value |
| if self.pe: |
| pe = self.get_position_encoding(x) |
| if pe.shape[2] > x.shape[2]: |
| x += pe[:, :, :-1] |
| else: |
| x += self.get_position_encoding(x) |
|
|
| |
| if self.RIN: |
| print('/// RIN ACTIVATED ///\r',end='') |
| means = x.mean(1, keepdim=True).detach() |
| |
| x = x - means |
| |
| stdev = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5) |
| x /= stdev |
| |
| |
| x = x * self.affine_weight + self.affine_bias |
|
|
| |
| res1 = x |
| x = self.blocks1(x) |
| x += res1 |
| if self.num_decoder_layer == 1: |
| x = self.projection1(x) |
| else: |
| x = x.permute(0,2,1) |
| for div_projection in self.div_projection: |
| output = torch.zeros(x.shape,dtype=x.dtype).cuda() |
| for i, div_layer in enumerate(div_projection): |
| div_x = x[:,:,i*self.div_len:min(i*self.div_len+self.overlap_len,self.input_len)] |
| output[:,:,i*self.div_len:(i+1)*self.div_len] = div_layer(div_x) |
| x = output |
| x = self.projection1(x) |
| x = x.permute(0,2,1) |
|
|
| if self.stacks == 1: |
| |
| if self.RIN: |
| x = x - self.affine_bias |
| x = x / (self.affine_weight + 1e-10) |
| x = x * stdev |
| x = x + means |
|
|
| return x + last_value |
|
|
| elif self.stacks == 2: |
| MidOutPut = x |
| if self.concat_len: |
| x = torch.cat((res1[:, -self.concat_len:,:], x), dim=1) |
| else: |
| x = torch.cat((res1, x), dim=1) |
|
|
| |
| res2 = x |
| x = self.blocks2(x) |
| x += res2 |
| x = self.projection2(x) |
| |
| |
| if self.RIN: |
| MidOutPut = MidOutPut - self.affine_bias |
| MidOutPut = MidOutPut / (self.affine_weight + 1e-10) |
| MidOutPut = MidOutPut * stdev |
| MidOutPut = MidOutPut + means |
|
|
| if self.RIN: |
| x = x - self.affine_bias |
| x = x / (self.affine_weight + 1e-10) |
| x = x * stdev |
| x = x + means |
|
|
| return x + last_value , MidOutPut |
|
|
|
|
| def get_variable(x): |
| x = Variable(x) |
| return x.cuda() if torch.cuda.is_available() else x |
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
|
|
| parser.add_argument('--window_size', type=int, default=96) |
| parser.add_argument('--horizon', type=int, default=12) |
|
|
| parser.add_argument('--dropout', type=float, default=0.5) |
| parser.add_argument('--groups', type=int, default=1) |
|
|
| parser.add_argument('--hidden-size', default=1, type=int, help='hidden channel of module') |
| parser.add_argument('--INN', default=1, type=int, help='use INN or basic strategy') |
| parser.add_argument('--kernel', default=3, type=int, help='kernel size') |
| parser.add_argument('--dilation', default=1, type=int, help='dilation') |
| parser.add_argument('--positionalEcoding', type=bool, default=True) |
|
|
| parser.add_argument('--single_step_output_One', type=int, default=0) |
|
|
| args = parser.parse_args() |
|
|
| model = SCINet(output_len = args.horizon, input_len= args.window_size, input_dim = 9, hid_size = args.hidden_size, num_stacks = 1, |
| num_levels = 3, concat_len = 0, groups = args.groups, kernel = args.kernel, dropout = args.dropout, |
| single_step_output_One = args.single_step_output_One, positionalE = args.positionalEcoding, modified = True).cuda() |
| x = torch.randn(32, 96, 9).cuda() |
| y = model(x) |
| print(y.shape) |
|
|