razmars commited on
Commit
389c431
·
verified ·
1 Parent(s): e67dea9

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +15 -1
modeling_super_linear.py CHANGED
@@ -263,6 +263,7 @@ class SparseNoisyMoE(nn.Module):
263
  def __init__(self, configs, experts=None):
264
  super(SparseNoisyMoE, self).__init__()
265
  input_dim = configs.seq_len
 
266
  output_dim = configs.pred_len
267
  self.k = configs.top_k_experts
268
  self.noise_std = configs.noisy_gating_std
@@ -327,6 +328,16 @@ class SparseNoisyMoE(nn.Module):
327
 
328
  return I
329
 
 
 
 
 
 
 
 
 
 
 
330
  def forward(self, x, get_prob=False):
331
  if self.use_fft:
332
  x_0 = self.get_periodogram(x, ker_len=self.ker_len, n=self.fft_len, con=self.con)
@@ -348,6 +359,8 @@ class SparseNoisyMoE(nn.Module):
348
  self.topk_gates = F.softmax(self.topk_values, dim=1)
349
 
350
  batch_size = x.size(0)
 
 
351
  expert_outputs = torch.stack([self.experts[i](x) for i in range(self.num_experts)], dim=1)
352
 
353
  topk_indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, expert_outputs.size(2))
@@ -697,7 +710,8 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
697
 
698
  if x_enc.shape[1] < 512:
699
  #x_enc = self.revin_layer(x_enc, 'norm')
700
- x_enc = self.fourier_interp_dim1(x_enc)
 
701
 
702
  #self.backbone.inf_pred_len = 336
703
 
 
263
  def __init__(self, configs, experts=None):
264
  super(SparseNoisyMoE, self).__init__()
265
  input_dim = configs.seq_len
266
+ self.lookback = configs.seq_len
267
  output_dim = configs.pred_len
268
  self.k = configs.top_k_experts
269
  self.noise_std = configs.noisy_gating_std
 
328
 
329
  return I
330
 
331
+
332
+ def fourier_interp_dim1(self,x, target_len: int = 512):
333
+
334
+ L = x.size(1)
335
+ X = torch.fft.rfft(x, dim=1) # (..., 25, ...)
336
+ pad = target_len // 2 + 1 - X.size(1)
337
+ X_pad = torch.cat([X, X.new_zeros(*X.shape[:-1], pad)], dim=1)
338
+ y = torch.fft.irfft(X_pad, n=target_len, dim=1)
339
+ return y
340
+
341
  def forward(self, x, get_prob=False):
342
  if self.use_fft:
343
  x_0 = self.get_periodogram(x, ker_len=self.ker_len, n=self.fft_len, con=self.con)
 
359
  self.topk_gates = F.softmax(self.topk_values, dim=1)
360
 
361
  batch_size = x.size(0)
362
+ if x_enc.shape[1] < 512:
363
+ x = self.fourier_interp_dim1(x)
364
  expert_outputs = torch.stack([self.experts[i](x) for i in range(self.num_experts)], dim=1)
365
 
366
  topk_indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, expert_outputs.size(2))
 
710
 
711
  if x_enc.shape[1] < 512:
712
  #x_enc = self.revin_layer(x_enc, 'norm')
713
+ #x_enc = self.fourier_interp_dim1(x_enc)
714
+ pass
715
 
716
  #self.backbone.inf_pred_len = 336
717