jeffkang-lunit commited on
Commit
4ba7bc8
·
verified ·
1 Parent(s): 97cd573

Upload modeling_gravity_moe.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_gravity_moe.py +25 -697
modeling_gravity_moe.py CHANGED
@@ -11,718 +11,46 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
- import math
15
- from collections.abc import Callable
16
- from typing import Optional
17
-
18
- import torch
19
- import torch.nn.functional as F
20
- from torch import nn
21
-
22
- from transformers import initialization as init
23
- from transformers.activations import ACT2FN
24
- from transformers.cache_utils import Cache, DynamicCache
25
- from transformers.generation import GenerationMixin
26
- from transformers.integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
27
- from transformers.masking_utils import create_causal_mask
28
- from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
29
- from transformers.modeling_layers import (
30
- GenericForSequenceClassification,
31
- GenericForTokenClassification,
32
- GradientCheckpointingLayer,
33
  )
34
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
35
- from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
36
- from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
37
- from transformers.processing_utils import Unpack
38
- from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
39
- from transformers.utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults
40
- from transformers.utils.output_capturing import capture_outputs
41
- from .configuration_gravity_moe import GravityMoEConfig
42
-
43
-
44
- @use_kernel_forward_from_hub("RMSNorm")
45
- class GravityMoERMSNorm(nn.Module):
46
- def __init__(self, hidden_size, eps: float = 1e-6) -> None:
47
- """
48
- GravityMoERMSNorm is equivalent to T5LayerNorm
49
- """
50
- super().__init__()
51
- self.weight = nn.Parameter(torch.ones(hidden_size))
52
- self.variance_epsilon = eps
53
-
54
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
55
- input_dtype = hidden_states.dtype
56
- hidden_states = hidden_states.to(torch.float32)
57
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
58
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
59
- return self.weight * hidden_states.to(input_dtype)
60
-
61
- def extra_repr(self):
62
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
63
-
64
-
65
- class GravityMoERotaryEmbedding(nn.Module):
66
- inv_freq: torch.Tensor # fix linting for `register_buffer`
67
-
68
- def __init__(self, config: GravityMoEConfig, device=None):
69
- super().__init__()
70
- self.max_seq_len_cached = config.max_position_embeddings
71
- self.original_max_seq_len = config.max_position_embeddings
72
-
73
- self.config = config
74
-
75
- self.rope_type = self.config.rope_parameters["rope_type"]
76
- rope_init_fn: Callable = self.compute_default_rope_parameters
77
- if self.rope_type != "default":
78
- rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
79
- inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
80
-
81
- self.register_buffer("inv_freq", inv_freq, persistent=False)
82
- self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
83
-
84
- @staticmethod
85
- def compute_default_rope_parameters(
86
- config: GravityMoEConfig | None = None,
87
- device: Optional["torch.device"] = None,
88
- seq_len: int | None = None,
89
- ) -> tuple["torch.Tensor", float]:
90
- """
91
- Computes the inverse frequencies according to the original RoPE implementation
92
- Args:
93
- config ([`~transformers.PreTrainedConfig`]):
94
- The model configuration.
95
- device (`torch.device`):
96
- The device to use for initialization of the inverse frequencies.
97
- seq_len (`int`, *optional*):
98
- The current sequence length. Unused for this type of RoPE.
99
- Returns:
100
- Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
101
- post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
102
- """
103
- base = config.rope_parameters["rope_theta"]
104
- dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
105
-
106
- attention_factor = 1.0 # Unused in this type of RoPE
107
-
108
- # Compute the inverse frequencies
109
- inv_freq = 1.0 / (
110
- base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
111
- )
112
- return inv_freq, attention_factor
113
-
114
- @torch.no_grad()
115
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
116
- def forward(self, x, position_ids):
117
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
118
- position_ids_expanded = position_ids[:, None, :].float()
119
-
120
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
121
- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
122
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
123
- emb = torch.cat((freqs, freqs), dim=-1)
124
- cos = emb.cos() * self.attention_scaling
125
- sin = emb.sin() * self.attention_scaling
126
-
127
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
128
-
129
-
130
- class GravityMoEMLP(nn.Module):
131
- def __init__(self, config, intermediate_size=None):
132
- super().__init__()
133
- self.config = config
134
- self.hidden_size = config.hidden_size
135
- self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
136
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
137
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
138
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
139
- self.act_fn = ACT2FN[config.hidden_act]
140
-
141
- def forward(self, x):
142
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
143
- return down_proj
144
-
145
-
146
- class GravityMoETopkRouter(nn.Module):
147
- def __init__(self, config):
148
- super().__init__()
149
- self.config = config
150
- self.n_routed_experts = config.n_routed_experts
151
-
152
- self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
153
- self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts))
154
-
155
- def forward(self, hidden_states):
156
- hidden_states = hidden_states.view(-1, self.config.hidden_size)
157
- router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
158
- return router_logits
159
-
160
-
161
- class GravityMoENaiveMoe(nn.Module):
162
- """Collection of expert weights stored as fused 3D tensors."""
163
-
164
- def __init__(self, config):
165
- super().__init__()
166
- self.num_experts = config.num_local_experts
167
- self.hidden_dim = config.hidden_size
168
- self.intermediate_dim = config.moe_intermediate_size
169
- self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
170
- self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
171
- self.act_fn = ACT2FN[config.hidden_act]
172
-
173
- def forward(
174
- self,
175
- hidden_states: torch.Tensor,
176
- top_k_index: torch.Tensor,
177
- top_k_weights: torch.Tensor,
178
- ) -> torch.Tensor:
179
- final_hidden_states = torch.zeros_like(hidden_states)
180
- with torch.no_grad():
181
- expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
182
- expert_mask = expert_mask.permute(2, 1, 0)
183
- expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
184
-
185
- for expert_idx in expert_hit:
186
- expert_idx = expert_idx[0]
187
- if expert_idx == self.num_experts:
188
- continue
189
- top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
190
- current_state = hidden_states[token_idx]
191
- gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
192
- current_hidden_states = self.act_fn(gate) * up
193
- current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
194
- current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
195
- final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
196
-
197
- return final_hidden_states
198
-
199
-
200
- class GravityMoEMoE(nn.Module):
201
- """
202
- A mixed expert module containing shared experts.
203
- """
204
-
205
- def __init__(self, config):
206
- super().__init__()
207
- self.config = config
208
- self.experts = GravityMoENaiveMoe(config)
209
- self.gate = GravityMoETopkRouter(config)
210
- self.shared_experts = GravityMoEMLP(
211
- config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
212
- )
213
- self.n_routed_experts = config.n_routed_experts
214
- self.n_group = config.n_group
215
- self.topk_group = config.topk_group
216
- self.norm_topk_prob = config.norm_topk_prob
217
- self.routed_scaling_factor = config.routed_scaling_factor
218
- self.top_k = config.num_experts_per_tok
219
-
220
- def route_tokens_to_experts(self, router_logits):
221
- router_logits = router_logits.sigmoid()
222
- router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
223
- group_scores = (
224
- router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
225
- .topk(2, dim=-1)[0]
226
- .sum(dim=-1)
227
- )
228
- group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
229
- group_mask = torch.zeros_like(group_scores)
230
- group_mask.scatter_(1, group_idx, 1)
231
- score_mask = (
232
- group_mask.unsqueeze(-1)
233
- .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
234
- .reshape(-1, self.n_routed_experts)
235
- )
236
- scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0)
237
- topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
238
- topk_weights = router_logits.gather(1, topk_indices)
239
- if self.norm_topk_prob:
240
- denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
241
- topk_weights /= denominator
242
- topk_weights = topk_weights * self.routed_scaling_factor
243
- return topk_indices, topk_weights
244
-
245
- def forward(self, hidden_states):
246
- residuals = hidden_states
247
- orig_shape = hidden_states.shape
248
- router_logits = self.gate(hidden_states)
249
- topk_indices, topk_weights = self.route_tokens_to_experts(router_logits)
250
- hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
251
- hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape)
252
- hidden_states = hidden_states + self.shared_experts(residuals)
253
- return hidden_states
254
-
255
-
256
- def rotate_half(x):
257
- """Rotates half the hidden dims of the input."""
258
- x1 = x[..., : x.shape[-1] // 2]
259
- x2 = x[..., x.shape[-1] // 2 :]
260
- return torch.cat((-x2, x1), dim=-1)
261
-
262
-
263
- @use_kernel_func_from_hub("rotary_pos_emb")
264
- def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
265
- """Applies Rotary Position Embedding to the query and key tensors.
266
-
267
- Args:
268
- q (`torch.Tensor`): The query tensor.
269
- k (`torch.Tensor`): The key tensor.
270
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
271
- sin (`torch.Tensor`): The sine part of the rotary embedding.
272
- unsqueeze_dim (`int`, *optional*, defaults to 1):
273
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
274
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
275
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
276
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
277
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
278
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
279
- Returns:
280
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
281
- """
282
- cos = cos.unsqueeze(unsqueeze_dim)
283
- sin = sin.unsqueeze(unsqueeze_dim)
284
- q_embed = (q * cos) + (rotate_half(q) * sin)
285
- k_embed = (k * cos) + (rotate_half(k) * sin)
286
- return q_embed, k_embed
287
-
288
-
289
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
290
- """
291
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
292
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
293
- """
294
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
295
- if n_rep == 1:
296
- return hidden_states
297
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
298
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
299
-
300
-
301
- def eager_attention_forward(
302
- module: nn.Module,
303
- query: torch.Tensor,
304
- key: torch.Tensor,
305
- value: torch.Tensor,
306
- attention_mask: torch.Tensor | None,
307
- scaling: float,
308
- dropout: float = 0.0,
309
- **kwargs: Unpack[TransformersKwargs],
310
- ):
311
- key_states = repeat_kv(key, module.num_key_value_groups)
312
- value_states = repeat_kv(value, module.num_key_value_groups)
313
-
314
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
315
- if attention_mask is not None:
316
- attn_weights = attn_weights + attention_mask
317
-
318
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
319
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
320
- attn_output = torch.matmul(attn_weights, value_states)
321
- attn_output = attn_output.transpose(1, 2).contiguous()
322
-
323
- return attn_output, attn_weights
324
-
325
-
326
- def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
327
- r"""
328
- TODO let's just use the original freqcis computation to not have the view
329
- transpose + reshape! This is not optimized!
330
- Applies Rotary Position Embedding to the query and key tensors.
331
-
332
- Args:
333
- q (`torch.Tensor`): The query tensor.
334
- k (`torch.Tensor`): The key tensor.
335
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
336
- sin (`torch.Tensor`): The sine part of the rotary embedding.
337
- position_ids (`torch.Tensor`):
338
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
339
- used to pass offsetted position ids when working with a KV-cache.
340
- unsqueeze_dim (`int`, *optional*, defaults to 1):
341
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
342
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
343
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
344
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
345
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
346
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
347
- Returns:
348
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
349
- """
350
- cos = cos.unsqueeze(unsqueeze_dim)
351
- sin = sin.unsqueeze(unsqueeze_dim)
352
-
353
- b, h, s, d = q.shape
354
- q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
355
-
356
- b, h, s, d = k.shape
357
- k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
358
-
359
- q_embed = (q * cos) + (rotate_half(q) * sin)
360
- k_embed = (k * cos) + (rotate_half(k) * sin)
361
- return q_embed, k_embed
362
 
 
363
 
364
- def yarn_get_mscale(scale=1, mscale=1):
365
- if scale <= 1:
366
- return 1.0
367
- return 0.1 * mscale * math.log(scale) + 1.0
368
-
369
-
370
- class GravityMoEAttention(nn.Module):
371
- """Multi-headed attention from 'Attention Is All You Need' paper"""
372
-
373
- def __init__(self, config: GravityMoEConfig, layer_idx: int):
374
- super().__init__()
375
- self.config = config
376
- self.layer_idx = layer_idx
377
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
378
- self.attention_dropout = config.attention_dropout
379
- self.num_heads = config.num_attention_heads
380
-
381
- self.q_lora_rank = config.q_lora_rank
382
- self.qk_rope_head_dim = config.qk_rope_head_dim
383
- self.kv_lora_rank = config.kv_lora_rank
384
- self.v_head_dim = config.v_head_dim
385
- self.qk_nope_head_dim = config.qk_nope_head_dim
386
- self.qk_head_dim = config.qk_head_dim
387
-
388
- self.is_causal = True
389
- if self.q_lora_rank is None:
390
- self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
391
- else:
392
- self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
393
- self.q_a_layernorm = GravityMoERMSNorm(config.q_lora_rank)
394
- self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
395
-
396
- self.kv_a_proj_with_mqa = nn.Linear(
397
- config.hidden_size,
398
- self.kv_lora_rank + self.qk_rope_head_dim,
399
- bias=config.attention_bias,
400
- )
401
- self.kv_a_layernorm = GravityMoERMSNorm(self.kv_lora_rank)
402
- self.kv_b_proj = nn.Linear(
403
- self.kv_lora_rank,
404
- self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
405
- bias=False,
406
- )
407
-
408
- self.o_proj = nn.Linear(
409
- self.num_heads * self.v_head_dim,
410
- config.hidden_size,
411
- bias=config.attention_bias,
412
- )
413
-
414
- self.scaling = self.qk_head_dim ** (-0.5)
415
- if self.config.rope_parameters.get("rope_type", "default") != "default":
416
- mscale_all_dim = self.config.rope_parameters.get("mscale_all_dim", 0)
417
- scaling_factor = self.config.rope_parameters["factor"]
418
- if mscale_all_dim:
419
- mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
420
- self.scaling = self.scaling * mscale * mscale
421
-
422
- def forward(
423
- self,
424
- hidden_states: torch.Tensor,
425
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
426
- attention_mask: torch.Tensor | None,
427
- past_key_values: Cache | None = None,
428
- **kwargs: Unpack[FlashAttentionKwargs],
429
- ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
430
- batch_size, seq_length = hidden_states.shape[:-1]
431
- query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
432
- key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
433
-
434
- if self.q_lora_rank is None:
435
- q_states = self.q_proj(hidden_states)
436
- else:
437
- q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
438
- q_states = q_states.view(query_shape).transpose(1, 2)
439
- q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
440
-
441
- compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
442
- k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
443
-
444
- k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)
445
- k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
446
-
447
- k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
448
-
449
- cos, sin = position_embeddings
450
- if self.config.rope_interleave: # support using interleaved weights for efficiency
451
- q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
452
- else:
453
- q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
454
- k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
455
-
456
- query_states = torch.cat((q_pass, q_rot), dim=-1)
457
- key_states = torch.cat((k_pass, k_rot), dim=-1)
458
-
459
- if past_key_values is not None:
460
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
461
-
462
- if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim:
463
- value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
464
-
465
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
466
- self.config._attn_implementation, eager_attention_forward
467
- )
468
-
469
- attn_output, attn_weights = attention_interface(
470
- self,
471
- query_states,
472
- key_states,
473
- value_states,
474
- attention_mask,
475
- dropout=0.0 if not self.training else self.attention_dropout,
476
- scaling=self.scaling,
477
- **kwargs,
478
- )
479
-
480
- if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim:
481
- attn_output = attn_output[:, :, :, : self.v_head_dim]
482
-
483
- attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
484
- attn_output = self.o_proj(attn_output)
485
- return attn_output, attn_weights
486
-
487
-
488
- class GravityMoEDecoderLayer(GradientCheckpointingLayer):
489
- def __init__(self, config: GravityMoEConfig, layer_idx: int):
490
- super().__init__()
491
- self.hidden_size = config.hidden_size
492
-
493
- self.self_attn = GravityMoEAttention(config=config, layer_idx=layer_idx)
494
-
495
- if layer_idx >= config.first_k_dense_replace:
496
- self.mlp = GravityMoEMoE(config)
497
- else:
498
- self.mlp = GravityMoEMLP(config)
499
-
500
- self.input_layernorm = GravityMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps)
501
- self.post_attention_layernorm = GravityMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps)
502
-
503
- def forward(
504
- self,
505
- hidden_states: torch.Tensor,
506
- attention_mask: torch.Tensor | None = None,
507
- position_ids: torch.LongTensor | None = None,
508
- past_key_values: Cache | None = None,
509
- use_cache: bool | None = False,
510
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
511
- **kwargs: Unpack[TransformersKwargs],
512
- ) -> torch.Tensor:
513
- residual = hidden_states
514
- hidden_states = self.input_layernorm(hidden_states)
515
- # Self Attention
516
- hidden_states, _ = self.self_attn(
517
- hidden_states=hidden_states,
518
- attention_mask=attention_mask,
519
- position_ids=position_ids,
520
- past_key_values=past_key_values,
521
- use_cache=use_cache,
522
- position_embeddings=position_embeddings,
523
- **kwargs,
524
- )
525
- hidden_states = residual + hidden_states
526
-
527
- # Fully Connected
528
- residual = hidden_states
529
- hidden_states = self.post_attention_layernorm(hidden_states)
530
- hidden_states = self.mlp(hidden_states)
531
- hidden_states = residual + hidden_states
532
- return hidden_states
533
-
534
 
