Update modeling_super_linear.py
Browse files- modeling_super_linear.py +17 -4
modeling_super_linear.py
CHANGED
|
@@ -592,10 +592,23 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 592 |
|
| 593 |
return y
|
| 594 |
|
| 595 |
-
def upsample_interpolate(self,x, target_len: int = 512):
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 599 |
|
| 600 |
|
| 601 |
|
|
|
|
| 592 |
|
| 593 |
return y
|
| 594 |
|
| 595 |
+
def upsample_interpolate(self, x, target_len: int = 512):
|
| 596 |
+
# Add channel dimension if input is 2D
|
| 597 |
+
if len(x.shape) == 2:
|
| 598 |
+
x = x.unsqueeze(1) # [B, 1, L]
|
| 599 |
+
|
| 600 |
+
scale_factor = 512/x.shape[-1]
|
| 601 |
+
upsample = interpolate(x, scale_factor=scale_factor, mode='linear') # [B, C, new_L]
|
| 602 |
+
|
| 603 |
+
# Take last 500 timesteps and rearrange to [B, L, C]
|
| 604 |
+
upsample = upsample.permute(0, 2, 1)[:, -500:, :]
|
| 605 |
+
|
| 606 |
+
# If input was 2D, remove the channel dimension
|
| 607 |
+
if x.shape[1] == 1:
|
| 608 |
+
upsample = upsample.squeeze(-1)
|
| 609 |
+
|
| 610 |
+
print(f"Upsampled shape: {upsample.shape}")
|
| 611 |
+
return upsample
|
| 612 |
|
| 613 |
|
| 614 |
|