|
|
from typing import Optional, Tuple |
|
|
import torch, torch.nn as nn, torch.nn.functional as F |
|
|
|
|
|
from transformers import ( |
|
|
PretrainedConfig, |
|
|
PreTrainedModel, |
|
|
GenerationMixin, |
|
|
AutoConfig, |
|
|
AutoModelForCausalLM, |
|
|
) |
|
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SuperLinearConfig(PretrainedConfig): |
|
|
""" |
|
|
Configuration for the SuperLinear MoE time–series foundation model. |
|
|
Only *model_type* must be unique inside transformers; the rest mirrors |
|
|
the __init__ arguments of your original Config object. |
|
|
""" |
|
|
|
|
|
model_type = "super_linear" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
seq_len=512, |
|
|
pred_len=96, |
|
|
inf_pred_len=96, |
|
|
max_horizon=96, |
|
|
moe_n_experts=8, |
|
|
top_k_experts=5, |
|
|
moe =1, |
|
|
freq_experts='mean_naive_1/6_1/7_1/8_1/12_1/14_1/16_1/21_1/24_1/28_1/30_1/32_1/36_1/42_1/48_1/52_1/56_1/60_1/72_1/84_1/96_1/120_1/144_1/168_1/180_1/224_1/252_1/288_1/336_1/365_1/504_1/672_1/1008_1/1440_1/2016_1/3600', |
|
|
auto_regressive= 1, |
|
|
con= 0, |
|
|
d_model= 128, |
|
|
dropout= 0.0, |
|
|
fft_len= 10000, |
|
|
freeze_experts= 1, |
|
|
ker_len= 50, |
|
|
layer_type= "RLinear", |
|
|
linear_checkpoints_dir= "checkpoints5", |
|
|
linear_checkpoints_path= "/cs/azencot_fsas/MoE/", |
|
|
load_linear = 1, |
|
|
manual_moe = 0, |
|
|
misc_moe = 1, |
|
|
mlp_gating = 1, |
|
|
model_type= "super_linear", |
|
|
moe_temp = 1, |
|
|
noisy_gating_std = 0.1, |
|
|
noisy_gating_std_decay = 1, |
|
|
torch_dtype = "float32", |
|
|
transformers_version = "4.40.1", |
|
|
use_fft = 1, |
|
|
**kwargs, |
|
|
): |
|
|
self.seq_len = seq_len |
|
|
self.moe = moe |
|
|
self.pred_len = pred_len |
|
|
self.inf_pred_len = inf_pred_len |
|
|
self.max_horizon = max_horizon |
|
|
self.auto_regressive = auto_regressive |
|
|
self.moe_n_experts = moe_n_experts |
|
|
self.top_k_experts = top_k_experts |
|
|
self.freq_experts = freq_experts |
|
|
self.freeze_experts = freeze_experts |
|
|
self.layer_type = layer_type |
|
|
self.linear_checkpoints_path = linear_checkpoints_path |
|
|
self.linear_checkpoints_dir = linear_checkpoints_dir |
|
|
self.load_linear = load_linear |
|
|
self.manual_moe = manual_moe |
|
|
self.misc_moe = misc_moe |
|
|
self.noisy_gating_std = noisy_gating_std |
|
|
self.noisy_gating_std_decay = noisy_gating_std_decay |
|
|
self.ker_len = ker_len |
|
|
self.con = con |
|
|
self.d_model = d_model |
|
|
self.mlp_gating = mlp_gating |
|
|
self.moe_temp = moe_temp |
|
|
self.use_fft = use_fft |
|
|
self.fft_len = fft_len |
|
|
self.dropout = dropout |
|
|
super().__init__(**kwargs) |
|
|
|