| from typing import Optional, Tuple |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
|
|
| from transformers import PreTrainedModel, GenerationMixin |
| from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions |
| from .configuration_super_linear import SuperLinearConfig |
|
|
|
|
| "-------------------------------------------------------------------------------------------------------------------" |
| class RevIN(nn.Module): |
| def __init__(self, num_features: int, eps=1e-5, affine=True, norm_type = None, subtract_last = False): |
| """ |
| :param num_features: the number of features or channels |
| :param eps: a value added for numerical stability |
| :param affine: if True, RevIN has learnable affine parameters |
| """ |
| super(RevIN, self).__init__() |
| self.num_features = num_features |
| self.eps = eps |
| self.affine = affine |
| self.subtract_last = subtract_last |
| self.norm_type = norm_type |
| if self.affine: |
| self._init_params() |
|
|
| def forward(self, x, mode:str): |
| if mode == 'norm': |
| self._get_statistics(x) |
| x = self._normalize(x) |
| elif mode == 'denorm': |
| x = self._denormalize(x) |
| else: raise NotImplementedError |
| return x |
|
|
| def _init_params(self): |
| |
| self.affine_weight = nn.Parameter(torch.ones(self.num_features)) |
| self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) |
|
|
| def _get_statistics(self, x): |
| dim2reduce = tuple(range(1, x.ndim-1)) |
|
|
| if self.subtract_last: |
| self.last = x[:,-1,:].unsqueeze(1) |
| else: |
| self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() |
| self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach() |
| if self.norm_type == "l1": |
| self.denom = torch.sum(torch.abs(x), dim=dim2reduce, keepdim=True).detach() |
| elif self.norm_type == "l2": |
| self.denom = torch.sqrt(torch.sum(x**2, dim=dim2reduce, keepdim=True)).detach() |
|
|
| |
| def _normalize(self, x): |
|
|
| if self.subtract_last: |
| x = x - self.last |
| else: |
| x = x - self.mean |
| x = x / self.stdev |
|
|
| if self.norm_type in ["l1", "l2"]: |
| x = x / self.denom |
|
|
| if self.affine: |
| x = x * self.affine_weight |
| x = x + self.affine_bias |
| return x |
|
|
| def _denormalize(self, x): |
| if self.affine: |
| x = x - self.affine_bias |
| x = x / (self.affine_weight + self.eps*self.eps) |
| if self.norm_type in ["l1", "l2"]: |
| x = x * self.denom |
| x = x * self.stdev |
| if self.subtract_last: |
| x = x + self.last |
| else: |
| x = x + self.mean |
| |
| return x |
| "-------------------------------------------------------------------------------------------------------------------" |
| class Linear(nn.Module): |
| """Simple linear layer expert.""" |
| def __init__(self, input_len, output_len): |
| super(Linear, self).__init__() |
| self.Linear = nn.Linear(input_len, output_len) |
|
|
| def forward(self, x): |
| |
| x = x.clone() |
| x = self.Linear(x).clone() |
| return x |
| |
| class Naive(nn.Module): |
| """Naive forecasting expert - repeats last value.""" |
| def __init__(self, input_len, output_len): |
| super(Naive, self).__init__() |
| self.output_len = output_len |
|
|
| def forward(self, x): |
| |
| x = x[:,-1].unsqueeze(1).repeat(1, self.output_len) |
| return x |
| |
| class Mean(nn.Module): |
| """Mean forecasting expert - repeats mean value.""" |
| def __init__(self, input_len, output_len): |
| super(Mean, self).__init__() |
| self.output_len = output_len |
|
|
| def forward(self, x): |
| |
| x = x.mean(dim=1).unsqueeze(1).repeat(1, self.output_len) |
| return x |
|
|
| class RLinear(nn.Module): |
| """Reversible Instance Normalization Linear layer expert.""" |
| def __init__(self, input_len, output_len): |
| super(RLinear, self).__init__() |
| self.Linear = nn.Linear(input_len, output_len) |
| self.revin_layer = RevIN(num_features = None, affine=False, norm_type = None, subtract_last = False) |
|
|
| def forward(self, x): |
| |
| x_shape = x.shape |
| if len(x_shape) == 2: |
| x = x.unsqueeze(-1) |
| x = x.clone() |
| x = self.revin_layer(x, 'norm') |
| |
| x = self.Linear(x.permute(0,2,1)).permute(0,2,1).clone() |
| x = self.revin_layer(x, 'denorm') |
| if len(x_shape) == 2: |
| x = x.squeeze(-1) |
| return x |
|
|
| "-------------------------------------------------------------------------------------------------------------------" |
| class SparseMoE(nn.Module): |
| """ |
| Sparse Mixture of Experts (MoE) module that routes inputs to the most relevant experts. |
| |
| This implementation uses a gating network to determine which experts should process each input. |
| Only the top-k experts are used for each input, creating a sparse computation pattern. |
| |
| Args: |
| configs: Configuration object containing MoE parameters |
| experts: Collection of expert modules (neural networks) |
| """ |
| def __init__(self, configs, experts=None): |
| super(SparseMoE, self).__init__() |
| self.noise_std = configs.noisy_gating_std |
| self.experts = nn.ModuleList(experts) |
| self.num_experts = len(experts) |
| self.k = configs.top_k_experts |
| |
| if self.k > self.num_experts: |
| self.k = self.num_experts |
| |
| self.moe_temp = configs.moe_temp |
| self.use_fft = configs.use_fft |
| self.fft_len = configs.fft_len |
| self.moe_norm = configs.moe_norm |
| |
| |
| if self.use_fft: |
| self.gating_network = nn.Linear(self.fft_len//2, self.num_experts, bias=True) |
| else: |
| self.gating_network = nn.Linear(configs.train_seq_len, self.num_experts, bias=True) |
|
|
| if self.moe_norm: |
| self.gate_norm = nn.BatchNorm1d(self.num_experts) |
|
|
| def get_periodogram(self, inputs, n=10000): |
| """ |
| Calculate the periodogram (power spectral density) of input time series. |
| |
| The periodogram is used as a frequency-domain representation of the signal |
| to help the gating network identify periodic patterns. |
| |
| Args: |
| inputs: Input time series tensor of shape [batch_size, sequence_length] or [batch_size, sequence_length, features] |
| n: Number of points in FFT computation |
| |
| Returns: |
| Normalized periodogram of the input signals |
| """ |
| if inputs.dim() == 2: |
| x_0 = inputs.unsqueeze(2) |
| else: |
| x_0 = inputs |
| x_0 = x_0 - torch.mean(x_0, dim=1, keepdim=True) |
|
|
| |
| dft = torch.fft.fft(x_0, dim=1, n=n) / np.sqrt(n) |
| dft = dft[:, :n//2, :] |
| I = torch.abs(dft) ** 2 |
|
|
| |
| I_sum = torch.sum(I, dim=1, keepdim=True) |
| I_sum[I_sum == 0] = 1 |
| I = I / I_sum |
|
|
| if torch.any(I_sum == 0): |
| print("Zeros in the sum") |
| raise ValueError |
|
|
| if inputs.dim() == 2: |
| I = I.squeeze(2) |
| |
| return I |
|
|
| def forward(self, x, get_prob=False): |
| """ |
| Forward pass through the Mixture of Experts. |
| |
| Args: |
| x: Input tensor of shape [batch_size, sequence_length] |
| get_prob: Whether to return expert selection probabilities |
| |
| Returns: |
| - Output tensor from the selected experts |
| - (Optional) Expert selection probabilities if get_prob is True |
| """ |
| |
| if self.use_fft: |
| x_0 = self.get_periodogram(x, n=self.fft_len) |
| else: |
| x_0 = x |
| |
| |
| self.gate_outputs = self.gating_network(x_0) |
| |
| if self.moe_norm: |
| self.gate_outputs = self.gate_norm(self.gate_outputs) |
|
|
| |
| if not self.training: |
| self.gate_outputs = self.gate_outputs / self.moe_temp |
|
|
| |
| noise = torch.randn_like(self.gate_outputs).to(x.device) * self.noise_std |
| if self.training: |
| noisy_gate_outputs = self.gate_outputs + noise |
| self.topk_values, topk_indices = torch.topk(noisy_gate_outputs, self.k, dim=1) |
| else: |
| self.topk_values, topk_indices = torch.topk(self.gate_outputs, self.k, dim=1) |
|
|
| |
| self.topk_gates = F.softmax(self.topk_values, dim=1) |
| |
| batch_size = x.size(0) |
| |
| expert_outputs = torch.stack([self.experts[i](x) for i in range(self.num_experts)], dim=1) |
|
|
| |
| topk_indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, expert_outputs.size(2)) |
| sparse_expert_outputs = torch.gather(expert_outputs, 1, topk_indices_expanded) |
|
|
| |
| output = torch.sum(self.topk_gates.unsqueeze(2) * sparse_expert_outputs, dim=1) |
| |
| if get_prob: |
| expert_probs = F.softmax(self.gate_outputs, dim=1) |
| return output, expert_probs |
| |
| return output |
|
|
|
|
| class Model(nn.Module): |
| """ |
| Main model class that employs a Mixture of Experts for time series forecasting. |
| |
| This model can work with various types of linear layers as experts and supports |
| both standard prediction and auto-regressive prediction for longer horizons. |
| |
| Args: |
| configs: Configuration object containing model parameters |
| """ |
| def __init__(self, configs): |
| super(Model, self).__init__() |
| self.configs = configs |
| self.model_name = "SuperLinear" |
| self.train_pred_len = configs.train_pred_len |
| self.train_seq_len = configs.train_seq_len |
| self.resample_long_lookback = configs.resample_long_lookback |
| self.layer_type = configs.layer_type |
|
|
| |
| if configs.freq_experts == "": |
| self.freq_experts = None |
| else: |
| self.freq_experts = configs.freq_experts.split('_') |
|
|
| self.top_k_experts = configs.top_k_experts |
| self.freeze_experts = configs.freeze_experts |
|
|
| |
| self.experts = {} |
| if self.freq_experts is not None: |
| for expert_freq in self.freq_experts: |
| if expert_freq == "naive" or expert_freq == "Naive": |
| self.experts[expert_freq] = Naive(self.train_seq_len, self.train_pred_len) |
| elif expert_freq == "mean" or expert_freq == "Mean": |
| self.experts[expert_freq] = Mean(self.train_seq_len, self.train_pred_len) |
| else: |
| |
| expert_classes = {'Linear': Linear, 'RLinear': RLinear} |
| if self.layer_type in expert_classes: |
| expert_class = expert_classes[self.layer_type] |
| self.experts[expert_freq] = expert_class(self.train_seq_len, self.train_pred_len) |
| else: |
| |
| self.experts[expert_freq] = RLinear(self.train_seq_len, self.train_pred_len) |
| else: |
| raise ValueError("No frequency experts specified in configuration.") |
|
|
| |
| if configs.comp_moe > 0: |
| for i in range(configs.comp_moe): |
| expert_classes = {'Linear': Linear, 'RLinear': RLinear} |
| if self.layer_type in expert_classes: |
| expert_class = expert_classes[self.layer_type] |
| self.experts[f"comp_{i}"] = expert_class(self.train_seq_len, self.train_pred_len) |
| else: |
| |
| self.experts[f"comp_{i}"] = RLinear(self.train_seq_len, self.train_pred_len) |
| |
| |
| self.moe = SparseMoE(configs, experts=self.experts.values()) |
| |
| print("Experts:", self.experts.keys()) |
|
|
| def add_experts(self, experts: dict): |
| """ |
| Add new experts to the model. |
| |
| Args: |
| experts: Dictionary of expert instances to add |
| """ |
| for name, expert in experts.items(): |
| self.experts[name] = expert |
| |
| self.moe = SparseMoE(self.configs, experts=self.experts.values()) |
| return self.moe |
|
|
| def resample_seq_len(self, x, pred_len, inverse=False, orig_pred_len=None): |
| """ |
| Resample sequence length for handling inputs shorter than expected training length. |
| |
| Args: |
| x: Input tensor |
| pred_len: Prediction length |
| inverse: If True, downsample back to original scale; if False, upsample |
| orig_pred_len: Original prediction length (required for inverse=True) |
| |
| Returns: |
| Tuple of (resampled_tensor, updated_pred_len, scale_factor, orig_pred_len) |
| For inverse=True: returns (resampled_tensor, None, None, None) |
| """ |
| if not inverse: |
| |
| if x.size(-1) < self.train_seq_len: |
| scale_factor = self.train_seq_len / x.size(-1) |
| x_resampled = F.interpolate(x.unsqueeze(1), size=self.train_seq_len, mode='linear', align_corners=False).squeeze(1) |
| pred_len_resampled = int(pred_len * scale_factor) |
| return x_resampled, pred_len_resampled, scale_factor, pred_len |
| else: |
| return x, pred_len, None, None |
| else: |
| |
| if orig_pred_len is not None: |
| x_resampled = F.interpolate(x.unsqueeze(1), size=orig_pred_len, mode='linear', align_corners=False).squeeze(1) |
| return x_resampled, None, None, None |
| else: |
| return x, None, None, None |
|
|
| def forward(self, x_in, get_prob=False, pred_len=None): |
| """ |
| Forward pass through the model. |
| |
| Args: |
| x_in: Encoder input tensor |
| get_prob: Whether to return expert selection probabilities |
| pred_len: Override for prediction length |
| |
| Returns: |
| - Prediction tensor |
| - (Optional) Expert selection probabilities if get_prob is True |
| """ |
| if pred_len is None: |
| pred_len = self.train_pred_len |
|
|
| x = x_in |
| |
| if x_in.dim() == 2: |
| x = x.unsqueeze(-1) |
|
|
| |
| x = x.permute(0, 2, 1) |
| B, V, L = x.shape |
|
|
| scale_factor = None |
| orig_pred_len = None |
|
|
| |
| if self.resample_long_lookback and L < self.train_seq_len: |
| x, pred_len, scale_factor, orig_pred_len = self.resample_seq_len(x, pred_len, inverse=False) |
|
|
| |
| x = x.reshape(B * V, x.size(-1)) |
|
|
| |
| if get_prob: |
| out, expert_probs = self.moe(x, get_prob=True) |
| else: |
| out = self.moe(x) |
|
|
| if self.train_pred_len < pred_len: |
| outputs = [out] |
| ar_x = torch.cat([x, out], dim=1)[:, -self.seq_len:] |
| for i in range(0, inf_pred_len, self.train_pred_len): |
| ar_out, _ = self.moe(ar_x) |
| outputs.append(ar_out) |
| ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.seq_len:] |
| out = torch.cat(outputs, dim=1)[:, :pred_len] |
|
|
| |
| out = out.reshape(B, V, out.size(-1)) |
|
|
| |
| if scale_factor is not None: |
| out, _, _, _ = self.resample_seq_len(out, None, inverse=True, orig_pred_len=orig_pred_len) |
|
|
| |
| result = out.permute(0, 2, 1) |
| |
| if x_in.dim() == 2: |
| result = result.squeeze(-1) |
|
|
| if get_prob: |
| expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1]) |
| return result, expert_probs |
| return result |
| "-------------------------------------------------------------------------------------------------------------------" |
| class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin): |
| config_class = SuperLinearConfig |
|
|
| def __init__(self, config: SuperLinearConfig): |
| super().__init__(config) |
| |
|
|
| |
| backbone_cfg = type("Cfg", (), config.to_dict())() |
| self.args = backbone_cfg |
| self.backbone = Model(backbone_cfg) |
| self.post_init() |
|
|
| |
| |
| |
| def forward(self, |
| inputs_embeds: torch.Tensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| past_key_values: Optional[Tuple] = None, |
| use_cache: bool = True, |
| labels: Optional[torch.Tensor] = None, |
| **kwargs,) -> CausalLMOutputWithCrossAttentions: |
|
|
|
|
| if inputs_embeds is None: |
| raise ValueError("Pass the time‑series as `inputs_embeds`") |
| |
| |
| x_enc = inputs_embeds |
| |
| |
| preds = self.backbone(x_enc) |
| return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,) |
|
|
|
|
| def prepare_inputs_for_generation(self, inputs_embeds, past_key_values=None, **kwargs): |
| if past_key_values is not None: |
| |
| inputs_embeds = inputs_embeds[:, -1:, :] |
| return {"inputs_embeds": inputs_embeds, "past_key_values": past_key_values} |
|
|
| def _reorder_cache(self, past, beam_idx, **kwargs): |
| return past |
|
|
|
|