File size: 4,077 Bytes
431f8e0
b7195ad
1ef6477
 
 
 
 
 
 
 
431f8e0
 
 
 
 
 
1ef6477
b7195ad
 
 
 
 
431f8e0
 
b7195ad
431f8e0
 
b7195ad
 
 
7624ca4
 
 
 
b7195ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431f8e0
b7195ad
 
 
 
 
 
 
431f8e0
b7195ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431f8e0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from typing import Optional, Tuple

from transformers import (
    PretrainedConfig,
    PreTrainedModel,
    GenerationMixin,
    AutoConfig,
    AutoModelForCausalLM,
)
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions

# 1) --------------------------------------------------------------------------
# CONFIG
# -----------------------------------------------------------------------------


class SuperLinearConfig(PretrainedConfig):
    """
    Configuration for the SuperLinear MoE time–series foundation model.
    Only *model_type* must be unique inside transformers; the rest mirrors
    the __init__ arguments of your original Config object.
    """

    model_type = "super_linear"      

    def __init__(
        self,
        # Model architecture parameters
        train_seq_len=512,
        train_pred_len=96,
        seq_len=512,
        pred_len=96,
        inf_pred_len=96,
        max_horizon=96,
        auto_regressive=1,
        
        # MoE parameters
        moe_n_experts=4,
        top_k_experts=12,
        noisy_gating_std=0.1,
        moe_temp=1.0,
        moe_norm=False,
        layer_type='RLinear',
        n_experts=4,
        comp_moe=12,
        freeze_experts=True,
        moe=1,
        
        # FFT-based gating parameters
        use_fft=True,
        fft_len=5000,
        
        # Expert configuration
        freq_experts='mean_naive_1/4_1/6_1/7_1/8_1/12_1/14_1/16_1/21_1/24_1/28_1/30_1/32_1/36_1/42_1/48_1/52_1/56_1/60_1/72_1/84_1/90_1/96_1/120_1/144_1/168_1/180_1/224_1/252_1/288_1/336_1/365_1/504_1/672_1/1008_1/1440_1/2016_1/3600',
        
        # Model loading and saving
        load_linear=True,
        load_weights_full=True,
        linear_freq_weights_path='./weights/linear_freq_weights/',
        full_weights_path='./weights/full_weights/checkpoint.pth',
        
        # Training parameters
        resample_long_lookback=False,
        
        # Legacy parameters for backward compatibility
        linear_checkpoints_path='/cs/azencot_fsas/MoE/',
        linear_checkpoints_dir="checkpoints5",
        manual_moe=0,
        misc_moe=1,
        noisy_gating_std_decay=1,
        ker_len=50,
        con=0,
        d_model=512,
        mlp_gating=1,
        dropout=0.0,
        **kwargs,
    ):
        # Model architecture parameters
        self.train_seq_len = train_seq_len
        self.train_pred_len = train_pred_len
        self.seq_len = seq_len
        self.pred_len = pred_len
        self.inf_pred_len = inf_pred_len
        self.max_horizon = max_horizon
        self.auto_regressive = auto_regressive
        
        # MoE parameters
        self.moe = moe
        self.moe_n_experts = moe_n_experts
        self.top_k_experts = top_k_experts
        self.noisy_gating_std = noisy_gating_std
        self.moe_temp = moe_temp
        self.moe_norm = moe_norm
        self.layer_type = layer_type
        self.n_experts = n_experts
        self.comp_moe = comp_moe
        self.freeze_experts = freeze_experts
        
        # FFT-based gating parameters
        self.use_fft = use_fft
        self.fft_len = fft_len
        
        # Expert configuration
        self.freq_experts = freq_experts
        
        # Model loading and saving
        self.load_linear = load_linear
        self.load_weights_full = load_weights_full
        self.linear_freq_weights_path = linear_freq_weights_path
        self.full_weights_path = full_weights_path
        
        # Training parameters
        self.resample_long_lookback = resample_long_lookback
        
        # Legacy parameters for backward compatibility
        self.linear_checkpoints_path = linear_checkpoints_path
        self.linear_checkpoints_dir = linear_checkpoints_dir
        self.manual_moe = manual_moe
        self.misc_moe = misc_moe
        self.noisy_gating_std_decay = noisy_gating_std_decay
        self.ker_len = ker_len
        self.con = con
        self.d_model = d_model
        self.mlp_gating = mlp_gating
        self.dropout = dropout
        
        super().__init__(**kwargs)