Update modeling_super_linear.py
Browse files- modeling_super_linear.py +13 -1
modeling_super_linear.py
CHANGED
|
@@ -548,6 +548,18 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 548 |
# ------------------ restore original dimension ordering -------------------
|
| 549 |
return unstack(y)
|
| 550 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
def forward(self,
|
| 552 |
inputs_embeds: torch.Tensor = None,
|
| 553 |
attention_mask: Optional[torch.Tensor] = None,
|
|
@@ -564,7 +576,7 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 564 |
x_enc = inputs_embeds
|
| 565 |
print(x_enc.shape)
|
| 566 |
if x_enc.shape[1] < 512:
|
| 567 |
-
x_enc = self.
|
| 568 |
print(x_enc.shape)
|
| 569 |
|
| 570 |
# backbone returns (B, pred_len, C)
|
|
|
|
| 548 |
# ------------------ restore original dimension ordering -------------------
|
| 549 |
return unstack(y)
|
| 550 |
|
| 551 |
+
def fourier_interp_dim1(self,x, target_len: int = 512):
|
| 552 |
+
|
| 553 |
+
L = x.size(1)
|
| 554 |
+
assert L == 48, "dim-1 length must be 48"
|
| 555 |
+
|
| 556 |
+
X = torch.fft.rfft(x, dim=1) # (..., 25, ...)
|
| 557 |
+
pad = target_len // 2 + 1 - X.size(1)
|
| 558 |
+
X_pad = torch.cat([X, X.new_zeros(*X.shape[:-1], pad)], dim=1)
|
| 559 |
+
y = torch.fft.irfft(X_pad, n=target_len, dim=1)
|
| 560 |
+
return y
|
| 561 |
+
|
| 562 |
+
|
| 563 |
def forward(self,
|
| 564 |
inputs_embeds: torch.Tensor = None,
|
| 565 |
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
| 576 |
x_enc = inputs_embeds
|
| 577 |
print(x_enc.shape)
|
| 578 |
if x_enc.shape[1] < 512:
|
| 579 |
+
x_enc = self.fourier_interp_dim1(x_enc)
|
| 580 |
print(x_enc.shape)
|
| 581 |
|
| 582 |
# backbone returns (B, pred_len, C)
|