razmars commited on
Commit
1c6415b
·
verified ·
1 Parent(s): 5e056b1

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +13 -1
modeling_super_linear.py CHANGED
@@ -548,6 +548,18 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
548
  # ------------------ restore original dimension ordering -------------------
549
  return unstack(y)
550
 
 
 
 
 
 
 
 
 
 
 
 
 
551
  def forward(self,
552
  inputs_embeds: torch.Tensor = None,
553
  attention_mask: Optional[torch.Tensor] = None,
@@ -564,7 +576,7 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
564
  x_enc = inputs_embeds
565
  print(x_enc.shape)
566
  if x_enc.shape[1] < 512:
567
- x_enc = self.upsample_dim1(x_enc)
568
  print(x_enc.shape)
569
 
570
  # backbone returns (B, pred_len, C)
 
548
  # ------------------ restore original dimension ordering -------------------
549
  return unstack(y)
550
 
551
+ def fourier_interp_dim1(self,x, target_len: int = 512):
552
+
553
+ L = x.size(1)
554
+ assert L == 48, "dim-1 length must be 48"
555
+
556
+ X = torch.fft.rfft(x, dim=1) # (..., 25, ...)
557
+ pad = target_len // 2 + 1 - X.size(1)
558
+ X_pad = torch.cat([X, X.new_zeros(*X.shape[:-1], pad)], dim=1)
559
+ y = torch.fft.irfft(X_pad, n=target_len, dim=1)
560
+ return y
561
+
562
+
563
  def forward(self,
564
  inputs_embeds: torch.Tensor = None,
565
  attention_mask: Optional[torch.Tensor] = None,
 
576
  x_enc = inputs_embeds
577
  print(x_enc.shape)
578
  if x_enc.shape[1] < 512:
579
+ x_enc = self.fourier_interp_dim1(x_enc)
580
  print(x_enc.shape)
581
 
582
  # backbone returns (B, pred_len, C)