razmars commited on
Commit
5e056b1
·
verified ·
1 Parent(s): 9b041ba

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +33 -6
modeling_super_linear.py CHANGED
@@ -518,6 +518,36 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
518
  # ------------------------------------------------------------------
519
  # Forward pass expected by AutoModelForCausalLM
520
  # ------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521
  def forward(self,
522
  inputs_embeds: torch.Tensor = None,
523
  attention_mask: Optional[torch.Tensor] = None,
@@ -532,13 +562,10 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
532
 
533
  # backbone expects (B, C, L)
534
  x_enc = inputs_embeds
535
- #print(x_enc.shape)
536
  if x_enc.shape[1] < 512:
537
- if len(x_enc.shape) == 2:
538
- x_enc = F.interpolate(x_enc.unsqueeze(0),size = 512,mode="linear", align_corners=False)
539
- else:
540
- x_enc = F.interpolate(x_enc,size = 512,mode="linear", align_corners=False)
541
-
542
 
543
  # backbone returns (B, pred_len, C)
544
  preds = self.backbone(x_enc)
 
518
  # ------------------------------------------------------------------
519
  # Forward pass expected by AutoModelForCausalLM
520
  # ------------------------------------------------------------------
521
+
522
+
523
+ def upsample_dim1(self, x, target_len: int = 512, mode: str = "linear"):
524
+ # -------- bring the dim-1 axis to the PyTorch 1-D “length” position --------
525
+ orig_shape = x.shape
526
+ ndim = x.ndim
527
+
528
+ # Reshape to (N, C, L) where L is the axis we want to scale
529
+ if ndim == 1: # (L,)
530
+ x_ = x.unsqueeze(0).unsqueeze(0) # (1,1,L)
531
+ unstack = lambda t: t.squeeze(0).squeeze(0)
532
+ elif ndim == 2: # (L,C) or (C,L)
533
+ if orig_shape[0] == 48: # assume (L,C)
534
+ x_ = x.permute(1, 0).unsqueeze(0) # (1,C,L)
535
+ unstack = lambda t: t.squeeze(0).permute(1, 0)
536
+ else: # assume (C,L)
537
+ x_ = x.unsqueeze(0) # (1,C,L)
538
+ unstack = lambda t: t.squeeze(0)
539
+ else: # ≥3 dims, assume (B,L,C, …) with L at dim-1
540
+ x_ = x.transpose(1, 2) # (B,C,L,...)
541
+ new_order = list(range(ndim))
542
+ new_order[1], new_order[2] = 2, 1 # swap back later
543
+ unstack = lambda t: t.permute(*new_order)
544
+
545
+ # ------------------ actual interpolation in length dimension --------------
546
+ y = F.interpolate(x_, size=target_len, mode=mode, align_corners=False)
547
+
548
+ # ------------------ restore original dimension ordering -------------------
549
+ return unstack(y)
550
+
551
  def forward(self,
552
  inputs_embeds: torch.Tensor = None,
553
  attention_mask: Optional[torch.Tensor] = None,
 
562
 
563
  # backbone expects (B, C, L)
564
  x_enc = inputs_embeds
565
+ print(x_enc.shape)
566
  if x_enc.shape[1] < 512:
567
+ x_enc = self.upsample_dim1(x_enc)
568
+ print(x_enc.shape)
 
 
 
569
 
570
  # backbone returns (B, pred_len, C)
571
  preds = self.backbone(x_enc)