JingzeShi commited on
Commit
0d17fef
·
verified ·
1 Parent(s): 407c69e

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration_doge.py +241 -0
  2. modeling_doge.py +696 -0
configuration_doge.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/doge/modular_doge.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_doge.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 Jingze Shi and the HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ # The Doge family of small language models is trained by SmallDoge Team.
11
+ #
12
+ # Licensed under the Apache License, Version 2.0 (the "License");
13
+ # you may not use this file except in compliance with the License.
14
+ # You may obtain a copy of the License at
15
+ #
16
+ # http://www.apache.org/licenses/LICENSE-2.0
17
+ #
18
+ # Unless required by applicable law or agreed to in writing, software
19
+ # distributed under the License is distributed on an "AS IS" BASIS,
20
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21
+ # See the License for the specific language governing permissions and
22
+ # limitations under the License.
23
+ from transformers.configuration_utils import PretrainedConfig
24
+ from transformers.modeling_rope_utils import rope_config_validation
25
+
26
+
27
+ class DogeConfig(PretrainedConfig):
28
+ r"""
29
+ This is the configuration class to store the configuration of a [`DogeModel`]. It is used to instantiate an Doge
30
+ model according to the specified arguments, defining the model architecture like [SmallDoge/Doge-320M](https://huggingface.co/SmallDoge/Doge-320M).
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+ Args:
36
+ vocab_size (`int`, *optional*, defaults to 32768):
37
+ Vocabulary size of the Doge2 model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`DogeModel`]
38
+ hidden_size (`int`, *optional*, defaults to 1024):
39
+ Dimension of the hidden representations.
40
+ intermediate_size (`int`, *optional*, defaults to 2048):
41
+ Dimension of the MLP representations.
42
+ num_hidden_layers (`int`, *optional*, defaults to 32):
43
+ Number of hidden layers in the Transformer decoder.
44
+ hidden_dropout (`float`, *optional*, defaults to 0.0):
45
+ Dropout probability for each sequence transformation and state transformation module.
46
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
47
+ The non-linear activation function (function or string) in the decoder.
48
+ initializer_range (`float`, *optional*, defaults to 0.02):
49
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
50
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
51
+ The epsilon used by the rms normalization layers.
52
+ use_cache (`bool`, *optional*, defaults to `True`):
53
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
54
+ relevant if `config.is_decoder=True`.
55
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
56
+ Whether the model's input and output word embeddings should be tied.
57
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
58
+ The maximum sequence length that this model might ever be used with.
59
+ rope_theta (`float`, *optional*, defaults to 10000.0):
60
+ The base period of the RoPE embeddings.
61
+ rope_scaling (`Dict`, *optional*):
62
+ Dictionary containing the scaling configuration for the RoPE embeddings.
63
+ NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value accordingly.
64
+ Doge family of small models use `{ 'rope_type': 'dynamic', 'factor': 4.0, 'original_max_position_embeddings': 2048 }` as the default value.
65
+ Expected contents:
66
+ `rope_type` (`str`):
67
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', 'llama3'], with 'default' being the original RoPE implementation.
68
+ `factor` (`float`, *optional*):
69
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings.
70
+ In most scaling types, a `factor` of x will enable the model to handle sequences of length x * original maximum pre-trained length.
71
+ `original_max_position_embeddings` (`int`, *optional*):
72
+ Used with 'dynamic', 'longrope' and 'llama3'.
73
+ The original max position embeddings used during pretraining.
74
+ `attention_factor` (`float`, *optional*):
75
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
76
+ computation.
77
+ If unspecified, it defaults to value recommended by the implementation, using the `factor` field to infer the suggested value.
78
+ `beta_fast` (`float`, *optional*):
79
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
80
+ ramp function. If unspecified, it defaults to 32.
81
+ `beta_slow` (`float`, *optional*):
82
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
83
+ ramp function. If unspecified, it defaults to 1.
84
+ `short_factor` (`List[float]`, *optional*):
85
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<`original_max_position_embeddings`).
86
+ Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2
87
+ `long_factor` (`List[float]`, *optional*):
88
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<`original_max_position_embeddings`).
89
+ Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2
90
+ `low_freq_factor` (`float`, *optional*):
91
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
92
+ `high_freq_factor` (`float`, *optional*):
93
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
94
+ num_attention_heads (`int`, *optional*, defaults to 8):
95
+ Number of attention heads for each attention layer in the Transformer decoder.
96
+ num_key_value_heads (`int`, *optional*):
97
+ This is the number of key_value heads that should be used to implement Grouped Query Attention.
98
+ If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
99
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used.
100
+ When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group.
101
+ For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf).
102
+ If it is not specified, will default to `num_attention_heads`.
103
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
104
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
105
+ attention_dropout (`float`, *optional*, defaults to 0.0):
106
+ The dropout ratio for the attention probabilities.
107
+ mlp_bias (`bool`, *optional*, defaults to `False`):
108
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
109
+ sliding_window (`int`, *optional*):
110
+ Sliding window attention window size. If not specified, will default to `None`.
111
+ keep_window_size (`int`, *optional*, defaults to 2048):
112
+ The window size of tokens that are not dynamically masked, and dynamic masking is only performed when the sequence length exceeds this value.
113
+ is_moe (`bool`, *optional*, defaults to `False`):
114
+ Whether to use the Cross Domain Mixture of Experts, if `True`, the MoE will inherit the MLP to initialize.
115
+ num_experts (`int`, *optional*, defaults to 16384):
116
+ Number of routed experts in the model. This is only used when `is_moe=True`.
117
+ num_experts_per_tok (`int`, *optional*, defaults to 64):
118
+ Number of selected experts to route per-token.
119
+ norm_topk_prob (`bool`, *optional*, defaults to `False`):
120
+ Whether to normalize the topk probabilities.
121
+ output_router_logits (`bool`, *optional*, defaults to `False`):
122
+ Whether or not the router logits should be returned by the model. Enabling this will also
123
+ allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
124
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
125
+ The aux loss factor for the total loss.
126
+
127
+ ```python
128
+ >>> from transformers import DogeConfig, DogeModel
129
+
130
+ >>> # Initializing a Doge-320M style configuration
131
+ >>> configuration = DogeConfig()
132
+
133
+ >>> # Initializing a model from the Doge-320M style configuration
134
+ >>> model = DogeModel(configuration)
135
+
136
+ >>> # Accessing the model configuration
137
+ >>> configuration = model.config
138
+ ```"""
139
+
140
+ model_type = "doge"
141
+ keys_to_ignore_at_inference = ["past_key_values"]
142
+ # Default tensor parallel plan for base model `DogeModel`
143
+ base_model_tp_plan = {
144
+ "layers.*.self_attn.q_proj": "colwise",
145
+ "layers.*.self_attn.k_proj": "colwise",
146
+ "layers.*.self_attn.v_proj": "colwise",
147
+ "layers.*.self_attn.dt_proj": "rowwise",
148
+ "layers.*.self_attn.o_proj": "rowwise",
149
+ "layers.*.input_layernorm.weight": "sequence_parallel",
150
+ "layers.*.input_residual.weight": "sequence_parallel",
151
+ "layers.*.post_attention_layernorm.weight": "sequence_parallel",
152
+ "layers.*.post_attention_residual.weight": "sequence_parallel",
153
+ "norm.weight": "sequence_parallel",
154
+ "layers.*.mlp.gate_proj": "colwise",
155
+ "layers.*.mlp.up_proj": "colwise",
156
+ "layers.*.mlp.down_proj": "rowwise",
157
+ "layers.*.mlp.router_gate": "colwise_rep",
158
+ "layers.*.mlp.down_embed": "rowwise_rep",
159
+ "layers.*.mlp.up_embed": "rowwise_rep",
160
+ }
161
+ base_model_pp_plan = {
162
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
163
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
164
+ "norm": (["hidden_states"], ["hidden_states"]),
165
+ }
166
+
167
+ def __init__(
168
+ self,
169
+ vocab_size=32768,
170
+ hidden_size=1024,
171
+ intermediate_size=2048,
172
+ num_hidden_layers=32,
173
+ hidden_dropout=0.0,
174
+ hidden_act="silu",
175
+ initializer_range=0.02,
176
+ rms_norm_eps=1e-06,
177
+ use_cache=True,
178
+ tie_word_embeddings=False,
179
+ max_position_embeddings=2048,
180
+ rope_theta=10000.0,
181
+ rope_scaling=None,
182
+ num_attention_heads=8,
183
+ num_key_value_heads=None,
184
+ attention_bias=False,
185
+ attention_dropout=0.0,
186
+ mlp_bias=False,
187
+ sliding_window=None,
188
+ keep_window_size=2048,
189
+ is_moe=False,
190
+ num_experts=16384,
191
+ num_experts_per_tok=64,
192
+ norm_topk_prob=False,
193
+ output_router_logits=False,
194
+ router_aux_loss_coef=0.001,
195
+ **kwargs,
196
+ ):
197
+ self.vocab_size = vocab_size
198
+ self.hidden_size = hidden_size
199
+ self.intermediate_size = intermediate_size
200
+ self.num_hidden_layers = num_hidden_layers
201
+
202
+ self.hidden_dropout = hidden_dropout
203
+ self.hidden_act = hidden_act
204
+ self.initializer_range = initializer_range
205
+ self.rms_norm_eps = rms_norm_eps
206
+ self.use_cache = use_cache
207
+
208
+ self.max_position_embeddings = max_position_embeddings
209
+ self.rope_theta = rope_theta
210
+ self.rope_scaling = rope_scaling
211
+ self.num_attention_heads = num_attention_heads
212
+ self.num_key_value_heads = num_key_value_heads
213
+ self.attention_bias = attention_bias
214
+ self.attention_dropout = attention_dropout
215
+ self.mlp_bias = mlp_bias
216
+ self.sliding_window = sliding_window
217
+ self.keep_window_size = keep_window_size
218
+ self.is_moe = is_moe
219
+ self.num_experts = num_experts
220
+ self.num_experts_per_tok = num_experts_per_tok
221
+ self.norm_topk_prob = norm_topk_prob
222
+ self.output_router_logits = output_router_logits
223
+ self.router_aux_loss_coef = router_aux_loss_coef
224
+
225
+ # Validate the correctness of rotary position embeddings parameters
226
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
227
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
228
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
229
+ rope_config_validation(self)
230
+
231
+ # for backward compatibility
232
+ if num_key_value_heads is None:
233
+ self.num_key_value_heads = num_attention_heads
234
+
235
+ super().__init__(
236
+ tie_word_embeddings=tie_word_embeddings,
237
+ **kwargs,
238
+ )
239
+
240
+
241
+ __all__ = ["DogeConfig"]
modeling_doge.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/doge/modular_doge.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_doge.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 Jingze Shi and the HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ # The Doge family of small language models is trained by SmallDoge Team.
11
+ #
12
+ # Licensed under the Apache License, Version 2.0 (the "License");
13
+ # you may not use this file except in compliance with the License.
14
+ # You may obtain a copy of the License at
15
+ #
16
+ # http://www.apache.org/licenses/LICENSE-2.0
17
+ #
18
+ # Unless required by applicable law or agreed to in writing, software
19
+ # distributed under the License is distributed on an "AS IS" BASIS,
20
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21
+ # See the License for the specific language governing permissions and
22
+ # limitations under the License.
23
+
24
+ import math
25
+ from typing import Callable, Optional, Union
26
+
27
+ import torch
28
+ import torch.nn.functional as F
29
+ from torch import nn
30
+
31
+ from transformers.activations import ACT2FN
32
+ from transformers.cache_utils import Cache, DynamicCache
33
+ from transformers.generation import GenerationMixin
34
+ from transformers.integrations import use_kernel_forward_from_hub
35
+ from transformers.integrations.flex_attention import compile_friendly_flex_attention
36
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
37
+ from transformers.modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
38
+ from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
39
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
40
+ from transformers.modeling_utils import AttentionInterface, PreTrainedModel
41
+ from transformers.processing_utils import Unpack
42
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available
43
+ from transformers.utils.deprecation import deprecate_kwarg
44
+ from transformers.utils.generic import OutputRecorder, check_model_inputs
45
+ from .configuration_doge import DogeConfig
46
+
47
+ try:
48
+ from flash_dmattn.integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward
49
+ except ImportError:
50
+ print("Please install flash_dmattn to use this model: pip install flash-dmattn")
51
+
52
+ if is_torch_flex_attn_available():
53
+ from torch.nn.attention.flex_attention import BlockMask
54
+
55
+
56
+ @use_kernel_forward_from_hub("RMSNorm")
57
+ class DogeRMSNorm(nn.Module):
58
+ def __init__(self, hidden_size, eps=1e-6):
59
+ """
60
+ DogeRMSNorm is equivalent to T5LayerNorm
61
+ """
62
+ super().__init__()
63
+ self.weight = nn.Parameter(torch.ones(hidden_size))
64
+ self.variance_epsilon = eps
65
+
66
+ def forward(self, hidden_states):
67
+ input_dtype = hidden_states.dtype
68
+ hidden_states = hidden_states.to(torch.float32)
69
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
70
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
71
+ return self.weight * hidden_states.to(input_dtype)
72
+
73
+ def extra_repr(self):
74
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
75
+
76
+
77
+ class DogeRotaryEmbedding(nn.Module):
78
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
79
+
80
+ def __init__(self, config: DogeConfig, device=None):
81
+ super().__init__()
82
+ # BC: "rope_type" was originally "type"
83
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
84
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
85
+ else:
86
+ self.rope_type = "default"
87
+ self.max_seq_len_cached = config.max_position_embeddings
88
+ self.original_max_seq_len = config.max_position_embeddings
89
+
90
+ self.config = config
91
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
92
+
93
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
94
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
95
+ self.original_inv_freq = self.inv_freq
96
+
97
+ @torch.no_grad()
98
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
99
+ def forward(self, x, position_ids):
100
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
101
+ position_ids_expanded = position_ids[:, None, :].float()
102
+
103
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
104
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
105
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
106
+ emb = torch.cat((freqs, freqs), dim=-1)
107
+ cos = emb.cos() * self.attention_scaling
108
+ sin = emb.sin() * self.attention_scaling
109
+
110
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
111
+
112
+
113
+ def rotate_half(x):
114
+ """Rotates half the hidden dims of the input."""
115
+ x1 = x[..., : x.shape[-1] // 2]
116
+ x2 = x[..., x.shape[-1] // 2 :]
117
+ return torch.cat((-x2, x1), dim=-1)
118
+
119
+
120
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
121
+ """Applies Rotary Position Embedding to the query and key tensors.
122
+
123
+ Args:
124
+ q (`torch.Tensor`): The query tensor.
125
+ k (`torch.Tensor`): The key tensor.
126
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
127
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
128
+ position_ids (`torch.Tensor`, *optional*):
129
+ Deprecated and unused.
130
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
131
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
132
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
133
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
134
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
135
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
136
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
137
+ Returns:
138
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
139
+ """
140
+ cos = cos.unsqueeze(unsqueeze_dim)
141
+ sin = sin.unsqueeze(unsqueeze_dim)
142
+ q_embed = (q * cos) + (rotate_half(q) * sin)
143
+ k_embed = (k * cos) + (rotate_half(k) * sin)
144
+ return q_embed, k_embed
145
+
146
+
147
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
148
+ """
149
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
150
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
151
+ """
152
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
153
+ if n_rep == 1:
154
+ return hidden_states
155
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
156
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
157
+
158
+
159
+ class DogeAttention(nn.Module):
160
+ def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None):
161
+ super().__init__()
162
+ self.config = config
163
+ self.layer_idx = layer_idx
164
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
165
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
166
+ self.scaling = self.head_dim**-0.5
167
+ self.attention_dropout = config.attention_dropout
168
+ self.keep_window_size = config.keep_window_size
169
+ self.is_causal = True
170
+
171
+ self.q_proj = nn.Linear(
172
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
173
+ )
174
+ self.k_proj = nn.Linear(
175
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
176
+ )
177
+ self.v_proj = nn.Linear(
178
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
179
+ )
180
+ # dynamic mask for the QK^T attention weights matrix
181
+ self.A = nn.Parameter(torch.zeros(config.num_key_value_heads))
182
+ self.dt_proj = nn.Linear(
183
+ config.num_key_value_heads * self.head_dim, config.num_key_value_heads, bias=config.attention_bias
184
+ )
185
+ self.o_proj = nn.Linear(
186
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
187
+ )
188
+ self.q_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
189
+ self.k_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
190
+
191
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
192
+ def forward(
193
+ self,
194
+ hidden_states: torch.Tensor,
195
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
196
+ attention_mask: Optional[torch.Tensor] = None,
197
+ past_key_values: Optional[Cache] = None,
198
+ cache_position: Optional[torch.LongTensor] = None,
199
+ **kwargs,
200
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
201
+ input_shape = hidden_states.shape[:-1]
202
+ hidden_shape = (*input_shape, -1, self.head_dim)
203
+
204
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
205
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
206
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
207
+
208
+ cos, sin = position_embeddings
209
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
210
+
211
+ if past_key_values is not None:
212
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
213
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
214
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
215
+
216
+ # sampling dt_states from value_states to generate attention bias
217
+ dt_states = self.dt_proj(
218
+ value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1)
219
+ )
220
+ attn_bias = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2).to(hidden_states.dtype)
221
+
222
+ attention_interface: Callable = flash_dynamic_mask_attention_forward
223
+
224
+ attn_output, attn_weights = attention_interface(
225
+ self,
226
+ query_states,
227
+ key_states,
228
+ value_states,
229
+ attention_mask=attention_mask,
230
+ attention_bias=attn_bias,
231
+ scale=self.scaling,
232
+ )
233
+
234
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
235
+ attn_output = self.o_proj(attn_output)
236
+ return attn_output, attn_weights
237
+
238
+
239
+ class DogeMLP(nn.Module):
240
+ def __init__(self, config):
241
+ super().__init__()
242
+ self.config = config
243
+ self.hidden_size = config.hidden_size
244
+ self.intermediate_size = config.intermediate_size
245
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
246
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
247
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
248
+ self.act_fn = ACT2FN[config.hidden_act]
249
+
250
+ def forward(self, x):
251
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
252
+ return down_proj
253
+
254
+
255
+ class DogeCDMoE(nn.Module):
256
+ def __init__(self, config: DogeConfig):
257
+ super().__init__()
258
+ self.hidden_size = config.hidden_size
259
+ self.intermediate_size = config.intermediate_size
260
+ self.act_fn = ACT2FN[config.hidden_act]
261
+
262
+ self.num_experts = config.num_experts
263
+ self.num_keys = math.floor(math.sqrt(self.num_experts))
264
+ self.top_k = config.num_experts_per_tok
265
+ self.norm_topk_prob = config.norm_topk_prob
266
+
267
+ # shared expert
268
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
269
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
270
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
271
+
272
+ # router gate for retrieval experts
273
+ self.router_gate = nn.Linear(self.hidden_size, self.num_keys * 2, bias=False)
274
+
275
+ # routed experts
276
+ self.down_embed = nn.Embedding(self.num_experts, self.hidden_size)
277
+ self.up_embed = nn.Embedding(self.num_experts, self.hidden_size)
278
+
279
+ def forward(
280
+ self,
281
+ hidden_states: torch.Tensor,
282
+ **kwargs,
283
+ ) -> torch.Tensor:
284
+ bsz, seq_len, _ = hidden_states.shape
285
+
286
+ # get routing logits with router gate
287
+ router_logits = self.router_gate(hidden_states).view(2, bsz * seq_len, -1)
288
+
289
+ # get experts with the highest routing logits
290
+ (scores_x, scores_y), (indices_x, indices_y) = router_logits.topk(self.num_keys, dim=-1)
291
+ all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
292
+ all_indices = indices_x.unsqueeze(-1) * self.num_keys + indices_y.unsqueeze(-2)
293
+ all_scores = all_scores.view(*all_scores.shape[:-2], -1)
294
+ all_indices = all_indices.view(*all_indices.shape[:-2], -1)
295
+ scores, position_indices = all_scores.topk(self.top_k, dim=-1)
296
+ indices = all_indices.gather(-1, position_indices)
297
+ routing_weights = F.softmax(scores, dim=-1)
298
+ if self.norm_topk_prob:
299
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
300
+
301
+ # mix routed experts states with shared expert states
302
+ down_embed = self.down_embed(indices)
303
+ up_embed = self.up_embed(indices)
304
+ experts_weights = torch.matmul(down_embed, hidden_states.view(bsz * seq_len, -1, 1)).view(bsz * seq_len, -1)
305
+ experts_weights = self.act_fn(experts_weights) * routing_weights
306
+ experts_states = torch.matmul(experts_weights.view(bsz * seq_len, 1, -1), up_embed).view(bsz, seq_len, -1)
307
+ hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
308
+ hidden_states = hidden_states + experts_states
309
+ return hidden_states, router_logits
310
+
311
+
312
+ class DogeDecoderLayer(GradientCheckpointingLayer):
313
+ def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None):
314
+ super().__init__()
315
+ self.hidden_dropout = config.hidden_dropout
316
+
317
+ self.input_layernorm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
318
+ self.self_attn = DogeAttention(config=config, layer_idx=layer_idx)
319
+ self.input_residual = nn.Parameter(torch.ones(config.hidden_size))
320
+
321
+ self.post_attention_layernorm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
322
+ self.mlp = DogeMLP(config) if not config.is_moe else DogeCDMoE(config)
323
+ self.post_attention_residual = nn.Parameter(torch.ones(config.hidden_size))
324
+
325
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
326
+ def forward(
327
+ self,
328
+ hidden_states: torch.Tensor,
329
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
330
+ attention_mask: Optional[torch.Tensor] = None,
331
+ position_ids: Optional[torch.LongTensor] = None,
332
+ past_key_values: Optional[Cache] = None,
333
+ use_cache: Optional[bool] = False,
334
+ cache_position: Optional[torch.LongTensor] = None,
335
+ **kwargs: Unpack[TransformersKwargs],
336
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
337
+ # sequence transformation
338
+ residual = hidden_states
339
+ hidden_states = self.input_layernorm(hidden_states)
340
+ hidden_states, _ = self.self_attn(
341
+ hidden_states=hidden_states,
342
+ position_embeddings=position_embeddings,
343
+ attention_mask=attention_mask,
344
+ position_ids=position_ids,
345
+ past_key_values=past_key_values,
346
+ use_cache=use_cache,
347
+ cache_position=cache_position,
348
+ **kwargs,
349
+ )
350
+ hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
351
+ hidden_states = self.input_residual * residual + hidden_states
352
+
353
+ # state transformation
354
+ residual = hidden_states
355
+ hidden_states = self.post_attention_layernorm(hidden_states)
356
+ hidden_states = self.mlp(hidden_states)
357
+ if isinstance(hidden_states, tuple):
358
+ hidden_states, _ = hidden_states
359
+ hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
360
+ hidden_states = self.post_attention_residual * residual + hidden_states
361
+
362
+ return hidden_states
363
+
364
+
365
+ @auto_docstring
366
+ class DogePreTrainedModel(PreTrainedModel):
367
+ config: DogeConfig
368
+ base_model_prefix = "model"
369
+ supports_gradient_checkpointing = True
370
+ _no_split_modules = ["DogeDecoderLayer"]
371
+ _skip_keys_device_placement = ["past_key_values"]
372
+ _supports_flash_attn = False
373
+ _supports_sdpa = False
374
+ _supports_flex_attn = False
375
+ _can_compile_fullgraph = False
376
+ _supports_attention_backend = False
377
+ _can_record_outputs = {
378
+ "router_logits": OutputRecorder(DogeCDMoE, index=1),
379
+ "hidden_states": DogeDecoderLayer,
380
+ "attentions": DogeAttention,
381
+ }
382
+
383
+ def _init_weights(self, module):
384
+ """Initialize the weights"""
385
+ super()._init_weights(module)
386
+ if isinstance(module, DogeAttention):
387
+ if hasattr(module, "A"):
388
+ module.A.data.normal_(mean=0.0, std=self.config.initializer_range)
389
+ elif isinstance(module, DogeCDMoE):
390
+ if hasattr(module, "router_gate"):
391
+ module.router_gate.weight.data.zero_()
392
+ elif isinstance(module, DogeDecoderLayer):
393
+ if hasattr(module, "input_residual"):
394
+ module.input_residual.data.fill_(1.0)
395
+ if hasattr(module, "post_attention_residual"):
396
+ module.post_attention_residual.data.fill_(1.0)
397
+
398
+
399
+ @auto_docstring
400
+ class DogeModel(DogePreTrainedModel):
401
+ def __init__(self, config: DogeConfig):
402
+ super().__init__(config)
403
+ self.padding_idx = config.pad_token_id
404
+ self.vocab_size = config.vocab_size
405
+
406
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
407
+ self.layers = nn.ModuleList(
408
+ [DogeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
409
+ )
410
+ self.norm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
411
+ self.rotary_emb = DogeRotaryEmbedding(config=config)
412
+ self.gradient_checkpointing = False
413
+
414
+ # Initialize weights and apply final processing
415
+ self.post_init()
416
+
417
+ @check_model_inputs
418
+ @auto_docstring
419
+ def forward(
420
+ self,
421
+ input_ids: Optional[torch.LongTensor] = None,
422
+ attention_mask: Optional[torch.Tensor] = None,
423
+ position_ids: Optional[torch.LongTensor] = None,
424
+ past_key_values: Optional[Cache] = None,
425
+ inputs_embeds: Optional[torch.FloatTensor] = None,
426
+ use_cache: Optional[bool] = None,
427
+ cache_position: Optional[torch.LongTensor] = None,
428
+ **kwargs: Unpack[TransformersKwargs],
429
+ ) -> MoeModelOutputWithPast:
430
+ if (input_ids is None) ^ (inputs_embeds is not None):
431
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
432
+
433
+ if use_cache and past_key_values is None:
434
+ past_key_values = DynamicCache(config=self.config)
435
+
436
+ if inputs_embeds is None:
437
+ inputs_embeds = self.embed_tokens(input_ids)
438
+
439
+ if cache_position is None:
440
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
441
+ cache_position = torch.arange(
442
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
443
+ )
444
+ if position_ids is None:
445
+ position_ids = cache_position.unsqueeze(0)
446
+
447
+ mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
448
+ causal_mask = mask_function(
449
+ config=self.config,
450
+ input_embeds=inputs_embeds,
451
+ attention_mask=attention_mask,
452
+ cache_position=cache_position,
453
+ past_key_values=past_key_values,
454
+ position_ids=position_ids,
455
+ )
456
+
457
+ hidden_states = inputs_embeds
458
+
459
+ # create position embeddings to be shared across the decoder layers
460
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
461
+
462
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
463
+ hidden_states = decoder_layer(
464
+ hidden_states,
465
+ position_embeddings=position_embeddings,
466
+ attention_mask=causal_mask,
467
+ position_ids=position_ids,
468
+ past_key_values=past_key_values,
469
+ use_cache=use_cache,
470
+ cache_position=cache_position,
471
+ **kwargs,
472
+ )
473
+
474
+ hidden_states = self.norm(hidden_states)
475
+
476
+ return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
477
+ last_hidden_state=hidden_states,
478
+ past_key_values=past_key_values,
479
+ )
480
+
481
+
482
+ def load_balancing_loss_func(
483
+ gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
484
+ num_experts: Optional[int] = None,
485
+ num_keys: Optional[int] = None,
486
+ top_k: int = 2,
487
+ attention_mask: Optional[torch.Tensor] = None,
488
+ ) -> Union[torch.Tensor, int]:
489
+ r"""
490
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
491
+
492
+ See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
493
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
494
+ experts is too unbalanced.
495
+
496
+ Args:
497
+ gate_logits:
498
+ Logits from the `router_gate`, should be a tuple of model.config.num_hidden_layers tensors of
499
+ shape [2, batch_size * sequence_length, num_keys].
500
+ num_experts:
501
+ Number of experts
502
+ num_keys:
503
+ Number of keys
504
+ top_k:
505
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
506
+ parameter.
507
+ attention_mask (`torch.Tensor`, *optional*):
508
+ The attention_mask used in forward function
509
+ shape [batch_size X sequence_length] if not None.
510
+
511
+ Returns:
512
+ The auxiliary loss.
513
+ """
514
+ if gate_logits is None or not isinstance(gate_logits, tuple):
515
+ return 0
516
+
517
+ compute_dtype = gate_logits[0].dtype
518
+ compute_device = gate_logits[0].device
519
+ all_expert_indices = []
520
+ all_routing_weights = []
521
+
522
+ for layer_gate_logits in gate_logits:
523
+ layer_gate_logits = layer_gate_logits.to(compute_device)
524
+
525
+ (scores_x, scores_y), (indices_x, indices_y) = layer_gate_logits.topk(num_keys, dim=-1)
526
+
527
+ all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
528
+ all_indices = indices_x.unsqueeze(-1) * num_keys + indices_y.unsqueeze(-2)
529
+ all_scores = all_scores.view(*all_scores.shape[:-2], -1)
530
+ all_indices = all_indices.view(*all_indices.shape[:-2], -1)
531
+
532
+ _, position_indices = all_scores.topk(top_k, dim=-1)
533
+ expert_indices = all_indices.gather(-1, position_indices)
534
+
535
+ routing_weights = F.softmax(all_scores, dim=-1)
536
+
537
+ all_expert_indices.append(expert_indices)
538
+ all_routing_weights.append(routing_weights)
539
+ all_expert_indices = torch.cat(all_expert_indices, dim=0)
540
+ all_routing_weights = torch.cat(all_routing_weights, dim=0)
541
+
542
+ if attention_mask is None:
543
+ # Compute the percentage of tokens routed to each experts
544
+ all_expert_indices = all_expert_indices.view(-1)
545
+ tokens_per_expert = torch.zeros(num_experts, dtype=compute_dtype, device=compute_device)
546
+ pad = torch.ones_like(all_expert_indices, dtype=compute_dtype, device=compute_device)
547
+ tokens_per_expert = tokens_per_expert.scatter_add_(0, all_expert_indices, pad) / all_expert_indices.shape[0]
548
+
549
+ # Compute the average probability of routing to these experts
550
+ router_prob_per_expert = torch.mean(all_routing_weights, dim=0)
551
+ else:
552
+ batch_size, sequence_length = attention_mask.shape
553
+ num_hidden_layers = len(gate_logits)
554
+
555
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
556
+ expert_attention_mask = (
557
+ attention_mask[None, :, :, None]
558
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k))
559
+ .reshape(-1)
560
+ .to(compute_device)
561
+ )
562
+ all_expert_indices = all_expert_indices.view(-1)[expert_attention_mask.bool()]
563
+
564
+ # Compute the percentage of tokens routed to each experts
565
+ tokens_per_expert = torch.zeros(num_experts, dtype=compute_dtype, device=compute_device)
566
+ pad = torch.ones_like(all_expert_indices, dtype=compute_dtype, device=compute_device)
567
+ tokens_per_expert = tokens_per_expert.scatter_add_(0, all_expert_indices, pad) / torch.sum(
568
+ expert_attention_mask
569
+ )
570
+
571
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
572
+ router_per_expert_attention_mask = (
573
+ attention_mask[None, :, :, None]
574
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
575
+ .reshape(-1, num_experts)
576
+ .to(compute_device)
577
+ )
578
+
579
+ # Compute the average probability of routing to these experts
580
+ router_prob_per_expert = torch.sum(all_routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
581
+ router_per_expert_attention_mask, dim=0
582
+ )
583
+
584
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert)
585
+ return overall_loss * num_experts
586
+
587
+
588
+ @auto_docstring
589
+ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
590
+ _tied_weights_keys = ["lm_head.weight"]
591
+ _tp_plan = {"lm_head": "colwise_rep"}
592
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
593
+
594
+ def __init__(self, config):
595
+ super().__init__(config)
596
+ self.model = DogeModel(config)
597
+ self.vocab_size = config.vocab_size
598
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
599
+ self.router_aux_loss_coef = config.router_aux_loss_coef
600
+ self.num_experts = config.num_experts
601
+ self.num_experts_per_tok = config.num_experts_per_tok
602
+
603
+ # Initialize weights and apply final processing
604
+ self.post_init()
605
+
606
+ @can_return_tuple
607
+ @auto_docstring
608
+ def forward(
609
+ self,
610
+ input_ids: Optional[torch.LongTensor] = None,
611
+ attention_mask: Optional[torch.Tensor] = None,
612
+ position_ids: Optional[torch.LongTensor] = None,
613
+ past_key_values: Optional[Cache] = None,
614
+ inputs_embeds: Optional[torch.FloatTensor] = None,
615
+ labels: Optional[torch.LongTensor] = None,
616
+ use_cache: Optional[bool] = None,
617
+ cache_position: Optional[torch.LongTensor] = None,
618
+ logits_to_keep: Union[int, torch.Tensor] = 0,
619
+ output_router_logits: Optional[bool] = None,
620
+ **kwargs: Unpack[TransformersKwargs],
621
+ ) -> MoeCausalLMOutputWithPast:
622
+ r"""
623
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
624
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
625
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
626
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
627
+
628
+ Example:
629
+
630
+ ```python
631
+ >>> from transformers import AutoTokenizer, DogeForCausalLM
632
+
633
+ >>> model = DogeForCausalLM.from_pretrained("SmallDoge/Doge-320M")
634
+ >>> tokenizer = AutoTokenizer.from_pretrained("SmallDoge/Doge-320M")
635
+
636
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
637
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
638
+
639
+ >>> # Generate
640
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
641
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
642
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
643
+ ```"""
644
+ output_router_logits = (
645
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
646
+ )
647
+
648
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
649
+ outputs: MoeModelOutputWithPast = self.model(
650
+ input_ids=input_ids,
651
+ attention_mask=attention_mask,
652
+ position_ids=position_ids,
653
+ past_key_values=past_key_values,
654
+ inputs_embeds=inputs_embeds,
655
+ use_cache=use_cache,
656
+ cache_position=cache_position,
657
+ **kwargs,
658
+ )
659
+
660
+ hidden_states = outputs.last_hidden_state
661
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
662
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
663
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
664
+
665
+ loss = None
666
+ if labels is not None:
667
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
668
+
669
+ aux_loss = None
670
+ if output_router_logits:
671
+ aux_loss = load_balancing_loss_func(
672
+ outputs.router_logits,
673
+ self.num_experts,
674
+ math.floor(math.sqrt(self.num_experts)),
675
+ self.num_experts_per_tok,
676
+ attention_mask,
677
+ )
678
+ if labels is not None:
679
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
680
+
681
+ return MoeCausalLMOutputWithPast(
682
+ loss=loss,
683
+ aux_loss=aux_loss,
684
+ logits=logits,
685
+ past_key_values=outputs.past_key_values,
686
+ hidden_states=outputs.hidden_states,
687
+ attentions=outputs.attentions,
688
+ router_logits=outputs.router_logits,
689
+ )
690
+
691
+
692
+ class DogeForSequenceClassification(GenericForSequenceClassification, DogePreTrainedModel):
693
+ pass
694
+
695
+
696
+ __all__ = ["DogeForCausalLM", "DogeModel", "DogePreTrainedModel", "DogeForSequenceClassification"]