| from typing import Optional, Tuple |
| import torch, torch.nn as nn, torch.nn.functional as F |
| from transformers import ( |
| PretrainedConfig, |
| PreTrainedModel, |
| GenerationMixin, |
| AutoConfig, |
| AutoModelForCausalLM, |
| ) |
| from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions |
|
|
| |
| |
| |
|
|
|
|
| class SuperLinearConfig(PretrainedConfig): |
|
|
| model_type = "super_linear" |
| def __init__( |
| self, |
| seq_len=512, |
| pred_len=96, |
| inf_pred_len=96, |
| max_horizon=96, |
| moe_n_experts=12, |
| top_k_experts=5, |
| moe =1, |
| freq_experts= 'mean_naive_1/4_1/6_1/7_1/8_1/12_1/14_1/16_1/21_1/24_1/28_1/30_1/32_1/36_1/42_1/48_1/52_1/56_1/60_1/72_1/84_1/90_1/96_1/120_1/144_1/168_1/180_1/224_1/252_1/288_1/336_1/365_1/504_1/672_1/1008_1/1440_1/2016_1/3600', |
| auto_regressive= 1, |
| d_model= 128, |
| dropout= 0.0, |
| fft_len= 5000, |
| freeze_experts= 1, |
| layer_type= "RLinear", |
| linear_checkpoints_dir= "checkpoints5", |
| linear_checkpoints_path= "/cs/azencot_fsas/MoE/", |
| load_linear = 0, |
| load_weights =0, |
| misc_moe = 10, |
| mlp_gating = 0, |
| moe_norm = 0, |
| model_type= "super_linear", |
| moe_temp = 1, |
| noisy_gating_std = 0.1, |
| noisy_gating_std_decay = 1, |
| torch_dtype = "float32", |
| transformers_version = "4.40.1", |
| use_fft = 1, |
| train_epochs = 30, |
| patience = 5, |
| lradj = "constant", |
| learning_rate = 0.05, |
| channel_ind = 0, |
| full_size = 0, |
| **kwargs, |
| ): |
| self.seq_len = seq_len |
| self.moe = moe |
| self.pred_len = pred_len |
| self.inf_pred_len = inf_pred_len |
| self.max_horizon = max_horizon |
| self.auto_regressive = auto_regressive |
| self.moe_n_experts = moe_n_experts |
| self.top_k_experts = top_k_experts |
| self.freq_experts = freq_experts |
| self.freeze_experts = freeze_experts |
| self.layer_type = layer_type |
| self.linear_checkpoints_path = linear_checkpoints_path |
| self.linear_checkpoints_dir = linear_checkpoints_dir |
| self.load_linear = load_linear |
| self.load_weights = load_weights |
| self.misc_moe = misc_moe |
| self.noisy_gating_std = noisy_gating_std |
| self.noisy_gating_std_decay = noisy_gating_std_decay |
| self.d_model = d_model |
| self.mlp_gating = mlp_gating |
| self.moe_norm = moe_norm |
| self.moe_temp = moe_temp |
| self.use_fft = use_fft |
| self.fft_len = fft_len |
| self.dropout = dropout |
| self.train_epochs = train_epochs |
| self.patience = patience |
| self.lradj = lradj |
| self.learning_rate = learning_rate |
| self.channel_ind = channel_ind |
| self.full_size = full_size |
| super().__init__(**kwargs) |
|
|