Update modeling_super_linear.py
Browse files- modeling_super_linear.py +9 -33
modeling_super_linear.py
CHANGED
|
@@ -4,6 +4,7 @@ import torch, torch.nn as nn, torch.nn.functional as F
|
|
| 4 |
from transformers import (PreTrainedModel,GenerationMixin,AutoConfig,AutoModelForCausalLM,)
|
| 5 |
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
| 6 |
from .configuration_super_linear import SuperLinearConfig
|
|
|
|
| 7 |
|
| 8 |
import datetime
|
| 9 |
|
|
@@ -562,39 +563,8 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 562 |
self.revin_layer = RevIN(num_features = None, affine=False, norm_type = None, subtract_last = False)
|
| 563 |
self.post_init()
|
| 564 |
|
| 565 |
-
# ------------------------------------------------------------------
|
| 566 |
-
# Forward pass expected by AutoModelForCausalLM
|
| 567 |
-
# ------------------------------------------------------------------
|
| 568 |
|
| 569 |
|
| 570 |
-
def upsample_dim1(self, x, target_len: int = 512, mode: str = "bicubic"):
|
| 571 |
-
# -------- bring the dim-1 axis to the PyTorch 1-D “length” position --------
|
| 572 |
-
orig_shape = x.shape
|
| 573 |
-
ndim = x.ndim
|
| 574 |
-
|
| 575 |
-
# Reshape to (N, C, L) where L is the axis we want to scale
|
| 576 |
-
if ndim == 1: # (L,)
|
| 577 |
-
x_ = x.unsqueeze(0).unsqueeze(0) # (1,1,L)
|
| 578 |
-
unstack = lambda t: t.squeeze(0).squeeze(0)
|
| 579 |
-
elif ndim == 2: # (L,C) or (C,L)
|
| 580 |
-
if orig_shape[0] == 48: # assume (L,C)
|
| 581 |
-
x_ = x.permute(1, 0).unsqueeze(0) # (1,C,L)
|
| 582 |
-
unstack = lambda t: t.squeeze(0).permute(1, 0)
|
| 583 |
-
else: # assume (C,L)
|
| 584 |
-
x_ = x.unsqueeze(0) # (1,C,L)
|
| 585 |
-
unstack = lambda t: t.squeeze(0)
|
| 586 |
-
else: # ≥3 dims, assume (B,L,C, …) with L at dim-1
|
| 587 |
-
x_ = x.transpose(1, 2) # (B,C,L,...)
|
| 588 |
-
new_order = list(range(ndim))
|
| 589 |
-
new_order[1], new_order[2] = 2, 1 # swap back later
|
| 590 |
-
unstack = lambda t: t.permute(*new_order)
|
| 591 |
-
|
| 592 |
-
# ------------------ actual interpolation in length dimension --------------
|
| 593 |
-
y = F.interpolate(x_, size=target_len, mode=mode, align_corners=False)
|
| 594 |
-
|
| 595 |
-
# ------------------ restore original dimension ordering -------------------
|
| 596 |
-
return unstack(y)
|
| 597 |
-
|
| 598 |
def fourier_interp_dim1(self,x, target_len: int = 512):
|
| 599 |
|
| 600 |
L = x.size(1)
|
|
@@ -621,7 +591,12 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 621 |
|
| 622 |
|
| 623 |
return y
|
| 624 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 625 |
|
| 626 |
|
| 627 |
|
|
@@ -643,7 +618,7 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 643 |
|
| 644 |
if x_enc.shape[1] < 512:
|
| 645 |
#x_enc = self.revin_layer(x_enc, 'norm')
|
| 646 |
-
|
| 647 |
pass
|
| 648 |
|
| 649 |
|
|
@@ -665,3 +640,4 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 665 |
def _reorder_cache(self, past, beam_idx, **kwargs):
|
| 666 |
return past # backbone keeps no KV cache
|
| 667 |
|
|
|
|
|
|
| 4 |
from transformers import (PreTrainedModel,GenerationMixin,AutoConfig,AutoModelForCausalLM,)
|
| 5 |
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
| 6 |
from .configuration_super_linear import SuperLinearConfig
|
| 7 |
+
from torch.nn.functional import interpolate
|
| 8 |
|
| 9 |
import datetime
|
| 10 |
|
|
|
|
| 563 |
self.revin_layer = RevIN(num_features = None, affine=False, norm_type = None, subtract_last = False)
|
| 564 |
self.post_init()
|
| 565 |
|
|
|
|
|
|
|
|
|
|
| 566 |
|
| 567 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 568 |
def fourier_interp_dim1(self,x, target_len: int = 512):
|
| 569 |
|
| 570 |
L = x.size(1)
|
|
|
|
| 591 |
|
| 592 |
|
| 593 |
return y
|
| 594 |
+
|
| 595 |
+
def upsample_interpolate(self,x, target_len: int = 512):
|
| 596 |
+
scale_factor = 512/x.shape[1]
|
| 597 |
+
upsample = interpolate(x, scale_factor=scale_factor, mode='linear').permute(0,2,1)[:, -500:, :]
|
| 598 |
+
print(upsample.shape)
|
| 599 |
+
|
| 600 |
|
| 601 |
|
| 602 |
|
|
|
|
| 618 |
|
| 619 |
if x_enc.shape[1] < 512:
|
| 620 |
#x_enc = self.revin_layer(x_enc, 'norm')
|
| 621 |
+
x_enc = self.upsample_interpolate(x_enc)
|
| 622 |
pass
|
| 623 |
|
| 624 |
|
|
|
|
| 640 |
def _reorder_cache(self, past, beam_idx, **kwargs):
|
| 641 |
return past # backbone keeps no KV cache
|
| 642 |
|
| 643 |
+
|