razmars commited on
Commit
87b73e2
·
verified ·
1 Parent(s): 6023125

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +17 -4
modeling_super_linear.py CHANGED
@@ -592,10 +592,23 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
592
 
593
  return y
594
 
595
- def upsample_interpolate(self,x, target_len: int = 512):
596
- scale_factor = 512/x.shape[1]
597
- upsample = interpolate(x, scale_factor=scale_factor, mode='linear').permute(0,2,1)[:, -500:, :]
598
- print(upsample.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
599
 
600
 
601
 
 
592
 
593
  return y
594
 
595
+ def upsample_interpolate(self, x, target_len: int = 512):
596
+ # Add channel dimension if input is 2D
597
+ if len(x.shape) == 2:
598
+ x = x.unsqueeze(1) # [B, 1, L]
599
+
600
+ scale_factor = 512/x.shape[-1]
601
+ upsample = interpolate(x, scale_factor=scale_factor, mode='linear') # [B, C, new_L]
602
+
603
+ # Take last 500 timesteps and rearrange to [B, L, C]
604
+ upsample = upsample.permute(0, 2, 1)[:, -500:, :]
605
+
606
+ # If input was 2D, remove the channel dimension
607
+ if x.shape[1] == 1:
608
+ upsample = upsample.squeeze(-1)
609
+
610
+ print(f"Upsampled shape: {upsample.shape}")
611
+ return upsample
612
 
613
 
614