Update modeling_super_linear.py
Browse files- modeling_super_linear.py +4 -9
modeling_super_linear.py
CHANGED
|
@@ -594,20 +594,15 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 594 |
|
| 595 |
def upsample_interpolate(self, x,scale_factor, target_len: int = 512):
|
| 596 |
# Add channel dimension if input is 2D
|
| 597 |
-
|
|
|
|
| 598 |
x = x.unsqueeze(1) # [B, 1, L]
|
| 599 |
|
| 600 |
|
| 601 |
-
upsample = interpolate(x,size
|
| 602 |
-
|
| 603 |
-
# Take last 500 timesteps and rearrange to [B, L, C]
|
| 604 |
-
if target_len < 512:
|
| 605 |
-
upsample = upsample.permute(0, 2, 1)
|
| 606 |
-
else:
|
| 607 |
-
upsample = upsample.permute(0, 2, 1)
|
| 608 |
|
| 609 |
# If input was 2D, remove the channel dimension
|
| 610 |
-
if
|
| 611 |
upsample = upsample.squeeze(-1)
|
| 612 |
|
| 613 |
upsample = upsample.float()
|
|
|
|
| 594 |
|
| 595 |
def upsample_interpolate(self, x,scale_factor, target_len: int = 512):
|
| 596 |
# Add channel dimension if input is 2D
|
| 597 |
+
size = len(x.shape)
|
| 598 |
+
if size == 2:
|
| 599 |
x = x.unsqueeze(1) # [B, 1, L]
|
| 600 |
|
| 601 |
|
| 602 |
+
upsample = = F.interpolate(x, size=target_len, mode='linear', align_corners=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
|
| 604 |
# If input was 2D, remove the channel dimension
|
| 605 |
+
if size == 2:
|
| 606 |
upsample = upsample.squeeze(-1)
|
| 607 |
|
| 608 |
upsample = upsample.float()
|