SuperLinear / configuration_super_linear.py
lirannoc's picture
Update configuration_super_linear.py
b7195ad verified
raw
history blame
4.08 kB
from typing import Optional, Tuple
from transformers import (
PretrainedConfig,
PreTrainedModel,
GenerationMixin,
AutoConfig,
AutoModelForCausalLM,
)
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
# 1) --------------------------------------------------------------------------
# CONFIG
# -----------------------------------------------------------------------------
class SuperLinearConfig(PretrainedConfig):
"""
Configuration for the SuperLinear MoE time–series foundation model.
Only *model_type* must be unique inside transformers; the rest mirrors
the __init__ arguments of your original Config object.
"""
model_type = "super_linear"
def __init__(
self,
# Model architecture parameters
train_seq_len=512,
train_pred_len=96,
seq_len=512,
pred_len=96,
inf_pred_len=96,
max_horizon=96,
auto_regressive=1,
# MoE parameters
moe_n_experts=4,
top_k_experts=12,
noisy_gating_std=0.1,
moe_temp=1.0,
moe_norm=False,
layer_type='RLinear',
n_experts=4,
comp_moe=12,
freeze_experts=True,
moe=1,
# FFT-based gating parameters
use_fft=True,
fft_len=5000,
# Expert configuration
freq_experts='mean_naive_1/4_1/6_1/7_1/8_1/12_1/14_1/16_1/21_1/24_1/28_1/30_1/32_1/36_1/42_1/48_1/52_1/56_1/60_1/72_1/84_1/90_1/96_1/120_1/144_1/168_1/180_1/224_1/252_1/288_1/336_1/365_1/504_1/672_1/1008_1/1440_1/2016_1/3600',
# Model loading and saving
load_linear=True,
load_weights_full=True,
linear_freq_weights_path='./weights/linear_freq_weights/',
full_weights_path='./weights/full_weights/checkpoint.pth',
# Training parameters
resample_long_lookback=False,
# Legacy parameters for backward compatibility
linear_checkpoints_path='/cs/azencot_fsas/MoE/',
linear_checkpoints_dir="checkpoints5",
manual_moe=0,
misc_moe=1,
noisy_gating_std_decay=1,
ker_len=50,
con=0,
d_model=512,
mlp_gating=1,
dropout=0.0,
**kwargs,
):
# Model architecture parameters
self.train_seq_len = train_seq_len
self.train_pred_len = train_pred_len
self.seq_len = seq_len
self.pred_len = pred_len
self.inf_pred_len = inf_pred_len
self.max_horizon = max_horizon
self.auto_regressive = auto_regressive
# MoE parameters
self.moe = moe
self.moe_n_experts = moe_n_experts
self.top_k_experts = top_k_experts
self.noisy_gating_std = noisy_gating_std
self.moe_temp = moe_temp
self.moe_norm = moe_norm
self.layer_type = layer_type
self.n_experts = n_experts
self.comp_moe = comp_moe
self.freeze_experts = freeze_experts
# FFT-based gating parameters
self.use_fft = use_fft
self.fft_len = fft_len
# Expert configuration
self.freq_experts = freq_experts
# Model loading and saving
self.load_linear = load_linear
self.load_weights_full = load_weights_full
self.linear_freq_weights_path = linear_freq_weights_path
self.full_weights_path = full_weights_path
# Training parameters
self.resample_long_lookback = resample_long_lookback
# Legacy parameters for backward compatibility
self.linear_checkpoints_path = linear_checkpoints_path
self.linear_checkpoints_dir = linear_checkpoints_dir
self.manual_moe = manual_moe
self.misc_moe = misc_moe
self.noisy_gating_std_decay = noisy_gating_std_decay
self.ker_len = ker_len
self.con = con
self.d_model = d_model
self.mlp_gating = mlp_gating
self.dropout = dropout
super().__init__(**kwargs)