535
- @auto_docstring
536
- class GravityMoEPreTrainedModel(PreTrainedModel):
537
- config: GravityMoEConfig
538
- base_model_prefix = "model"
539
- supports_gradient_checkpointing = True
540
- _no_split_modules = ["GravityMoEDecoderLayer"]
541
- _skip_keys_device_placement = ["past_key_values"]
542
- _supports_flash_attn = True
543
- _supports_sdpa = True
544
- _supports_flex_attn = True
545
 
546
- _can_compile_fullgraph = True
547
- _supports_attention_backend = True
548
- _can_record_outputs = {
549
- "hidden_states": GravityMoEDecoderLayer,
550
- "attentions": GravityMoEAttention,
551
- }
552
  _keep_in_fp32_modules_strict = ["e_score_correction_bias"]
553
  _keys_to_ignore_on_load_unexpected = [r"model\.layers\.28.*"]
554
 
555
- @torch.no_grad()
556
- def _init_weights(self, module):
557
- super()._init_weights(module)
558
- if isinstance(module, GravityMoETopkRouter):
559
- init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
560
- init.zeros_(module.e_score_correction_bias)
561
- elif isinstance(module, GravityMoENaiveMoe):
562
- init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
563
- init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
564
-
565
-
566
- @auto_docstring
567
- class GravityMoEModel(GravityMoEPreTrainedModel):
568
- def __init__(self, config: GravityMoEConfig):
569
- super().__init__(config)
570
- self.padding_idx = config.pad_token_id
571
- self.vocab_size = config.vocab_size
572
-
573
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
574
- self.layers = nn.ModuleList(
575
- [GravityMoEDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
576
- )
577
- self.norm = GravityMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps)
578
- self.rotary_emb = GravityMoERotaryEmbedding(config=config)
579
- self.gradient_checkpointing = False
580
-
581
- # Initialize weights and apply final processing
582
- self.post_init()
583
-
584
- @merge_with_config_defaults
585
- @capture_outputs
586
- @auto_docstring
587
- def forward(
588
- self,
589
- input_ids: torch.LongTensor | None = None,
590
- attention_mask: torch.Tensor | None = None,
591
- position_ids: torch.LongTensor | None = None,
592
- past_key_values: Cache | None = None,
593
- inputs_embeds: torch.FloatTensor | None = None,
594
- use_cache: bool | None = None,
595
- **kwargs: Unpack[TransformersKwargs],
596
- ) -> BaseModelOutputWithPast:
597
- if (input_ids is None) ^ (inputs_embeds is not None):
598
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
599
-
600
- if inputs_embeds is None:
601
- inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
602
-
603
- if use_cache and past_key_values is None:
604
- past_key_values = DynamicCache(config=self.config)
605
-
606
- if position_ids is None:
607
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
608
- position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
609
- position_ids = position_ids.unsqueeze(0)
610
-
611
- causal_mask = create_causal_mask(
612
- config=self.config,
613
- inputs_embeds=inputs_embeds,
614
- attention_mask=attention_mask,
615
- past_key_values=past_key_values,
616
- position_ids=position_ids,
617
- )
618
-
619
- hidden_states = inputs_embeds
620
- position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
621
-
622
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
623
- hidden_states = decoder_layer(
624
- hidden_states,
625
- attention_mask=causal_mask,
626
- position_embeddings=position_embeddings,
627
- position_ids=position_ids,
628
- past_key_values=past_key_values,
629
- use_cache=use_cache,
630
- **kwargs,
631
- )
632
-
633
- hidden_states = self.norm(hidden_states)
634
- return BaseModelOutputWithPast(
635
- last_hidden_state=hidden_states,
636
- past_key_values=past_key_values,
637
- )
638
-
639
-
640
- @auto_docstring
641
- class GravityMoEForCausalLM(GravityMoEPreTrainedModel, GenerationMixin):
642
- _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
643
- _tp_plan = {"lm_head": "colwise_gather_output"}
644
- _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
645
-
646
- def __init__(self, config):
647
- super().__init__(config)
648
- self.model = GravityMoEModel(config)
649
- self.vocab_size = config.vocab_size
650
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
651
-
652
- # Initialize weights and apply final processing
653
- self.post_init()
654
-
655
- @can_return_tuple
656
- @auto_docstring
657
- def forward(
658
- self,
659
- input_ids: torch.LongTensor | None = None,
660
- attention_mask: torch.Tensor | None = None,
661
- position_ids: torch.LongTensor | None = None,
662
- past_key_values: Cache | None = None,
663
- inputs_embeds: torch.FloatTensor | None = None,
664
- labels: torch.LongTensor | None = None,
665
- use_cache: bool | None = None,
666
- logits_to_keep: int | torch.Tensor = 0,
667
- **kwargs: Unpack[TransformersKwargs],
668
- ) -> CausalLMOutputWithPast:
669
- r"""
670
- Example:
671
-
672
- ```python
673
- >>> from transformers import AutoTokenizer, GravityMoEForCausalLM
674
-
675
- >>> model = GravityMoEForCausalLM.from_pretrained("trillion-labs/Gravity-MoE-16.2B-A3.2B")
676
- >>> tokenizer = AutoTokenizer.from_pretrained("trillion-labs/Gravity-MoE-16.2B-A3.2B")
677
-
678
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
679
- >>> inputs = tokenizer(prompt, return_tensors="pt")
680
-
681
- >>> # Generate
682
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
683
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
684
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
685
- ```"""
686
- outputs: BaseModelOutputWithPast = self.model(
687
- input_ids=input_ids,
688
- attention_mask=attention_mask,
689
- position_ids=position_ids,
690
- past_key_values=past_key_values,
691
- inputs_embeds=inputs_embeds,
692
- use_cache=use_cache,
693
- **kwargs,
694
- )
695
-
696
- hidden_states = outputs.last_hidden_state
697
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
698
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
699
- logits = self.lm_head(hidden_states[:, slice_indices, :])
700
-
701
- loss = None
702
- if labels is not None:
703
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
704
-
705
- return CausalLMOutputWithPast(
706
- loss=loss,
707
- logits=logits,
708
- past_key_values=outputs.past_key_values,
709
- hidden_states=outputs.hidden_states,
710
- attentions=outputs.attentions,
711
- )
712
-
713
 
