Update modeling_super_linear.py
Browse files- modeling_super_linear.py +15 -1
modeling_super_linear.py
CHANGED
|
@@ -263,6 +263,7 @@ class SparseNoisyMoE(nn.Module):
|
|
| 263 |
def __init__(self, configs, experts=None):
|
| 264 |
super(SparseNoisyMoE, self).__init__()
|
| 265 |
input_dim = configs.seq_len
|
|
|
|
| 266 |
output_dim = configs.pred_len
|
| 267 |
self.k = configs.top_k_experts
|
| 268 |
self.noise_std = configs.noisy_gating_std
|
|
@@ -327,6 +328,16 @@ class SparseNoisyMoE(nn.Module):
|
|
| 327 |
|
| 328 |
return I
|
| 329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
def forward(self, x, get_prob=False):
|
| 331 |
if self.use_fft:
|
| 332 |
x_0 = self.get_periodogram(x, ker_len=self.ker_len, n=self.fft_len, con=self.con)
|
|
@@ -348,6 +359,8 @@ class SparseNoisyMoE(nn.Module):
|
|
| 348 |
self.topk_gates = F.softmax(self.topk_values, dim=1)
|
| 349 |
|
| 350 |
batch_size = x.size(0)
|
|
|
|
|
|
|
| 351 |
expert_outputs = torch.stack([self.experts[i](x) for i in range(self.num_experts)], dim=1)
|
| 352 |
|
| 353 |
topk_indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, expert_outputs.size(2))
|
|
@@ -697,7 +710,8 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 697 |
|
| 698 |
if x_enc.shape[1] < 512:
|
| 699 |
#x_enc = self.revin_layer(x_enc, 'norm')
|
| 700 |
-
x_enc = self.fourier_interp_dim1(x_enc)
|
|
|
|
| 701 |
|
| 702 |
#self.backbone.inf_pred_len = 336
|
| 703 |
|
|
|
|
| 263 |
def __init__(self, configs, experts=None):
|
| 264 |
super(SparseNoisyMoE, self).__init__()
|
| 265 |
input_dim = configs.seq_len
|
| 266 |
+
self.lookback = configs.seq_len
|
| 267 |
output_dim = configs.pred_len
|
| 268 |
self.k = configs.top_k_experts
|
| 269 |
self.noise_std = configs.noisy_gating_std
|
|
|
|
| 328 |
|
| 329 |
return I
|
| 330 |
|
| 331 |
+
|
| 332 |
+
def fourier_interp_dim1(self,x, target_len: int = 512):
|
| 333 |
+
|
| 334 |
+
L = x.size(1)
|
| 335 |
+
X = torch.fft.rfft(x, dim=1) # (..., 25, ...)
|
| 336 |
+
pad = target_len // 2 + 1 - X.size(1)
|
| 337 |
+
X_pad = torch.cat([X, X.new_zeros(*X.shape[:-1], pad)], dim=1)
|
| 338 |
+
y = torch.fft.irfft(X_pad, n=target_len, dim=1)
|
| 339 |
+
return y
|
| 340 |
+
|
| 341 |
def forward(self, x, get_prob=False):
|
| 342 |
if self.use_fft:
|
| 343 |
x_0 = self.get_periodogram(x, ker_len=self.ker_len, n=self.fft_len, con=self.con)
|
|
|
|
| 359 |
self.topk_gates = F.softmax(self.topk_values, dim=1)
|
| 360 |
|
| 361 |
batch_size = x.size(0)
|
| 362 |
+
if x_enc.shape[1] < 512:
|
| 363 |
+
x = self.fourier_interp_dim1(x)
|
| 364 |
expert_outputs = torch.stack([self.experts[i](x) for i in range(self.num_experts)], dim=1)
|
| 365 |
|
| 366 |
topk_indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, expert_outputs.size(2))
|
|
|
|
| 710 |
|
| 711 |
if x_enc.shape[1] < 512:
|
| 712 |
#x_enc = self.revin_layer(x_enc, 'norm')
|
| 713 |
+
#x_enc = self.fourier_interp_dim1(x_enc)
|
| 714 |
+
pass
|
| 715 |
|
| 716 |
#self.backbone.inf_pred_len = 336
|
| 717 |
|