Update modeling_super_linear.py
Browse files- modeling_super_linear.py +8 -3
modeling_super_linear.py
CHANGED
|
@@ -597,11 +597,14 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 597 |
if len(x.shape) == 2:
|
| 598 |
x = x.unsqueeze(1) # [B, 1, L]
|
| 599 |
|
| 600 |
-
scale_factor = 512/x.shape[-1]
|
| 601 |
upsample = interpolate(x, scale_factor=scale_factor, mode='linear') # [B, C, new_L]
|
| 602 |
|
| 603 |
# Take last 500 timesteps and rearrange to [B, L, C]
|
| 604 |
-
|
|
|
|
|
|
|
|
|
|
| 605 |
|
| 606 |
# If input was 2D, remove the channel dimension
|
| 607 |
if x.shape[1] == 1:
|
|
@@ -636,10 +639,12 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 636 |
x_enc = self.upsample_interpolate(x_enc)
|
| 637 |
pass
|
| 638 |
|
| 639 |
-
|
| 640 |
# backbone returns (B, pred_len, C)
|
|
|
|
| 641 |
|
| 642 |
preds = self.backbone(x_enc)
|
|
|
|
| 643 |
#preds = self.fourier_downsample_dim1(preds,96)
|
| 644 |
#preds = self.revin_layer(preds, 'denorm')
|
| 645 |
|
|
|
|
| 597 |
if len(x.shape) == 2:
|
| 598 |
x = x.unsqueeze(1) # [B, 1, L]
|
| 599 |
|
| 600 |
+
scale_factor = int(np.ceil(512/x.shape[-1]))
|
| 601 |
upsample = interpolate(x, scale_factor=scale_factor, mode='linear') # [B, C, new_L]
|
| 602 |
|
| 603 |
# Take last 500 timesteps and rearrange to [B, L, C]
|
| 604 |
+
if target_len < 512:
|
| 605 |
+
upsample = upsample.permute(0, 2, 1)
|
| 606 |
+
else:
|
| 607 |
+
upsample = upsample.permute(0, 2, 1)[:, -target_len:, :]
|
| 608 |
|
| 609 |
# If input was 2D, remove the channel dimension
|
| 610 |
if x.shape[1] == 1:
|
|
|
|
| 639 |
x_enc = self.upsample_interpolate(x_enc)
|
| 640 |
pass
|
| 641 |
|
| 642 |
+
scale_factor = int(np.ceil(512/x.shape[-1]))
|
| 643 |
# backbone returns (B, pred_len, C)
|
| 644 |
+
self.backbone.inf_pred_len = 96*scale_factor
|
| 645 |
|
| 646 |
preds = self.backbone(x_enc)
|
| 647 |
+
x_enc = self.upsample_interpolate(x_enc,x_enc.shape[1)
|
| 648 |
#preds = self.fourier_downsample_dim1(preds,96)
|
| 649 |
#preds = self.revin_layer(preds, 'denorm')
|
| 650 |
|