razmars commited on
Commit
6023125
·
verified ·
1 Parent(s): a0f4ed0

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +9 -33
modeling_super_linear.py CHANGED
@@ -4,6 +4,7 @@ import torch, torch.nn as nn, torch.nn.functional as F
4
  from transformers import (PreTrainedModel,GenerationMixin,AutoConfig,AutoModelForCausalLM,)
5
  from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
6
  from .configuration_super_linear import SuperLinearConfig
 
7
 
8
  import datetime
9
 
@@ -562,39 +563,8 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
562
  self.revin_layer = RevIN(num_features = None, affine=False, norm_type = None, subtract_last = False)
563
  self.post_init()
564
 
565
- # ------------------------------------------------------------------
566
- # Forward pass expected by AutoModelForCausalLM
567
- # ------------------------------------------------------------------
568
 
569
 
570
- def upsample_dim1(self, x, target_len: int = 512, mode: str = "bicubic"):
571
- # -------- bring the dim-1 axis to the PyTorch 1-D “length” position --------
572
- orig_shape = x.shape
573
- ndim = x.ndim
574
-
575
- # Reshape to (N, C, L) where L is the axis we want to scale
576
- if ndim == 1: # (L,)
577
- x_ = x.unsqueeze(0).unsqueeze(0) # (1,1,L)
578
- unstack = lambda t: t.squeeze(0).squeeze(0)
579
- elif ndim == 2: # (L,C) or (C,L)
580
- if orig_shape[0] == 48: # assume (L,C)
581
- x_ = x.permute(1, 0).unsqueeze(0) # (1,C,L)
582
- unstack = lambda t: t.squeeze(0).permute(1, 0)
583
- else: # assume (C,L)
584
- x_ = x.unsqueeze(0) # (1,C,L)
585
- unstack = lambda t: t.squeeze(0)
586
- else: # ≥3 dims, assume (B,L,C, …) with L at dim-1
587
- x_ = x.transpose(1, 2) # (B,C,L,...)
588
- new_order = list(range(ndim))
589
- new_order[1], new_order[2] = 2, 1 # swap back later
590
- unstack = lambda t: t.permute(*new_order)
591
-
592
- # ------------------ actual interpolation in length dimension --------------
593
- y = F.interpolate(x_, size=target_len, mode=mode, align_corners=False)
594
-
595
- # ------------------ restore original dimension ordering -------------------
596
- return unstack(y)
597
-
598
  def fourier_interp_dim1(self,x, target_len: int = 512):
599
 
600
  L = x.size(1)
@@ -621,7 +591,12 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
621
 
622
 
623
  return y
624
-
 
 
 
 
 
625
 
626
 
627
 
@@ -643,7 +618,7 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
643
 
644
  if x_enc.shape[1] < 512:
645
  #x_enc = self.revin_layer(x_enc, 'norm')
646
- #x_enc = self.fourier_interp_dim1(x_enc)
647
  pass
648
 
649
 
@@ -665,3 +640,4 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
665
  def _reorder_cache(self, past, beam_idx, **kwargs):
666
  return past # backbone keeps no KV cache
667
 
 
 
4
  from transformers import (PreTrainedModel,GenerationMixin,AutoConfig,AutoModelForCausalLM,)
5
  from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
6
  from .configuration_super_linear import SuperLinearConfig
7
+ from torch.nn.functional import interpolate
8
 
9
  import datetime
10
 
 
563
  self.revin_layer = RevIN(num_features = None, affine=False, norm_type = None, subtract_last = False)
564
  self.post_init()
565
 
 
 
 
566
 
567
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568
  def fourier_interp_dim1(self,x, target_len: int = 512):
569
 
570
  L = x.size(1)
 
591
 
592
 
593
  return y
594
+
595
+ def upsample_interpolate(self,x, target_len: int = 512):
596
+ scale_factor = 512/x.shape[1]
597
+ upsample = interpolate(x, scale_factor=scale_factor, mode='linear').permute(0,2,1)[:, -500:, :]
598
+ print(upsample.shape)
599
+
600
 
601
 
602
 
 
618
 
619
  if x_enc.shape[1] < 512:
620
  #x_enc = self.revin_layer(x_enc, 'norm')
621
+ x_enc = self.upsample_interpolate(x_enc)
622
  pass
623
 
624
 
 
640
  def _reorder_cache(self, past, beam_idx, **kwargs):
641
  return past # backbone keeps no KV cache
642
 
643
+