razmars commited on
Commit
0c42e00
·
verified ·
1 Parent(s): 8e0f9c2

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +3 -3
modeling_super_linear.py CHANGED
@@ -596,14 +596,14 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
596
  # Add channel dimension if input is 2D
597
  size = len(x.shape)
598
  if size == 2:
599
- x = x.unsqueeze(1) # [B, 1, L]
600
-
601
 
 
602
  upsample = F.interpolate(x, size=target_len, mode='linear', align_corners=False)
603
 
604
  # If input was 2D, remove the channel dimension
605
  if size == 2:
606
- upsample = upsample.squeeze(-1)
607
 
608
  upsample = upsample.float()
609
 
 
596
  # Add channel dimension if input is 2D
597
  size = len(x.shape)
598
  if size == 2:
599
+ x = x.unsqueeze(1)
 
600
 
601
+ print(target_len)
602
  upsample = F.interpolate(x, size=target_len, mode='linear', align_corners=False)
603
 
604
  # If input was 2D, remove the channel dimension
605
  if size == 2:
606
+ upsample = upsample.squeeze(1)
607
 
608
  upsample = upsample.float()
609