SuperLinear / modeling_super_linear.py
lirannoc's picture
Update modeling_super_linear.py
bd04f66 verified
raw
history blame
19.6 kB
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from transformers import PreTrainedModel, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from .configuration_super_linear import SuperLinearConfig
"-------------------------------------------------------------------------------------------------------------------"
class RevIN(nn.Module):
def __init__(self, num_features: int, eps=1e-5, affine=True, norm_type = None, subtract_last = False):
"""
:param num_features: the number of features or channels
:param eps: a value added for numerical stability
:param affine: if True, RevIN has learnable affine parameters
"""
super(RevIN, self).__init__()
self.num_features = num_features
self.eps = eps
self.affine = affine
self.subtract_last = subtract_last
self.norm_type = norm_type
if self.affine:
self._init_params()
def forward(self, x, mode:str):
if mode == 'norm':
self._get_statistics(x)
x = self._normalize(x)
elif mode == 'denorm':
x = self._denormalize(x)
else: raise NotImplementedError
return x
def _init_params(self):
# initialize RevIN params: (C,)
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
def _get_statistics(self, x):
dim2reduce = tuple(range(1, x.ndim-1))
if self.subtract_last:
self.last = x[:,-1,:].unsqueeze(1)
else:
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
if self.norm_type == "l1":
self.denom = torch.sum(torch.abs(x), dim=dim2reduce, keepdim=True).detach()
elif self.norm_type == "l2":
self.denom = torch.sqrt(torch.sum(x**2, dim=dim2reduce, keepdim=True)).detach()
def _normalize(self, x):
if self.subtract_last:
x = x - self.last
else:
x = x - self.mean
x = x / self.stdev
if self.norm_type in ["l1", "l2"]:
x = x / self.denom
if self.affine:
x = x * self.affine_weight
x = x + self.affine_bias
return x
def _denormalize(self, x):
if self.affine:
x = x - self.affine_bias
x = x / (self.affine_weight + self.eps*self.eps)
if self.norm_type in ["l1", "l2"]:
x = x * self.denom
x = x * self.stdev
if self.subtract_last:
x = x + self.last
else:
x = x + self.mean
return x
"-------------------------------------------------------------------------------------------------------------------"
class Linear(nn.Module):
"""Simple linear layer expert."""
def __init__(self, input_len, output_len):
super(Linear, self).__init__()
self.Linear = nn.Linear(input_len, output_len)
def forward(self, x):
# x: [Batch*Channel, Input length]
x = x.clone()
x = self.Linear(x).clone()
return x # to [Batch, Output length, Channel]
class Naive(nn.Module):
"""Naive forecasting expert - repeats last value."""
def __init__(self, input_len, output_len):
super(Naive, self).__init__()
self.output_len = output_len
def forward(self, x):
# x: [Batch*Channel, Input length]
x = x[:,-1].unsqueeze(1).repeat(1, self.output_len)
return x # to [Batch, Output length, Channel]
class Mean(nn.Module):
"""Mean forecasting expert - repeats mean value."""
def __init__(self, input_len, output_len):
super(Mean, self).__init__()
self.output_len = output_len
def forward(self, x):
# x: [Batch*Channel, Input length]
x = x.mean(dim=1).unsqueeze(1).repeat(1, self.output_len)
return x # to [Batch, Output length, Channel]
class RLinear(nn.Module):
"""Reversible Instance Normalization Linear layer expert."""
def __init__(self, input_len, output_len):
super(RLinear, self).__init__()
self.Linear = nn.Linear(input_len, output_len)
self.revin_layer = RevIN(num_features = None, affine=False, norm_type = None, subtract_last = False)
def forward(self, x):
# x: [Batch, Input length,Channel]
x_shape = x.shape
if len(x_shape) == 2:
x = x.unsqueeze(-1)
x = x.clone()
x = self.revin_layer(x, 'norm')
x = self.Linear(x.permute(0,2,1)).permute(0,2,1).clone()
x = self.revin_layer(x, 'denorm')
if len(x_shape) == 2:
x = x.squeeze(-1)
return x # to [Batch, Output length, Channel]
"-------------------------------------------------------------------------------------------------------------------"
class SparseMoE(nn.Module):
"""
Sparse Mixture of Experts (MoE) module that routes inputs to the most relevant experts.
This implementation uses a gating network to determine which experts should process each input.
Only the top-k experts are used for each input, creating a sparse computation pattern.
Args:
configs: Configuration object containing MoE parameters
experts: Collection of expert modules (neural networks)
"""
def __init__(self, configs, experts=None):
super(SparseMoE, self).__init__()
self.noise_std = configs.noisy_gating_std
self.experts = nn.ModuleList(experts) # Store experts in ModuleList for proper registration
self.num_experts = len(experts)
self.k = configs.top_k_experts
if self.k > self.num_experts:
self.k = self.num_experts
self.moe_temp = configs.moe_temp
self.use_fft = configs.use_fft
self.fft_len = configs.fft_len
self.moe_norm = configs.moe_norm
# Initialize gating network based on configuration
if self.use_fft:
self.gating_network = nn.Linear(self.fft_len//2, self.num_experts, bias=True)
else:
self.gating_network = nn.Linear(configs.train_seq_len, self.num_experts, bias=True)
if self.moe_norm:
self.gate_norm = nn.BatchNorm1d(self.num_experts)
def get_periodogram(self, inputs, n=10000):
"""
Calculate the periodogram (power spectral density) of input time series.
The periodogram is used as a frequency-domain representation of the signal
to help the gating network identify periodic patterns.
Args:
inputs: Input time series tensor of shape [batch_size, sequence_length] or [batch_size, sequence_length, features]
n: Number of points in FFT computation
Returns:
Normalized periodogram of the input signals
"""
if inputs.dim() == 2:
x_0 = inputs.unsqueeze(2)
else:
x_0 = inputs
x_0 = x_0 - torch.mean(x_0, dim=1, keepdim=True) # Remove mean (DC component)
# Compute FFT and normalize
dft = torch.fft.fft(x_0, dim=1, n=n) / np.sqrt(n)
dft = dft[:, :n//2, :] # Keep only positive frequencies
I = torch.abs(dft) ** 2 # Power spectral density
# Normalize periodogram
I_sum = torch.sum(I, dim=1, keepdim=True)
I_sum[I_sum == 0] = 1 # Avoid division by zero
I = I / I_sum
if torch.any(I_sum == 0):
print("Zeros in the sum")
raise ValueError
if inputs.dim() == 2:
I = I.squeeze(2)
return I
def forward(self, x, get_prob=False):
"""
Forward pass through the Mixture of Experts.
Args:
x: Input tensor of shape [batch_size, sequence_length]
get_prob: Whether to return expert selection probabilities
Returns:
- Output tensor from the selected experts
- (Optional) Expert selection probabilities if get_prob is True
"""
# Preprocess input if using FFT-based gating
if self.use_fft:
x_0 = self.get_periodogram(x, n=self.fft_len)
else:
x_0 = x
# Get gating logits
self.gate_outputs = self.gating_network(x_0) # Raw gating scores
if self.moe_norm:
self.gate_outputs = self.gate_norm(self.gate_outputs)
# Apply temperature scaling during inference
if not self.training:
self.gate_outputs = self.gate_outputs / self.moe_temp
# Add noise to gating logits during training (for exploration)
noise = torch.randn_like(self.gate_outputs).to(x.device) * self.noise_std
if self.training:
noisy_gate_outputs = self.gate_outputs + noise
self.topk_values, topk_indices = torch.topk(noisy_gate_outputs, self.k, dim=1)
else:
self.topk_values, topk_indices = torch.topk(self.gate_outputs, self.k, dim=1)
# Normalize the gate values with softmax
self.topk_gates = F.softmax(self.topk_values, dim=1)
batch_size = x.size(0)
# Get outputs from all experts
expert_outputs = torch.stack([self.experts[i](x) for i in range(self.num_experts)], dim=1)
# Select only the outputs from the top-k experts
topk_indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, expert_outputs.size(2))
sparse_expert_outputs = torch.gather(expert_outputs, 1, topk_indices_expanded)
# Combine expert outputs using the gate values
output = torch.sum(self.topk_gates.unsqueeze(2) * sparse_expert_outputs, dim=1)
if get_prob:
expert_probs = F.softmax(self.gate_outputs, dim=1)
return output, expert_probs
return output
class Model(nn.Module):
"""
Main model class that employs a Mixture of Experts for time series forecasting.
This model can work with various types of linear layers as experts and supports
both standard prediction and auto-regressive prediction for longer horizons.
Args:
configs: Configuration object containing model parameters
"""
def __init__(self, configs):
super(Model, self).__init__()
self.configs = configs
self.model_name = "SuperLinear"
self.train_pred_len = configs.train_pred_len
self.train_seq_len = configs.train_seq_len
self.resample_long_lookback = configs.resample_long_lookback
self.layer_type = configs.layer_type
# Parse frequency experts from configuration
if configs.freq_experts == "":
self.freq_experts = None
else:
self.freq_experts = configs.freq_experts.split('_')
self.top_k_experts = configs.top_k_experts
self.freeze_experts = configs.freeze_experts
# Initialize experts based on frequency specification or create generic experts
self.experts = {}
if self.freq_experts is not None:
for expert_freq in self.freq_experts:
if expert_freq == "naive" or expert_freq == "Naive":
self.experts[expert_freq] = Naive(self.train_seq_len, self.train_pred_len)
elif expert_freq == "mean" or expert_freq == "Mean":
self.experts[expert_freq] = Mean(self.train_seq_len, self.train_pred_len)
else:
# Use the appropriate expert class based on layer_type
expert_classes = {'Linear': Linear, 'RLinear': RLinear}
if self.layer_type in expert_classes:
expert_class = expert_classes[self.layer_type]
self.experts[expert_freq] = expert_class(self.train_seq_len, self.train_pred_len)
else:
# Default to RLinear if unknown layer type
self.experts[expert_freq] = RLinear(self.train_seq_len, self.train_pred_len)
else:
raise ValueError("No frequency experts specified in configuration.")
# Create additional complementary experts if specified
if configs.comp_moe > 0:
for i in range(configs.comp_moe):
expert_classes = {'Linear': Linear, 'RLinear': RLinear}
if self.layer_type in expert_classes:
expert_class = expert_classes[self.layer_type]
self.experts[f"comp_{i}"] = expert_class(self.train_seq_len, self.train_pred_len)
else:
# Default to RLinear if unknown layer type
self.experts[f"comp_{i}"] = RLinear(self.train_seq_len, self.train_pred_len)
# Initialize the MoE layer
self.moe = SparseMoE(configs, experts=self.experts.values())
print("Experts:", self.experts.keys())
def add_experts(self, experts: dict):
"""
Add new experts to the model.
Args:
experts: Dictionary of expert instances to add
"""
for name, expert in experts.items():
self.experts[name] = expert
# Reinitialize the MoE layer with the updated experts
self.moe = SparseMoE(self.configs, experts=self.experts.values())
return self.moe
def resample_seq_len(self, x, pred_len, inverse=False, orig_pred_len=None):
"""
Resample sequence length for handling inputs shorter than expected training length.
Args:
x: Input tensor
pred_len: Prediction length
inverse: If True, downsample back to original scale; if False, upsample
orig_pred_len: Original prediction length (required for inverse=True)
Returns:
Tuple of (resampled_tensor, updated_pred_len, scale_factor, orig_pred_len)
For inverse=True: returns (resampled_tensor, None, None, None)
"""
if not inverse:
# Upsample if input is shorter than training length
if x.size(-1) < self.train_seq_len:
scale_factor = self.train_seq_len / x.size(-1)
x_resampled = F.interpolate(x.unsqueeze(1), size=self.train_seq_len, mode='linear', align_corners=False).squeeze(1)
pred_len_resampled = int(pred_len * scale_factor)
return x_resampled, pred_len_resampled, scale_factor, pred_len
else:
return x, pred_len, None, None
else:
# Downsample back to original scale
if orig_pred_len is not None:
x_resampled = F.interpolate(x.unsqueeze(1), size=orig_pred_len, mode='linear', align_corners=False).squeeze(1)
return x_resampled, None, None, None
else:
return x, None, None, None
def forward(self, x_in, get_prob=False, pred_len=None):
"""
Forward pass through the model.
Args:
x_in: Encoder input tensor
get_prob: Whether to return expert selection probabilities
pred_len: Override for prediction length
Returns:
- Prediction tensor
- (Optional) Expert selection probabilities if get_prob is True
"""
if pred_len is None:
pred_len = self.train_pred_len
x = x_in
# If input is 2D, add a channel dimension
if x_in.dim() == 2:
x = x.unsqueeze(-1)
# Permute to shape [batch_size, features, sequence_length]
x = x.permute(0, 2, 1)
B, V, L = x.shape
scale_factor = None
orig_pred_len = None
# Handle resampling if input is shorter than training length
if self.resample_long_lookback and L < self.train_seq_len:
x, pred_len, scale_factor, orig_pred_len = self.resample_seq_len(x, pred_len, inverse=False)
# Reshape for MoE processing
x = x.reshape(B * V, x.size(-1))
# Forward through MoE
if get_prob:
out, expert_probs = self.moe(x, get_prob=True)
else:
out = self.moe(x)
if self.train_pred_len < pred_len:
outputs = [out]
ar_x = torch.cat([x, out], dim=1)[:, -self.train_seq_len:]
for i in range(0, pred_len, self.train_pred_len):
ar_out = self.moe(ar_x)
outputs.append(ar_out)
ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.train_seq_len:]
out = torch.cat(outputs, dim=1)[:, :pred_len]
# Reshape back
out = out.reshape(B, V, out.size(-1))
# Handle resampling back to original scale if needed
if scale_factor is not None:
out, _, _, _ = self.resample_seq_len(out, None, inverse=True, orig_pred_len=orig_pred_len)
# Return to original shape conventions
result = out.permute(0, 2, 1)
if x_in.dim() == 2:
result = result.squeeze(-1)
if get_prob:
expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])
return result, expert_probs
return result
"-------------------------------------------------------------------------------------------------------------------"
class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
config_class = SuperLinearConfig
def __init__(self, config: SuperLinearConfig):
super().__init__(config)
# the backbone keeps its own Config dataclass, so build one on‑the‑fly:
backbone_cfg = type("Cfg", (), config.to_dict())()
self.args = backbone_cfg
self.backbone = Model(backbone_cfg)
self.post_init()
# ------------------------------------------------------------------
# Forward pass expected by AutoModelForCausalLM
# ------------------------------------------------------------------
def forward(self,
inputs_embeds: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple] = None,
use_cache: bool = True,
labels: Optional[torch.Tensor] = None,
**kwargs,) -> CausalLMOutputWithCrossAttentions:
if inputs_embeds is None:
raise ValueError("Pass the time‑series as `inputs_embeds`")
# backbone expects (B, C, L)
x_enc = inputs_embeds
# backbone returns (B, pred_len, C)
preds = self.backbone(x_enc, pred_len=kwargs.get("pred_len", None))
return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
def prepare_inputs_for_generation(self, inputs_embeds, past_key_values=None, **kwargs):
if past_key_values is not None:
# only feed the last new step
inputs_embeds = inputs_embeds[:, -1:, :]
return {"inputs_embeds": inputs_embeds, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx, **kwargs):
return past # backbone keeps no KV cache