SuperLinear / configuration_super_linear.py
razmars's picture
Update configuration_super_linear.py
1a32550 verified
raw
history blame
2.16 kB
from typing import Optional, Tuple
import torch, torch.nn as nn, torch.nn.functional as F
from transformers import (
PretrainedConfig,
PreTrainedModel,
GenerationMixin,
AutoConfig,
AutoModelForCausalLM,
)
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
# 1) --------------------------------------------------------------------------
# CONFIG
# -----------------------------------------------------------------------------
class SuperLinearConfig(PretrainedConfig):
"""
Configuration for the SuperLinear MoE time–series foundation model.
Only *model_type* must be unique inside transformers; the rest mirrors
the __init__ arguments of your original Config object.
"""
model_type = "super_linear"
def __init__(
self,
**kwargs, # any extra CLI args
):
self.seq_len = seq_len
self.moe = moe
self.pred_len = pred_len
self.inf_pred_len = inf_pred_len
self.max_horizon = max_horizon
self.auto_regressive = auto_regressive
self.moe_n_experts = moe_n_experts
self.top_k_experts = top_k_experts
self.freq_experts = freq_experts
self.freeze_experts = freeze_experts
self.layer_type = layer_type
self.linear_checkpoints_path = linear_checkpoints_path
self.linear_checkpoints_dir = linear_checkpoints_dir
self.load_linear = load_linear
self.manual_moe = manual_moe
self.misc_moe = misc_moe
self.noisy_gating_std = noisy_gating_std
self.noisy_gating_std_decay = noisy_gating_std_decay
self.ker_len = ker_len
self.con = con
self.d_model = d_model
self.mlp_gating = mlp_gating
self.moe_temp = moe_temp
self.use_fft = use_fft
self.fft_len = fft_len
self.dropout = dropout
super().__init__(**kwargs)