File size: 2,602 Bytes
a19f1c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from transformers import PretrainedConfig


class YConfig3(PretrainedConfig):
    model_type = "ynet3"

    def __init__(self, **kwargs):
        self.dropout = kwargs.pop("dropout", 0.0)
        self.bos_token_id = kwargs.pop("bos_token_id", 151644)
        self.eos_token_id = kwargs.pop("eos_token_id", 151645)
        self.pad_token_id = kwargs.pop("pad_token_id", 151643)
        self.hidden_act = kwargs.pop("hidden_act", "silu")
        self.hidden_size = kwargs.pop("hidden_size", 768)
        self.num_hidden_layers = kwargs.pop("num_hidden_layers", 8)
        self.max_position_embeddings = kwargs.pop("max_position_embeddings", 8192)
        self.vocab_size = kwargs.pop("vocab_size", 6400)
        self.rms_norm_eps = kwargs.pop("rms_norm_eps", 1e-6)
        self.rope_theta = kwargs.pop("rope_theta", 5e4)
        self.rope_scaling = kwargs.pop("rope_scaling", None)
        self.dtype = kwargs.pop("dtype", "float32")
        self.self_distill = kwargs.pop("self_distill", True)
        self.intermediate_size = kwargs.pop("intermediate_size", 1536)
        self.expert_intermediate_size = kwargs.pop("expert_intermediate_size", None) or self.intermediate_size
        self.n_routed_experts = kwargs.pop("n_routed_experts", 0)
        self.moe_topk = kwargs.pop("moe_topk", 2)
        self.score_func = kwargs.pop("score_func", "softmax")
        self.n_shared_experts = kwargs.pop("n_shared_experts", 0)
        self.top_k_layer_dense = kwargs.pop("top_k_layer_dense", 1)
        self.aux_loss_alpha = kwargs.pop("aux_loss_alpha", 0.02)
        self.seq_aux = kwargs.pop("seq_aux", False)
        self.norm_topk_prob = kwargs.pop("norm_topk_prob", True)
        self.noisy_expert = kwargs.pop("noisy_expert", 0.0)
        self.moe_backend = kwargs.pop("moe_backend", "compact")
        self.router_bias_enabled = kwargs.pop("router_bias_enabled", True)
        self.router_bias_update_rate = kwargs.pop("router_bias_update_rate", 1e-3)
        self.router_bias_clamp = kwargs.pop("router_bias_clamp", 5.0)
        self.num_heads = kwargs.pop("num_heads", 12)
        self.mla_kv_lora_rank = kwargs.pop("mla_kv_lora_rank", 64)
        self.mla_qk_nope_head_dim = kwargs.pop("mla_qk_nope_head_dim", 64)
        self.mla_qk_rope_head_dim = kwargs.pop("mla_qk_rope_head_dim", 32)
        self.mla_attn_impl = kwargs.pop("mla_attn_impl", "absorb")
        self.qkv_lora = kwargs.pop("qkv_lora", False)
        super().__init__(
            bos_token_id=self.bos_token_id,
            eos_token_id=self.eos_token_id,
            pad_token_id=self.pad_token_id,
            **kwargs,
        )