razmars commited on
Commit
2513d6f
·
verified ·
1 Parent(s): c6c399a

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +8 -3
modeling_super_linear.py CHANGED
@@ -597,11 +597,14 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
597
  if len(x.shape) == 2:
598
  x = x.unsqueeze(1) # [B, 1, L]
599
 
600
- scale_factor = 512/x.shape[-1]
601
  upsample = interpolate(x, scale_factor=scale_factor, mode='linear') # [B, C, new_L]
602
 
603
  # Take last 500 timesteps and rearrange to [B, L, C]
604
- upsample = upsample.permute(0, 2, 1)[:, -target_len:, :]
 
 
 
605
 
606
  # If input was 2D, remove the channel dimension
607
  if x.shape[1] == 1:
@@ -636,10 +639,12 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
636
  x_enc = self.upsample_interpolate(x_enc)
637
  pass
638
 
639
-
640
  # backbone returns (B, pred_len, C)
 
641
 
642
  preds = self.backbone(x_enc)
 
643
  #preds = self.fourier_downsample_dim1(preds,96)
644
  #preds = self.revin_layer(preds, 'denorm')
645
 
 
597
  if len(x.shape) == 2:
598
  x = x.unsqueeze(1) # [B, 1, L]
599
 
600
+ scale_factor = int(np.ceil(512/x.shape[-1]))
601
  upsample = interpolate(x, scale_factor=scale_factor, mode='linear') # [B, C, new_L]
602
 
603
  # Take last 500 timesteps and rearrange to [B, L, C]
604
+ if target_len < 512:
605
+ upsample = upsample.permute(0, 2, 1)
606
+ else:
607
+ upsample = upsample.permute(0, 2, 1)[:, -target_len:, :]
608
 
609
  # If input was 2D, remove the channel dimension
610
  if x.shape[1] == 1:
 
639
  x_enc = self.upsample_interpolate(x_enc)
640
  pass
641
 
642
+ scale_factor = int(np.ceil(512/x.shape[-1]))
643
  # backbone returns (B, pred_len, C)
644
+ self.backbone.inf_pred_len = 96*scale_factor
645
 
646
  preds = self.backbone(x_enc)
647
+ x_enc = self.upsample_interpolate(x_enc,x_enc.shape[1)
648
  #preds = self.fourier_downsample_dim1(preds,96)
649
  #preds = self.revin_layer(preds, 'denorm')
650