jeffkang-lunit commited on
Commit
96a436c
·
verified ·
1 Parent(s): 8e26e8a

Upload modeling_gravity_moe.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_gravity_moe.py +728 -0
modeling_gravity_moe.py ADDED
@@ -0,0 +1,728 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Trillion Labs and the HuggingFace Inc. team. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from collections.abc import Callable
16
+ from typing import Optional
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+
22
+ from transformers import initialization as init
23
+ from transformers.activations import ACT2FN
24
+ from transformers.cache_utils import Cache, DynamicCache
25
+ from transformers.generation import GenerationMixin
26
+ from transformers.integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
27
+ from transformers.masking_utils import create_causal_mask
28
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
29
+ from transformers.modeling_layers import (
30
+ GenericForSequenceClassification,
31
+ GenericForTokenClassification,
32
+ GradientCheckpointingLayer,
33
+ )
34
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
35
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
36
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
37
+ from transformers.processing_utils import Unpack
38
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
39
+ from transformers.utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults
40
+ from transformers.utils.output_capturing import capture_outputs
41
+ from .configuration_gravity_moe import GravityMoEConfig
42
+
43
+
44
+ @use_kernel_forward_from_hub("RMSNorm")
45
+ class GravityMoERMSNorm(nn.Module):
46
+ def __init__(self, hidden_size, eps: float = 1e-6) -> None:
47
+ """
48
+ GravityMoERMSNorm is equivalent to T5LayerNorm
49
+ """
50
+ super().__init__()
51
+ self.weight = nn.Parameter(torch.ones(hidden_size))
52
+ self.variance_epsilon = eps
53
+
54
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
55
+ input_dtype = hidden_states.dtype
56
+ hidden_states = hidden_states.to(torch.float32)
57
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
58
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
59
+ return self.weight * hidden_states.to(input_dtype)
60
+
61
+ def extra_repr(self):
62
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
63
+
64
+
65
+ class GravityMoERotaryEmbedding(nn.Module):
66
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
67
+
68
+ def __init__(self, config: GravityMoEConfig, device=None):
69
+ super().__init__()
70
+ self.max_seq_len_cached = config.max_position_embeddings
71
+ self.original_max_seq_len = config.max_position_embeddings
72
+
73
+ self.config = config
74
+
75
+ self.rope_type = self.config.rope_parameters["rope_type"]
76
+ rope_init_fn: Callable = self.compute_default_rope_parameters
77
+ if self.rope_type != "default":
78
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
79
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
80
+
81
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
82
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
83
+
84
+ @staticmethod
85
+ def compute_default_rope_parameters(
86
+ config: GravityMoEConfig | None = None,
87
+ device: Optional["torch.device"] = None,
88
+ seq_len: int | None = None,
89
+ ) -> tuple["torch.Tensor", float]:
90
+ """
91
+ Computes the inverse frequencies according to the original RoPE implementation
92
+ Args:
93
+ config ([`~transformers.PreTrainedConfig`]):
94
+ The model configuration.
95
+ device (`torch.device`):
96
+ The device to use for initialization of the inverse frequencies.
97
+ seq_len (`int`, *optional*):
98
+ The current sequence length. Unused for this type of RoPE.
99
+ Returns:
100
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
101
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
102
+ """
103
+ base = config.rope_parameters["rope_theta"]
104
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
105
+
106
+ attention_factor = 1.0 # Unused in this type of RoPE
107
+
108
+ # Compute the inverse frequencies
109
+ inv_freq = 1.0 / (
110
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
111
+ )
112
+ return inv_freq, attention_factor
113
+
114
+ @torch.no_grad()
115
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
116
+ def forward(self, x, position_ids):
117
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
118
+ position_ids_expanded = position_ids[:, None, :].float()
119
+
120
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
121
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
122
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
123
+ emb = torch.cat((freqs, freqs), dim=-1)
124
+ cos = emb.cos() * self.attention_scaling
125
+ sin = emb.sin() * self.attention_scaling
126
+
127
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
128
+
129
+
130
+ class GravityMoEMLP(nn.Module):
131
+ def __init__(self, config, intermediate_size=None):
132
+ super().__init__()
133
+ self.config = config
134
+ self.hidden_size = config.hidden_size
135
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
136
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
137
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
138
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
139
+ self.act_fn = ACT2FN[config.hidden_act]
140
+
141
+ def forward(self, x):
142
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
143
+ return down_proj
144
+
145
+
146
+ class GravityMoETopkRouter(nn.Module):
147
+ def __init__(self, config):
148
+ super().__init__()
149
+ self.config = config
150
+ self.n_routed_experts = config.n_routed_experts
151
+
152
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
153
+ self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts))
154
+
155
+ def forward(self, hidden_states):
156
+ hidden_states = hidden_states.view(-1, self.config.hidden_size)
157
+ router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
158
+ return router_logits
159
+
160
+
161
+ class GravityMoENaiveMoe(nn.Module):
162
+ """Collection of expert weights stored as fused 3D tensors."""
163
+
164
+ def __init__(self, config):
165
+ super().__init__()
166
+ self.num_experts = config.num_local_experts
167
+ self.hidden_dim = config.hidden_size
168
+ self.intermediate_dim = config.moe_intermediate_size
169
+ self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
170
+ self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
171
+ self.act_fn = ACT2FN[config.hidden_act]
172
+
173
+ def forward(
174
+ self,
175
+ hidden_states: torch.Tensor,
176
+ top_k_index: torch.Tensor,
177
+ top_k_weights: torch.Tensor,
178
+ ) -> torch.Tensor:
179
+ final_hidden_states = torch.zeros_like(hidden_states)
180
+ with torch.no_grad():
181
+ expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
182
+ expert_mask = expert_mask.permute(2, 1, 0)
183
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
184
+
185
+ for expert_idx in expert_hit:
186
+ expert_idx = expert_idx[0]
187
+ if expert_idx == self.num_experts:
188
+ continue
189
+ top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
190
+ current_state = hidden_states[token_idx]
191
+ gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
192
+ current_hidden_states = self.act_fn(gate) * up
193
+ current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
194
+ current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
195
+ final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
196
+
197
+ return final_hidden_states
198
+
199
+
200
+ class GravityMoEMoE(nn.Module):
201
+ """
202
+ A mixed expert module containing shared experts.
203
+ """
204
+
205
+ def __init__(self, config):
206
+ super().__init__()
207
+ self.config = config
208
+ self.experts = GravityMoENaiveMoe(config)
209
+ self.gate = GravityMoETopkRouter(config)
210
+ self.shared_experts = GravityMoEMLP(
211
+ config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
212
+ )
213
+ self.n_routed_experts = config.n_routed_experts
214
+ self.n_group = config.n_group
215
+ self.topk_group = config.topk_group
216
+ self.norm_topk_prob = config.norm_topk_prob
217
+ self.routed_scaling_factor = config.routed_scaling_factor
218
+ self.top_k = config.num_experts_per_tok
219
+
220
+ def route_tokens_to_experts(self, router_logits):
221
+ router_logits = router_logits.sigmoid()
222
+ router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
223
+ group_scores = (
224
+ router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
225
+ .topk(2, dim=-1)[0]
226
+ .sum(dim=-1)
227
+ )
228
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
229
+ group_mask = torch.zeros_like(group_scores)
230
+ group_mask.scatter_(1, group_idx, 1)
231
+ score_mask = (
232
+ group_mask.unsqueeze(-1)
233
+ .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
234
+ .reshape(-1, self.n_routed_experts)
235
+ )
236
+ scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0)
237
+ topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
238
+ topk_weights = router_logits.gather(1, topk_indices)
239
+ if self.norm_topk_prob:
240
+ denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
241
+ topk_weights /= denominator
242
+ topk_weights = topk_weights * self.routed_scaling_factor
243
+ return topk_indices, topk_weights
244
+
245
+ def forward(self, hidden_states):
246
+ residuals = hidden_states
247
+ orig_shape = hidden_states.shape
248
+ router_logits = self.gate(hidden_states)
249
+ topk_indices, topk_weights = self.route_tokens_to_experts(router_logits)
250
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
251
+ hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape)
252
+ hidden_states = hidden_states + self.shared_experts(residuals)
253
+ return hidden_states
254
+
255
+
256
+ def rotate_half(x):
257
+ """Rotates half the hidden dims of the input."""
258
+ x1 = x[..., : x.shape[-1] // 2]
259
+ x2 = x[..., x.shape[-1] // 2 :]
260
+ return torch.cat((-x2, x1), dim=-1)
261
+
262
+
263
+ @use_kernel_func_from_hub("rotary_pos_emb")
264
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
265
+ """Applies Rotary Position Embedding to the query and key tensors.
266
+
267
+ Args:
268
+ q (`torch.Tensor`): The query tensor.
269
+ k (`torch.Tensor`): The key tensor.
270
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
271
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
272
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
273
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
274
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
275
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
276
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
277
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
278
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
279
+ Returns:
280
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
281
+ """
282
+ cos = cos.unsqueeze(unsqueeze_dim)
283
+ sin = sin.unsqueeze(unsqueeze_dim)
284
+ q_embed = (q * cos) + (rotate_half(q) * sin)
285
+ k_embed = (k * cos) + (rotate_half(k) * sin)
286
+ return q_embed, k_embed
287
+
288
+
289
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
290
+ """
291
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
292
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
293
+ """
294
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
295
+ if n_rep == 1:
296
+ return hidden_states
297
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
298
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
299
+
300
+
301
+ def eager_attention_forward(
302
+ module: nn.Module,
303
+ query: torch.Tensor,
304
+ key: torch.Tensor,
305
+ value: torch.Tensor,
306
+ attention_mask: torch.Tensor | None,
307
+ scaling: float,
308
+ dropout: float = 0.0,
309
+ **kwargs: Unpack[TransformersKwargs],
310
+ ):
311
+ key_states = repeat_kv(key, module.num_key_value_groups)
312
+ value_states = repeat_kv(value, module.num_key_value_groups)
313
+
314
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
315
+ if attention_mask is not None:
316
+ attn_weights = attn_weights + attention_mask
317
+
318
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
319
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
320
+ attn_output = torch.matmul(attn_weights, value_states)
321
+ attn_output = attn_output.transpose(1, 2).contiguous()
322
+
323
+ return attn_output, attn_weights
324
+
325
+
326
+ def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
327
+ r"""
328
+ TODO let's just use the original freqcis computation to not have the view
329
+ transpose + reshape! This is not optimized!
330
+ Applies Rotary Position Embedding to the query and key tensors.
331
+
332
+ Args:
333
+ q (`torch.Tensor`): The query tensor.
334
+ k (`torch.Tensor`): The key tensor.
335
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
336
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
337
+ position_ids (`torch.Tensor`):
338
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
339
+ used to pass offsetted position ids when working with a KV-cache.
340
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
341
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
342
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
343
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
344
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
345
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
346
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
347
+ Returns:
348
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
349
+ """
350
+ cos = cos.unsqueeze(unsqueeze_dim)
351
+ sin = sin.unsqueeze(unsqueeze_dim)
352
+
353
+ b, h, s, d = q.shape
354
+ q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
355
+
356
+ b, h, s, d = k.shape
357
+ k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
358
+
359
+ q_embed = (q * cos) + (rotate_half(q) * sin)
360
+ k_embed = (k * cos) + (rotate_half(k) * sin)
361
+ return q_embed, k_embed
362
+
363
+
364
+ def yarn_get_mscale(scale=1, mscale=1):
365
+ if scale <= 1:
366
+ return 1.0
367
+ return 0.1 * mscale * math.log(scale) + 1.0
368
+
369
+
370
+ class GravityMoEAttention(nn.Module):
371
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
372
+
373
+ def __init__(self, config: GravityMoEConfig, layer_idx: int):
374
+ super().__init__()
375
+ self.config = config
376
+ self.layer_idx = layer_idx
377
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
378
+ self.attention_dropout = config.attention_dropout
379
+ self.num_heads = config.num_attention_heads
380
+
381
+ self.q_lora_rank = config.q_lora_rank
382
+ self.qk_rope_head_dim = config.qk_rope_head_dim
383
+ self.kv_lora_rank = config.kv_lora_rank
384
+ self.v_head_dim = config.v_head_dim
385
+ self.qk_nope_head_dim = config.qk_nope_head_dim
386
+ self.qk_head_dim = config.qk_head_dim
387
+
388
+ self.is_causal = True
389
+ if self.q_lora_rank is None:
390
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
391
+ else:
392
+ self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
393
+ self.q_a_layernorm = GravityMoERMSNorm(config.q_lora_rank)
394
+ self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
395
+
396
+ self.kv_a_proj_with_mqa = nn.Linear(
397
+ config.hidden_size,
398
+ self.kv_lora_rank + self.qk_rope_head_dim,
399
+ bias=config.attention_bias,
400
+ )
401
+ self.kv_a_layernorm = GravityMoERMSNorm(self.kv_lora_rank)
402
+ self.kv_b_proj = nn.Linear(
403
+ self.kv_lora_rank,
404
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
405
+ bias=False,
406
+ )
407
+
408
+ self.o_proj = nn.Linear(
409
+ self.num_heads * self.v_head_dim,
410
+ config.hidden_size,
411
+ bias=config.attention_bias,
412
+ )
413
+
414
+ self.scaling = self.qk_head_dim ** (-0.5)
415
+ if self.config.rope_parameters.get("rope_type", "default") != "default":
416
+ mscale_all_dim = self.config.rope_parameters.get("mscale_all_dim", 0)
417
+ scaling_factor = self.config.rope_parameters["factor"]
418
+ if mscale_all_dim:
419
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
420
+ self.scaling = self.scaling * mscale * mscale
421
+
422
+ def forward(
423
+ self,
424
+ hidden_states: torch.Tensor,
425
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
426
+ attention_mask: torch.Tensor | None,
427
+ past_key_values: Cache | None = None,
428
+ **kwargs: Unpack[FlashAttentionKwargs],
429
+ ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
430
+ batch_size, seq_length = hidden_states.shape[:-1]
431
+ query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
432
+ key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
433
+
434
+ if self.q_lora_rank is None:
435
+ q_states = self.q_proj(hidden_states)
436
+ else:
437
+ q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
438
+ q_states = q_states.view(query_shape).transpose(1, 2)
439
+ q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
440
+
441
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
442
+ k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
443
+
444
+ k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)
445
+ k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
446
+
447
+ k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
448
+
449
+ cos, sin = position_embeddings
450
+ if self.config.rope_interleave: # support using interleaved weights for efficiency
451
+ q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
452
+ else:
453
+ q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
454
+ k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
455
+
456
+ query_states = torch.cat((q_pass, q_rot), dim=-1)
457
+ key_states = torch.cat((k_pass, k_rot), dim=-1)
458
+
459
+ if past_key_values is not None:
460
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
461
+
462
+ if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim:
463
+ value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
464
+
465
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
466
+ self.config._attn_implementation, eager_attention_forward
467
+ )
468
+
469
+ attn_output, attn_weights = attention_interface(
470
+ self,
471
+ query_states,
472
+ key_states,
473
+ value_states,
474
+ attention_mask,
475
+ dropout=0.0 if not self.training else self.attention_dropout,
476
+ scaling=self.scaling,
477
+ **kwargs,
478
+ )
479
+
480
+ if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim:
481
+ attn_output = attn_output[:, :, :, : self.v_head_dim]
482
+
483
+ attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
484
+ attn_output = self.o_proj(attn_output)
485
+ return attn_output, attn_weights
486
+
487
+
488
+ class GravityMoEDecoderLayer(GradientCheckpointingLayer):
489
+ def __init__(self, config: GravityMoEConfig, layer_idx: int):
490
+ super().__init__()
491
+ self.hidden_size = config.hidden_size
492
+
493
+ self.self_attn = GravityMoEAttention(config=config, layer_idx=layer_idx)
494
+
495
+ if layer_idx >= config.first_k_dense_replace:
496
+ self.mlp = GravityMoEMoE(config)
497
+ else:
498
+ self.mlp = GravityMoEMLP(config)
499
+
500
+ self.input_layernorm = GravityMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps)
501
+ self.post_attention_layernorm = GravityMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps)
502
+
503
+ def forward(
504
+ self,
505
+ hidden_states: torch.Tensor,
506
+ attention_mask: torch.Tensor | None = None,
507
+ position_ids: torch.LongTensor | None = None,
508
+ past_key_values: Cache | None = None,
509
+ use_cache: bool | None = False,
510
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
511
+ **kwargs: Unpack[TransformersKwargs],
512
+ ) -> torch.Tensor:
513
+ residual = hidden_states
514
+ hidden_states = self.input_layernorm(hidden_states)
515
+ # Self Attention
516
+ hidden_states, _ = self.self_attn(
517
+ hidden_states=hidden_states,
518
+ attention_mask=attention_mask,
519
+ position_ids=position_ids,
520
+ past_key_values=past_key_values,
521
+ use_cache=use_cache,
522
+ position_embeddings=position_embeddings,
523
+ **kwargs,
524
+ )
525
+ hidden_states = residual + hidden_states
526
+
527
+ # Fully Connected
528
+ residual = hidden_states
529
+ hidden_states = self.post_attention_layernorm(hidden_states)
530
+ hidden_states = self.mlp(hidden_states)
531
+ hidden_states = residual + hidden_states
532
+ return hidden_states
533
+
534
+
535
+ @auto_docstring
536
+ class GravityMoEPreTrainedModel(PreTrainedModel):
537
+ config: GravityMoEConfig
538
+ base_model_prefix = "model"
539
+ supports_gradient_checkpointing = True
540
+ _no_split_modules = ["GravityMoEDecoderLayer"]
541
+ _skip_keys_device_placement = ["past_key_values"]
542
+ _supports_flash_attn = True
543
+ _supports_sdpa = True
544
+ _supports_flex_attn = True
545
+
546
+ _can_compile_fullgraph = True
547
+ _supports_attention_backend = True
548
+ _can_record_outputs = {
549
+ "hidden_states": GravityMoEDecoderLayer,
550
+ "attentions": GravityMoEAttention,
551
+ }
552
+ _keep_in_fp32_modules_strict = ["e_score_correction_bias"]
553
+ _keys_to_ignore_on_load_unexpected = [r"model\.layers\.28.*"]
554
+
555
+ @torch.no_grad()
556
+ def _init_weights(self, module):
557
+ super()._init_weights(module)
558
+ if isinstance(module, GravityMoETopkRouter):
559
+ init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
560
+ init.zeros_(module.e_score_correction_bias)
561
+ elif isinstance(module, GravityMoENaiveMoe):
562
+ init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
563
+ init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
564
+
565
+
566
+ @auto_docstring
567
+ class GravityMoEModel(GravityMoEPreTrainedModel):
568
+ def __init__(self, config: GravityMoEConfig):
569
+ super().__init__(config)
570
+ self.padding_idx = config.pad_token_id
571
+ self.vocab_size = config.vocab_size
572
+
573
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
574
+ self.layers = nn.ModuleList(
575
+ [GravityMoEDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
576
+ )
577
+ self.norm = GravityMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps)
578
+ self.rotary_emb = GravityMoERotaryEmbedding(config=config)
579
+ self.gradient_checkpointing = False
580
+
581
+ # Initialize weights and apply final processing
582
+ self.post_init()
583
+
584
+ @merge_with_config_defaults
585
+ @capture_outputs
586
+ @auto_docstring
587
+ def forward(
588
+ self,
589
+ input_ids: torch.LongTensor | None = None,
590
+ attention_mask: torch.Tensor | None = None,
591
+ position_ids: torch.LongTensor | None = None,
592
+ past_key_values: Cache | None = None,
593
+ inputs_embeds: torch.FloatTensor | None = None,
594
+ use_cache: bool | None = None,
595
+ **kwargs: Unpack[TransformersKwargs],
596
+ ) -> BaseModelOutputWithPast:
597
+ if (input_ids is None) ^ (inputs_embeds is not None):
598
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
599
+
600
+ if inputs_embeds is None:
601
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
602
+
603
+ if use_cache and past_key_values is None:
604
+ past_key_values = DynamicCache(config=self.config)
605
+
606
+ if position_ids is None:
607
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
608
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
609
+ position_ids = position_ids.unsqueeze(0)
610
+
611
+ causal_mask = create_causal_mask(
612
+ config=self.config,
613
+ inputs_embeds=inputs_embeds,
614
+ attention_mask=attention_mask,
615
+ past_key_values=past_key_values,
616
+ position_ids=position_ids,
617
+ )
618
+
619
+ hidden_states = inputs_embeds
620
+ position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
621
+
622
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
623
+ hidden_states = decoder_layer(
624
+ hidden_states,
625
+ attention_mask=causal_mask,
626
+ position_embeddings=position_embeddings,
627
+ position_ids=position_ids,
628
+ past_key_values=past_key_values,
629
+ use_cache=use_cache,
630
+ **kwargs,
631
+ )
632
+
633
+ hidden_states = self.norm(hidden_states)
634
+ return BaseModelOutputWithPast(
635
+ last_hidden_state=hidden_states,
636
+ past_key_values=past_key_values,
637
+ )
638
+
639
+
640
+ @auto_docstring
641
+ class GravityMoEForCausalLM(GravityMoEPreTrainedModel, GenerationMixin):
642
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
643
+ _tp_plan = {"lm_head": "colwise_gather_output"}
644
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
645
+
646
+ def __init__(self, config):
647
+ super().__init__(config)
648
+ self.model = GravityMoEModel(config)
649
+ self.vocab_size = config.vocab_size
650
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
651
+
652
+ # Initialize weights and apply final processing
653
+ self.post_init()
654
+
655
+ @can_return_tuple
656
+ @auto_docstring
657
+ def forward(
658
+ self,
659
+ input_ids: torch.LongTensor | None = None,
660
+ attention_mask: torch.Tensor | None = None,
661
+ position_ids: torch.LongTensor | None = None,
662
+ past_key_values: Cache | None = None,
663
+ inputs_embeds: torch.FloatTensor | None = None,
664
+ labels: torch.LongTensor | None = None,
665
+ use_cache: bool | None = None,
666
+ logits_to_keep: int | torch.Tensor = 0,
667
+ **kwargs: Unpack[TransformersKwargs],
668
+ ) -> CausalLMOutputWithPast:
669
+ r"""
670
+ Example:
671
+
672
+ ```python
673
+ >>> from transformers import AutoTokenizer, GravityMoEForCausalLM
674
+
675
+ >>> model = GravityMoEForCausalLM.from_pretrained("trillion-labs/Gravity-MoE-16.2B-A3.2B")
676
+ >>> tokenizer = AutoTokenizer.from_pretrained("trillion-labs/Gravity-MoE-16.2B-A3.2B")
677
+
678
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
679
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
680
+
681
+ >>> # Generate
682
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
683
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
684
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
685
+ ```"""
686
+ outputs: BaseModelOutputWithPast = self.model(
687
+ input_ids=input_ids,
688
+ attention_mask=attention_mask,
689
+ position_ids=position_ids,
690
+ past_key_values=past_key_values,
691
+ inputs_embeds=inputs_embeds,
692
+ use_cache=use_cache,
693
+ **kwargs,
694
+ )
695
+
696
+ hidden_states = outputs.last_hidden_state
697
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
698
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
699
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
700
+
701
+ loss = None
702
+ if labels is not None:
703
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
704
+
705
+ return CausalLMOutputWithPast(
706
+ loss=loss,
707
+ logits=logits,
708
+ past_key_values=outputs.past_key_values,
709
+ hidden_states=outputs.hidden_states,
710
+ attentions=outputs.attentions,
711
+ )
712
+
713
+
714
+ class GravityMoEForSequenceClassification(GenericForSequenceClassification, GravityMoEPreTrainedModel):
715
+ pass
716
+
717
+
718
+ class GravityMoEForTokenClassification(GenericForTokenClassification, GravityMoEPreTrainedModel):
719
+ pass
720
+
721
+
722
+ __all__ = [
723
+ "GravityMoEPreTrainedModel",
724
+ "GravityMoEModel",
725
+ "GravityMoEForCausalLM",
726
+ "GravityMoEForSequenceClassification",
727
+ "GravityMoEForTokenClassification",
728
+ ]