Update modeling_super_linear.py
Browse files- modeling_super_linear.py +33 -6
modeling_super_linear.py
CHANGED
|
@@ -518,6 +518,36 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 518 |
# ------------------------------------------------------------------
|
| 519 |
# Forward pass expected by AutoModelForCausalLM
|
| 520 |
# ------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 521 |
def forward(self,
|
| 522 |
inputs_embeds: torch.Tensor = None,
|
| 523 |
attention_mask: Optional[torch.Tensor] = None,
|
|
@@ -532,13 +562,10 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 532 |
|
| 533 |
# backbone expects (B, C, L)
|
| 534 |
x_enc = inputs_embeds
|
| 535 |
-
|
| 536 |
if x_enc.shape[1] < 512:
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
else:
|
| 540 |
-
x_enc = F.interpolate(x_enc,size = 512,mode="linear", align_corners=False)
|
| 541 |
-
|
| 542 |
|
| 543 |
# backbone returns (B, pred_len, C)
|
| 544 |
preds = self.backbone(x_enc)
|
|
|
|
| 518 |
# ------------------------------------------------------------------
|
| 519 |
# Forward pass expected by AutoModelForCausalLM
|
| 520 |
# ------------------------------------------------------------------
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
def upsample_dim1(self, x, target_len: int = 512, mode: str = "linear"):
|
| 524 |
+
# -------- bring the dim-1 axis to the PyTorch 1-D “length” position --------
|
| 525 |
+
orig_shape = x.shape
|
| 526 |
+
ndim = x.ndim
|
| 527 |
+
|
| 528 |
+
# Reshape to (N, C, L) where L is the axis we want to scale
|
| 529 |
+
if ndim == 1: # (L,)
|
| 530 |
+
x_ = x.unsqueeze(0).unsqueeze(0) # (1,1,L)
|
| 531 |
+
unstack = lambda t: t.squeeze(0).squeeze(0)
|
| 532 |
+
elif ndim == 2: # (L,C) or (C,L)
|
| 533 |
+
if orig_shape[0] == 48: # assume (L,C)
|
| 534 |
+
x_ = x.permute(1, 0).unsqueeze(0) # (1,C,L)
|
| 535 |
+
unstack = lambda t: t.squeeze(0).permute(1, 0)
|
| 536 |
+
else: # assume (C,L)
|
| 537 |
+
x_ = x.unsqueeze(0) # (1,C,L)
|
| 538 |
+
unstack = lambda t: t.squeeze(0)
|
| 539 |
+
else: # ≥3 dims, assume (B,L,C, …) with L at dim-1
|
| 540 |
+
x_ = x.transpose(1, 2) # (B,C,L,...)
|
| 541 |
+
new_order = list(range(ndim))
|
| 542 |
+
new_order[1], new_order[2] = 2, 1 # swap back later
|
| 543 |
+
unstack = lambda t: t.permute(*new_order)
|
| 544 |
+
|
| 545 |
+
# ------------------ actual interpolation in length dimension --------------
|
| 546 |
+
y = F.interpolate(x_, size=target_len, mode=mode, align_corners=False)
|
| 547 |
+
|
| 548 |
+
# ------------------ restore original dimension ordering -------------------
|
| 549 |
+
return unstack(y)
|
| 550 |
+
|
| 551 |
def forward(self,
|
| 552 |
inputs_embeds: torch.Tensor = None,
|
| 553 |
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
| 562 |
|
| 563 |
# backbone expects (B, C, L)
|
| 564 |
x_enc = inputs_embeds
|
| 565 |
+
print(x_enc.shape)
|
| 566 |
if x_enc.shape[1] < 512:
|
| 567 |
+
x_enc = self.upsample_dim1(x_enc)
|
| 568 |
+
print(x_enc.shape)
|
|
|
|
|
|
|
|
|
|
| 569 |
|
| 570 |
# backbone returns (B, pred_len, C)
|
| 571 |
preds = self.backbone(x_enc)
|