714
- class GravityMoEForSequenceClassification(GenericForSequenceClassification, GravityMoEPreTrainedModel):
715
- pass
716
 
717
 
718
- class GravityMoEForTokenClassification(GenericForTokenClassification, GravityMoEPreTrainedModel):
719
- pass
720
 
721
 
722
  __all__ = [
723
  "GravityMoEPreTrainedModel",
724
  "GravityMoEModel",
725
  "GravityMoEForCausalLM",
726
- "GravityMoEForSequenceClassification",
727
- "GravityMoEForTokenClassification",
728
  ]
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
+ """
15
+ GravityMoE model inherits from DeepSeek V3.
16
+
17
+ GravityMoE shares the same sparse Mixture-of-Experts architecture as DeepSeek V3
18
+ (MLA attention, sigmoid routing with bias correction, shared + routed experts)
19
+ but with different model hyperparameters. All modeling logic is inherited from
20
+ the DeepSeek V3 implementation in `transformers`.
21
+ """
22
+
23
+ from transformers.conversion_mapping import _MODEL_TO_CONVERSION_PATTERN
24
+ from transformers.models.deepseek_v3.modeling_deepseek_v3 import (
25
+ DeepseekV3ForCausalLM,
26
+ DeepseekV3Model,
27
+ DeepseekV3PreTrainedModel,
 
 
 
 
 
28
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ from .configuration_gravity_moe import GravityMoEConfig
31
 
32
+ # Register weight conversion so that from_pretrained fuses per-expert
33
+ # checkpoint weights (experts.*.gate_proj, etc.) into 3D tensors
34
+ # (experts.gate_up_proj, experts.down_proj), same as DeepSeek V3.
35
+ _MODEL_TO_CONVERSION_PATTERN["gravity_moe"] = "qwen2_moe"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ class GravityMoEPreTrainedModel(DeepseekV3PreTrainedModel):
39
+ config_class = GravityMoEConfig
 
 
 
 
40
  _keep_in_fp32_modules_strict = ["e_score_correction_bias"]
41
  _keys_to_ignore_on_load_unexpected = [r"model\.layers\.28.*"]
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ class GravityMoEModel(DeepseekV3Model):
45
+ config_class = GravityMoEConfig
46
 
47
 
48
+ class GravityMoEForCausalLM(DeepseekV3ForCausalLM):
49
+ config_class = GravityMoEConfig
50
 
51
 
52
  __all__ = [
53
  "GravityMoEPreTrainedModel",
54
  "GravityMoEModel",
55
  "GravityMoEForCausalLM",
 
 
56
  ]