File size: 3,115 Bytes
431f8e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7624ca4
 
 
 
56cf228
7624ca4
 
56cf228
7624ca4
 
 
56cf228
7624ca4
 
 
 
3513eca
 
56cf228
 
00425d5
7624ca4
 
 
 
 
 
 
431f8e0
 
 
 
 
 
 
 
 
 
 
ca53591
 
3f3fadd
ca53591
6a3442d
b027194
6a3442d
 
 
 
 
00425d5
6a3442d
 
 
 
431f8e0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from typing import Optional, Tuple
import torch, torch.nn as nn, torch.nn.functional as F

from transformers import (
    PretrainedConfig,
    PreTrainedModel,
    GenerationMixin,
    AutoConfig,
    AutoModelForCausalLM,
)
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions

# 1) --------------------------------------------------------------------------
# CONFIG
# -----------------------------------------------------------------------------


class SuperLinearConfig(PretrainedConfig):
    """
    Configuration for the SuperLinear MoE time–series foundation model.
    Only *model_type* must be unique inside transformers; the rest mirrors
    the __init__ arguments of your original Config object.
    """

    model_type = "super_linear"      

    def __init__(
        self,
        seq_len=512,
        pred_len=96,
        inf_pred_len=96,
        max_horizon=96,
        moe_n_experts=12,
        top_k_experts=5,
        moe =1,
        freq_experts= 'mean_naive_1/4_1/6_1/7_1/8_1/12_1/14_1/16_1/21_1/24_1/28_1/30_1/32_1/36_1/42_1/48_1/52_1/56_1/60_1/72_1/84_1/90_1/96_1/120_1/144_1/168_1/180_1/224_1/252_1/288_1/336_1/365_1/504_1/672_1/1008_1/1440_1/2016_1/3600',
        auto_regressive= 1,
        d_model= 128,
        dropout= 0.0,
        fft_len= 5000,
        freeze_experts= 1,
        layer_type= "RLinear",
        linear_checkpoints_dir= "checkpoints5",
        linear_checkpoints_path= "/cs/azencot_fsas/MoE/",
        load_linear = 0,
        load_weights =0,
        misc_moe = 10,
        mlp_gating = 0,
        moe_norm = 1,
        model_type= "super_linear",
        moe_temp = 1,
        noisy_gating_std = 0.1,
        noisy_gating_std_decay = 1,
        torch_dtype = "float32",
        transformers_version = "4.40.1",
        use_fft = 1,
        **kwargs,                          # any extra CLI args
    ):
        self.seq_len         = seq_len
        self.moe             = moe
        self.pred_len        = pred_len
        self.inf_pred_len    = inf_pred_len
        self.max_horizon     = max_horizon
        self.auto_regressive = auto_regressive
        self.moe_n_experts   = moe_n_experts
        self.top_k_experts   = top_k_experts
        self.freq_experts    = freq_experts
        self.freeze_experts  = freeze_experts
        self.layer_type      = layer_type
        self.linear_checkpoints_path  = linear_checkpoints_path
        self.linear_checkpoints_dir   = linear_checkpoints_dir
        self.load_linear              = load_linear 
        self.load_weights             = load_weights
        self.misc_moe                = misc_moe
        self.noisy_gating_std        = noisy_gating_std
        self.noisy_gating_std_decay  = noisy_gating_std_decay
        self.d_model                 = d_model
        self.mlp_gating              = mlp_gating
        self.moe_norm                = moe_norm
        self.moe_temp                = moe_temp
        self.use_fft                 = use_fft
        self.fft_len                 = fft_len
        self.dropout                 = dropout
        super().__init__(**kwargs)