Update modeling_super_linear.py
Browse files- modeling_super_linear.py +4 -12
modeling_super_linear.py
CHANGED
|
@@ -98,16 +98,8 @@ class moving_avg(nn.Module):
|
|
| 98 |
super(moving_avg, self).__init__()
|
| 99 |
self.kernel_size = kernel_size
|
| 100 |
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
# padding on the both ends of time series
|
| 104 |
-
front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
|
| 105 |
-
end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
|
| 106 |
-
x = torch.cat([front, x, end], dim=1)
|
| 107 |
-
x = self.avg(x.permute(0, 2, 1))
|
| 108 |
-
x = x.permute(0, 2, 1)
|
| 109 |
-
return x
|
| 110 |
-
"""
|
| 111 |
def forward(self, x):
|
| 112 |
# x: [Batch, Input length]
|
| 113 |
# padding on the both ends of time series
|
|
@@ -236,7 +228,7 @@ class RLinear(nn.Module):
|
|
| 236 |
def forward(self, x):
|
| 237 |
# x: [Batch, Input length,Channel]
|
| 238 |
x_shape = x.shape
|
| 239 |
-
if x.shape[1] < self.seq_len:
|
| 240 |
#if self.zero_shot_Linear is None:
|
| 241 |
#print(F"new Lookkback : {x.shape[1]}")
|
| 242 |
|
|
@@ -244,7 +236,7 @@ class RLinear(nn.Module):
|
|
| 244 |
x = self.revin_layer(x, 'norm')
|
| 245 |
x = F.linear(x, self.zero_shot_Linear)
|
| 246 |
x = self.revin_layer(x, 'denorm')
|
| 247 |
-
return x
|
| 248 |
|
| 249 |
|
| 250 |
if len(x_shape) == 2:
|
|
|
|
| 98 |
super(moving_avg, self).__init__()
|
| 99 |
self.kernel_size = kernel_size
|
| 100 |
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
|
| 101 |
+
|
| 102 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
def forward(self, x):
|
| 104 |
# x: [Batch, Input length]
|
| 105 |
# padding on the both ends of time series
|
|
|
|
| 228 |
def forward(self, x):
|
| 229 |
# x: [Batch, Input length,Channel]
|
| 230 |
x_shape = x.shape
|
| 231 |
+
''''if x.shape[1] < self.seq_len:
|
| 232 |
#if self.zero_shot_Linear is None:
|
| 233 |
#print(F"new Lookkback : {x.shape[1]}")
|
| 234 |
|
|
|
|
| 236 |
x = self.revin_layer(x, 'norm')
|
| 237 |
x = F.linear(x, self.zero_shot_Linear)
|
| 238 |
x = self.revin_layer(x, 'denorm')
|
| 239 |
+
return x'''
|
| 240 |
|
| 241 |
|
| 242 |
if len(x_shape) == 2:
|