razmars commited on
Commit
97b2995
·
verified ·
1 Parent(s): 1ebef58

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +4 -9
modeling_super_linear.py CHANGED
@@ -594,20 +594,15 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
594
 
595
  def upsample_interpolate(self, x,scale_factor, target_len: int = 512):
596
  # Add channel dimension if input is 2D
597
- if len(x.shape) == 2:
 
598
  x = x.unsqueeze(1) # [B, 1, L]
599
 
600
 
601
- upsample = interpolate(x,size =target_len,mode='linear') # [B, C, new_L]
602
-
603
- # Take last 500 timesteps and rearrange to [B, L, C]
604
- if target_len < 512:
605
- upsample = upsample.permute(0, 2, 1)
606
- else:
607
- upsample = upsample.permute(0, 2, 1)
608
 
609
  # If input was 2D, remove the channel dimension
610
- if x.shape[1] == 1:
611
  upsample = upsample.squeeze(-1)
612
 
613
  upsample = upsample.float()
 
594
 
595
  def upsample_interpolate(self, x,scale_factor, target_len: int = 512):
596
  # Add channel dimension if input is 2D
597
+ size = len(x.shape)
598
+ if size == 2:
599
  x = x.unsqueeze(1) # [B, 1, L]
600
 
601
 
602
+ upsample = = F.interpolate(x, size=target_len, mode='linear', align_corners=False)
 
 
 
 
 
 
603
 
604
  # If input was 2D, remove the channel dimension
605
+ if size == 2:
606
  upsample = upsample.squeeze(-1)
607
 
608
  upsample = upsample.float()