from transformers import PretrainedConfig class YConfig3(PretrainedConfig): model_type = "ynet3" def __init__(self, **kwargs): self.dropout = kwargs.pop("dropout", 0.0) self.bos_token_id = kwargs.pop("bos_token_id", 151644) self.eos_token_id = kwargs.pop("eos_token_id", 151645) self.pad_token_id = kwargs.pop("pad_token_id", 151643) self.hidden_act = kwargs.pop("hidden_act", "silu") self.hidden_size = kwargs.pop("hidden_size", 768) self.num_hidden_layers = kwargs.pop("num_hidden_layers", 8) self.max_position_embeddings = kwargs.pop("max_position_embeddings", 8192) self.vocab_size = kwargs.pop("vocab_size", 6400) self.rms_norm_eps = kwargs.pop("rms_norm_eps", 1e-6) self.rope_theta = kwargs.pop("rope_theta", 5e4) self.rope_scaling = kwargs.pop("rope_scaling", None) self.dtype = kwargs.pop("dtype", "float32") self.self_distill = kwargs.pop("self_distill", True) self.intermediate_size = kwargs.pop("intermediate_size", 1536) self.expert_intermediate_size = kwargs.pop("expert_intermediate_size", None) or self.intermediate_size self.n_routed_experts = kwargs.pop("n_routed_experts", 0) self.moe_topk = kwargs.pop("moe_topk", 2) self.score_func = kwargs.pop("score_func", "softmax") self.n_shared_experts = kwargs.pop("n_shared_experts", 0) self.top_k_layer_dense = kwargs.pop("top_k_layer_dense", 1) self.aux_loss_alpha = kwargs.pop("aux_loss_alpha", 0.02) self.seq_aux = kwargs.pop("seq_aux", False) self.norm_topk_prob = kwargs.pop("norm_topk_prob", True) self.noisy_expert = kwargs.pop("noisy_expert", 0.0) self.moe_backend = kwargs.pop("moe_backend", "compact") self.router_bias_enabled = kwargs.pop("router_bias_enabled", True) self.router_bias_update_rate = kwargs.pop("router_bias_update_rate", 1e-3) self.router_bias_clamp = kwargs.pop("router_bias_clamp", 5.0) self.num_heads = kwargs.pop("num_heads", 12) self.mla_kv_lora_rank = kwargs.pop("mla_kv_lora_rank", 64) self.mla_qk_nope_head_dim = kwargs.pop("mla_qk_nope_head_dim", 64) self.mla_qk_rope_head_dim = kwargs.pop("mla_qk_rope_head_dim", 32) self.mla_attn_impl = kwargs.pop("mla_attn_impl", "absorb") self.qkv_lora = kwargs.pop("qkv_lora", False) super().__init__( bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, **kwargs, )