SuperLinear / modeling_super_linear.py
razmars's picture
Update modeling_super_linear.py
399c990 verified
raw
history blame
26 kB
from typing import Optional, Tuple
import torch, torch.nn as nn, torch.nn.functional as F
from transformers import (PreTrainedModel,GenerationMixin,AutoConfig,AutoModelForCausalLM,)
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from .configuration_super_linear import SuperLinearConfig
import numpy as np
import matplotlib.pyplot as plt
import os
import numpy as np
"-------------------------------------------------------------------------------------------------------------------"
class RevIN(nn.Module):
def __init__(self, num_features: int, eps=1e-5, affine=True, norm_type = None, subtract_last = False):
"""
:param num_features: the number of features or channels
:param eps: a value added for numerical stability
:param affine: if True, RevIN has learnable affine parameters
"""
super(RevIN, self).__init__()
self.num_features = num_features
self.eps = eps
self.affine = affine
self.subtract_last = subtract_last
self.norm_type = norm_type
if self.affine:
self._init_params()
def forward(self, x, mode:str):
if mode == 'norm':
self._get_statistics(x)
x = self._normalize(x)
elif mode == 'denorm':
x = self._denormalize(x)
else: raise NotImplementedError
return x
def _init_params(self):
# initialize RevIN params: (C,)
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
def _get_statistics(self, x):
dim2reduce = tuple(range(1, x.ndim-1))
if self.subtract_last:
self.last = x[:,-1,:].unsqueeze(1)
else:
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
if self.norm_type == "l1":
self.denom = torch.sum(torch.abs(x), dim=dim2reduce, keepdim=True).detach()
elif self.norm_type == "l2":
self.denom = torch.sqrt(torch.sum(x**2, dim=dim2reduce, keepdim=True)).detach()
def _normalize(self, x):
if self.subtract_last:
x = x - self.last
else:
x = x - self.mean
x = x / self.stdev
if self.norm_type in ["l1", "l2"]:
x = x / self.denom
if self.affine:
x = x * self.affine_weight
x = x + self.affine_bias
return x
def _denormalize(self, x):
if self.affine:
x = x - self.affine_bias
x = x / (self.affine_weight + self.eps*self.eps)
if self.norm_type in ["l1", "l2"]:
x = x * self.denom
x = x * self.stdev
if self.subtract_last:
x = x + self.last
else:
x = x + self.mean
return x
"-------------------------------------------------------------------------------------------------------------------"
class moving_avg(nn.Module):
"""
Moving average block to highlight the trend of time series
"""
def __init__(self, kernel_size, stride):
super(moving_avg, self).__init__()
self.kernel_size = kernel_size
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
"""
def forward(self, x):
# padding on the both ends of time series
front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
x = torch.cat([front, x, end], dim=1)
x = self.avg(x.permute(0, 2, 1))
x = x.permute(0, 2, 1)
return x
"""
def forward(self, x):
# x: [Batch, Input length]
# padding on the both ends of time series
front = x[:, 0:1].repeat(1, (self.kernel_size - 1) // 2)
end = x[:, -1:].repeat(1, (self.kernel_size - 1) // 2)
x = torch.cat([front, x, end], dim=1)
x = self.avg(x.unsqueeze(1)).squeeze(1)
return x
class series_decomp(nn.Module):
"""
Series decomposition block
"""
def __init__(self, kernel_size):
super(series_decomp, self).__init__()
self.moving_avg = moving_avg(kernel_size, stride=1)
def forward(self, x):
moving_mean = self.moving_avg(x)
res = x - moving_mean
return res, moving_mean
class DLinear(nn.Module):
def __init__(self, input_len, output_len, kernel_size = 25):
super(DLinear, self).__init__()
self.seasonal = nn.Linear(input_len, output_len)
self.trend = nn.Linear(input_len, output_len)
self.moving_avg = moving_avg(kernel_size, stride=1)
self.decompsition = series_decomp(kernel_size)
def forward(self, x):
# x: [Batch*Input length,Channel]
seasonal_init, trend_init = self.decompsition(x)
seasonal_output = self.seasonal(seasonal_init)
trend_output = self.trend(trend_init)
x = seasonal_output + trend_output
return x # to [Batch, Output length, Channel]
class Linear(nn.Module):
def __init__(self, input_len, output_len):
super(Linear, self).__init__()
self.Linear = nn.Linear(input_len, output_len)
def forward(self, x):
# x: [Batch*Channel, Input length]
x = x.clone()
x = self.Linear(x).clone()
return x # to [Batch, Output length, Channel]
class Naive(nn.Module):
def __init__(self, input_len, output_len):
super(Naive, self).__init__()
self.output_len = output_len
def forward(self, x):
# x: [Batch*Channel, Input length]
x = x[:,-1].unsqueeze(1).repeat(1, self.output_len)
return x # to [Batch, Output length, Channel]
class Mean(nn.Module):
def __init__(self, input_len, output_len):
super(Mean, self).__init__()
self.output_len = output_len
def forward(self, x):
# x: [Batch*Channel, Input length]
x = x.mean(dim=1).unsqueeze(1).repeat(1, self.output_len)
return x # to [Batch, Output length, Channel]
class NLinear(nn.Module):
def __init__(self, input_len, output_len):
super(NLinear, self).__init__()
self.Linear = nn.Linear(input_len, output_len)
def forward(self, x):
# x: [Batch, Input length,Channel]
seq_last = x[:,-1:].detach()
x = x - seq_last
x = self.Linear(x)
return x+seq_last # to [Batch, Output length, Channel]
class RLinear(nn.Module):
def __init__(self, input_len, output_len):
super(RLinear, self).__init__()
self.Linear = nn.Linear(input_len, output_len)
self.seq_len = input_len
self.horizon = output_len
self.revin_layer = RevIN(num_features = None, affine=False, norm_type = None, subtract_last = False)
self.zero_shot_Linear = None
def transform_model(self,new_lookback,mode):
if mode == 1:
W = self.Linear.weight.detach()
new_W = W[:, -new_lookback:]
original_norm = torch.norm(W, p=2)
new_norm = torch.norm(new_W, p=2)
final_scaling = original_norm / new_norm if new_norm.item() != 0 else 1.0
new_W = new_W * final_scaling
self.zero_shot_Linear = new_W
elif mode ==2:
W = self.Linear.weight.detach()
W4d = W.unsqueeze(0).unsqueeze(0) # (1, 1, out, in)
# resize H → self.horizon and W → new_lookback
new_W = F.interpolate(
W4d,
size=(self.horizon, new_lookback), # (H_out, W_out)
mode='bilinear',
align_corners=False
)[0, 0] # drop the two singleton dims
self.zero_shot_Linear = new_W # shape (self.horizon, new_lookback)
else:
W = self.Linear.weight.detach() # (out_features, seq_len)
# 1️⃣ keep the last `new_lookback` columns
W_tail = W[:, -new_lookback:] # (out_features, new_lookback)
# 2️⃣ resize those columns back to `seq_len`
# ─ use 4-D shape (N=1, C=1, H=out, W=new_lookback)
W4d = W_tail.unsqueeze(0).unsqueeze(0)
new_W = F.interpolate(
W4d,
size=(out_features, self.seq_len), # keep rows = out_features,
# stretch cols to seq_len
mode='bilinear',
align_corners=False
)[0, 0] # → (out_features, seq_len)
# 3️⃣ concatenate on the column axis
W_now = torch.cat((W_tail, new_W), dim=1) # (out_features, new_lookback + seq_len)
self.zero_shot_Linear = W_now
def forward(self, x):
# x: [Batch, Input length,Channel]
x_shape = x.shape
if x.shape[1] < self.seq_len:
#if self.zero_shot_Linear is None:
#print(F"new Lookkback : {x.shape[1]}")
self.transform_model(x.shape[1],3)
x = x.clone()
#x = x * (x.shape[1]/512)
x = self.revin_layer(x, 'norm')
x = F.linear(x, self.zero_shot_Linear)
x = self.revin_layer(x, 'denorm')
#x = x * (512/x.shape[1])
return x
if len(x_shape) == 2:
x = x.unsqueeze(-1)
x = x.clone()
x = self.revin_layer(x, 'norm')
x = self.Linear(x.permute(0,2,1)).permute(0,2,1).clone()
x = self.revin_layer(x, 'denorm')
if len(x_shape) == 2:
x = x.squeeze(-1)
return x # to [Batch, Output length, Channel]
"-------------------------------------------------------------------------------------------------------------------"
class SparseNoisyMoE(nn.Module):
def __init__(self, configs, experts=None):
super(SparseNoisyMoE, self).__init__()
input_dim = configs.seq_len
output_dim = configs.pred_len
self.k = configs.top_k_experts
self.noise_std = configs.noisy_gating_std
self.noise_std_decay = configs.noisy_gating_std_decay
self.experts = nn.ModuleList(experts)
self.num_experts = len(experts)
self.ker_len = configs.ker_len
self.con = configs.con
self.d_model = configs.d_model
self.mlp_gating = configs.mlp_gating
self.moe_temp = configs.moe_temp
self.use_fft = configs.use_fft
self.fft_len = configs.fft_len
if self.use_fft:
if self.mlp_gating:
self.gating_network = nn.Sequential(
nn.Linear(self.fft_len//2, self.d_model),
nn.ReLU(),
nn.Linear(self.d_model, self.num_experts)
)
else:
self.gating_network = nn.Linear(self.fft_len//2, self.num_experts, bias=True)
else:
self.gating_network = nn.Linear(input_dim, self.num_experts, bias=True)
def get_periodogram(self, inputs, ker_len=50, con=1, n=10000):
if inputs.dim() == 2:
x_0 = inputs.unsqueeze(2)
else:
x_0 = inputs
x_0 = x_0 - torch.mean(x_0, dim=1, keepdim=True)
v = torch.arange(0, n) / n
if con:
if ker_len is None:
ker_len = n // 4
ker_len = min(ker_len, 50)
x_0 = x_0.permute(0, 2, 1)
ker = (torch.ones(1, 1, ker_len) / ker_len).to(x_0.device)
x_c = F.conv1d(x_0, ker, padding="same")
x_c[:, :, :ker_len // 2] = x_c[:, :, ker_len // 2:ker_len // 2 + 1]
x_c[:, :, -ker_len // 2:] = x_c[:, :, -ker_len // 2 - 1:-ker_len // 2]
x_0 = x_0 - x_c
x_0 = x_0.permute(0, 2, 1)
dft = torch.fft.fft(x_0, dim=1, n=n) / np.sqrt(n)
dft = dft[:, :n//2, :]
I = torch.abs(dft) ** 2
I_sum = torch.sum(I, dim=1, keepdim=True)
I_sum[I_sum == 0] = 1
I = I / I_sum
if torch.any(I_sum == 0):
print("Zeros in the sum")
raise ValueError
if inputs.dim() == 2:
I = I.squeeze(2)
return I
def forward(self, x, get_prob=False):
if self.use_fft:
x_0 = self.get_periodogram(x, ker_len=self.ker_len, n=self.fft_len, con=self.con)
else:
x_0 = x
self.gate_outputs = self.gating_network(x_0)
if not self.training:
self.gate_outputs = self.gate_outputs / self.moe_temp
noise = torch.randn_like(self.gate_outputs).to(x.device) * self.noise_std
if self.training:
noisy_gate_outputs = self.gate_outputs + noise
self.topk_values, topk_indices = torch.topk(noisy_gate_outputs, self.k, dim=1)
else:
self.topk_values, topk_indices = torch.topk(self.gate_outputs, self.k, dim=1)
self.topk_gates = F.softmax(self.topk_values, dim=1)
batch_size = x.size(0)
expert_outputs = torch.stack([self.experts[i](x) for i in range(self.num_experts)], dim=1)
topk_indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, expert_outputs.size(2))
sparse_expert_outputs = torch.gather(expert_outputs, 1, topk_indices_expanded)
output = torch.sum(self.topk_gates.unsqueeze(2) * sparse_expert_outputs, dim=1)
load_balancing_loss = self.calculate_load_balancing_loss(self.gate_outputs, batch_size)
if get_prob:
expert_probs = F.softmax(self.gate_outputs, dim=1)
return output, load_balancing_loss, expert_probs
return output, load_balancing_loss
def calculate_load_balancing_loss(self, gate_outputs, batch_size):
gate_probs = F.softmax(gate_outputs, dim=1)
assignments = torch.argmax(gate_outputs, dim=1)
self.D = torch.zeros(self.num_experts, device=gate_outputs.device)
for i in range(self.num_experts):
self.D[i] = torch.sum(assignments == i).float() / batch_size
P = torch.mean(gate_probs, dim=0)
load_balancing_loss = torch.sum(self.D * P) * self.num_experts
return load_balancing_loss
class superLinear(nn.Module):
def __init__(self, configs):
super(superLinear, self).__init__()
self.configs = configs
self.pred_len = configs.pred_len
self.seq_len = configs.seq_len
self.inf_pred_len = configs.inf_pred_len
self.max_horizon = configs.max_horizon
self.auto_regressive = configs.auto_regressive
self.n_experts = configs.moe_n_experts
self.moe = configs.moe
if configs.freq_experts == "":
self.freq_experts = None
else:
self.freq_experts = configs.freq_experts.split('_')
self.moe_loss = None
self.top_k_experts = configs.top_k_experts
# self.noisy_gating = configs.noisy_gating
self.n_experts = configs.moe_n_experts
self.freeze_experts = configs.freeze_experts
self.layer_type = configs.layer_type
self.model_name = "SuperLinear"
self.layer_dict = {'DLinear': DLinear, 'Linear': Linear, 'NLinear': NLinear, 'RLinear': RLinear}
path = configs.linear_checkpoints_path + configs.linear_checkpoints_dir + "/"
dirs = os.listdir(path)
checkpoints_paths = [path + "/" + d + "/" + "checkpoint.pth" for d in dirs]
if self.freq_experts == "all":
self.freq_experts = []
for cp in checkpoints_paths:
if self.layer_type in cp:
cycle = cp.split("/")
self.experts = {}
if self.freq_experts is not None:
for expert_freq in self.freq_experts:
if expert_freq == "naive" or expert_freq == "Naive":
self.experts[expert_freq] = Naive(self.seq_len, self.pred_len)
elif expert_freq == "mean" or expert_freq == "Mean":
self.experts[expert_freq] = Mean(self.seq_len, self.pred_len)
else:
self.experts[expert_freq] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
if configs.load_linear:
cycle = self.map_to_cycle(expert_freq)
cycle_str = f'cycle_{cycle}/'
cycle_checkpoint_path = [cp for cp in checkpoints_paths if (cycle_str in cp and self.layer_type in cp)]
if len(cycle_checkpoint_path) > 0:
print()
print(cycle_str)
cycle_checkpoint_path = cycle_checkpoint_path[0]
#print(f'loading checkpoint with layer type: {self.layer_type} and cycle: {cycle_str}')
print(cycle_checkpoint_path)
self.experts[expert_freq].load_state_dict(torch.load(cycle_checkpoint_path))
else:
print(f"Checkpoint for {cycle_str} not found in {path}")
raise ValueError(f"Checkpoint for {cycle_str} not found in {path}")
if configs.freeze_experts:
for param in self.experts[expert_freq].parameters():
param.requires_grad = False
self.n_experts = len(self.experts)
else:
for i in range(self.n_experts):
print(f"creating expert {i}")
self.experts[str(i)] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
self.manual_moe = configs.manual_moe
if configs.misc_moe == 1:
self.experts["misc"] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
self.moe = SparseNoisyMoE(configs, experts=self.experts.values())
self.dropout = nn.Dropout(configs.dropout)
def map_to_cycle(self, freq):
if "/" in freq:
cycle = int(freq.split("/")[1])
elif "h" in freq:
cycle = 24
elif "2h":
cycle = 12
elif "3h" in freq:
cycle = 8
elif "4h" in freq:
cycle = 6
elif "D" in freq:
cycle = 7
elif "DM" in freq:
cycle = 30
elif "W" in freq:
cycle = 52
elif "M" in freq:
cycle = 12
elif "min" in freq:
cycle = 1440
elif "5min" in freq:
cycle = 288
elif "10min" in freq:
cycle = 144
elif "15min" in freq:
cycle = 96
elif "30min" in freq:
cycle = 48
else:
cycle = int(freq)
return cycle
def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None, freq=[None], get_prob=False):
if len(x_enc.shape) > 2:
x = x_enc.permute(0, 2, 1)
B, V, L = x.shape
else:
x = x_enc
B, L = x.shape
V = 1
x = x.reshape(B * V, L)
expert_probs = None
if get_prob:
out, self.moe_loss, expert_probs = self.moe(x, get_prob=True)
else:
out, self.moe_loss = self.moe(x)
if self.auto_regressive and self.max_horizon < self.inf_pred_len:
#print("bitch")
outputs = [out]
ar_x = torch.cat([x, out], dim=1)[:, -self.seq_len:]
for i in range(0, self.inf_pred_len, self.max_horizon):
ar_out, _ = self.moe(ar_x)
outputs.append(ar_out)
ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.seq_len:]
out = torch.cat(outputs, dim=1)[:, :self.inf_pred_len]
if len(x_enc.shape) > 2:
out = out.reshape(B, V, out.shape[-1])
result = out.permute(0, 2, 1)
else:
result = out
if get_prob:
expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])
return result, expert_probs
return result
"-------------------------------------------------------------------------------------------------------------------"
class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
config_class = SuperLinearConfig
def __init__(self, config: SuperLinearConfig):
super().__init__(config)
# the backbone keeps its own Config dataclass, so build one on‑the‑fly:
backbone_cfg = type("Cfg", (), config.to_dict())()
self.args = backbone_cfg
self.backbone = superLinear(backbone_cfg)
self.revin_layer = RevIN(num_features = None, affine=False, norm_type = None, subtract_last = False)
self.post_init()
# ------------------------------------------------------------------
# Forward pass expected by AutoModelForCausalLM
# ------------------------------------------------------------------
def upsample_dim1(self, x, target_len: int = 512, mode: str = "bicubic"):
# -------- bring the dim-1 axis to the PyTorch 1-D “length” position --------
orig_shape = x.shape
ndim = x.ndim
# Reshape to (N, C, L) where L is the axis we want to scale
if ndim == 1: # (L,)
x_ = x.unsqueeze(0).unsqueeze(0) # (1,1,L)
unstack = lambda t: t.squeeze(0).squeeze(0)
elif ndim == 2: # (L,C) or (C,L)
if orig_shape[0] == 48: # assume (L,C)
x_ = x.permute(1, 0).unsqueeze(0) # (1,C,L)
unstack = lambda t: t.squeeze(0).permute(1, 0)
else: # assume (C,L)
x_ = x.unsqueeze(0) # (1,C,L)
unstack = lambda t: t.squeeze(0)
else: # ≥3 dims, assume (B,L,C, …) with L at dim-1
x_ = x.transpose(1, 2) # (B,C,L,...)
new_order = list(range(ndim))
new_order[1], new_order[2] = 2, 1 # swap back later
unstack = lambda t: t.permute(*new_order)
# ------------------ actual interpolation in length dimension --------------
y = F.interpolate(x_, size=target_len, mode=mode, align_corners=False)
# ------------------ restore original dimension ordering -------------------
return unstack(y)
def fourier_interp_dim1(self,x, target_len: int = 512):
L = x.size(1)
X = torch.fft.rfft(x, dim=1) # (..., 25, ...)
pad = target_len // 2 + 1 - X.size(1)
X_pad = torch.cat([X, X.new_zeros(*X.shape[:-1], pad)], dim=1)
y = torch.fft.irfft(X_pad, n=target_len, dim=1)
return y
def fourier_downsample_dim1(self,x, target_len: int):
L = x.size(1)
# 1. Forward real FFT along dim-1
X = torch.fft.rfft(x, dim=1) # shape (..., L//2 + 1, ...)
# 2. Keep only the low-frequency bins needed for the shorter series
keep = target_len // 2 + 1 # rfft size for the target grid
X_crop = X[..., :keep] # ideal brick-wall low-pass
# 3. Inverse FFT to the shorter grid
y = torch.fft.irfft(X_crop, n=target_len, dim=1)
# 4. Renormalise amplitudes:
# irfft divides by `target_len`, whereas the forward rfft used length `L`.
# Multiply by (target_len / L) so DC and low-freq amplitudes match input.
return y
def forward(self,
inputs_embeds: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple] = None,
use_cache: bool = True,
labels: Optional[torch.Tensor] = None,
**kwargs,) -> CausalLMOutputWithCrossAttentions:
if inputs_embeds is None:
raise ValueError("Pass the time‑series as `inputs_embeds`")
# backbone expects (B, C, L)
x_enc = inputs_embeds
if x_enc.shape[1] < 512:
x_enc = self.revin_layer(x_enc, 'norm')
#x_enc = self.fourier_interp_dim1(x_enc)
#self.backbone.inf_pred_len = 336
# backbone returns (B, pred_len, C)
preds = self.backbone(x_enc)
#preds = self.fourier_downsample_dim1(preds,96)
preds = self.revin_layer(preds, 'denorm')
return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
def prepare_inputs_for_generation(self, inputs_embeds, past_key_values=None, **kwargs):
if past_key_values is not None:
# only feed the last new step
inputs_embeds = inputs_embeds[:, -1:, :]
return {"inputs_embeds": inputs_embeds, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx, **kwargs):
return past # backbone keeps no KV cache