Johnblick187 commited on
Commit
87f6245
·
verified ·
1 Parent(s): 1603b99

Update modeling_grok2.py

Browse files
Files changed (1) hide show
  1. modeling_grok2.py +29 -72
modeling_grok2.py CHANGED
@@ -1,34 +1,6 @@
1
  """
2
- modeling_grok2.py — Custom Grok 2 modeling code for transformers.
3
- Allows AutoModel to load Johnblick187/grok-2.
4
-
5
- Exact tensor key names:
6
- model.embed_tokens.weight [131072, 8192]
7
- model.layers.N.pre_attn_norm.weight [8192]
8
- model.layers.N.post_attn_norm.weight [8192]
9
- model.layers.N.pre_moe_norm.weight [8192]
10
- model.layers.N.post_moe_norm.weight [8192]
11
- model.layers.N.self_attn.q_proj.weight [8192, 8192]
12
- model.layers.N.self_attn.k_proj.weight [1024, 8192]
13
- model.layers.N.self_attn.v_proj.weight [1024, 8192]
14
- model.layers.N.self_attn.o_proj.weight [8192, 8192]
15
- model.layers.N.mlp.gate_proj.weight [32768, 8192]
16
- model.layers.N.mlp.up_proj.weight [32768, 8192]
17
- model.layers.N.mlp.down_proj.weight [8192, 32768]
18
- model.layers.N.block_sparse_moe.gate.weight [8, 8192]
19
- model.layers.N.block_sparse_moe.experts.E.w1.weight [16384, 8192]
20
- model.layers.N.block_sparse_moe.experts.E.w2.weight [8192, 16384]
21
- model.layers.N.block_sparse_moe.experts.E.w3.weight [16384, 8192]
22
- model.norm.weight [8192]
23
- lm_head.weight [131072, 8192]
24
-
25
- Architecture:
26
- 64 layers, hidden=8192, 64 attn heads, 8 KV heads, head_dim=128
27
- Dense residual MLP (SwiGLU): gate_proj, up_proj, down_proj
28
- Sparse MoE: 8 experts, top-2, SwiGLU (w1=gate, w3=up, w2=down)
29
- 4x RMSNorm per layer (no bias)
30
- RoPE with scaled theta
31
- KV cache disabled — forward pass only, no past_key_values
32
  """
33
 
34
  import math
@@ -40,7 +12,6 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
40
  from transformers import AutoConfig, AutoModelForCausalLM
41
 
42
 
43
- # ── Config ────────────────────────────────────────────────────────────────────
44
  class Grok2Config(PretrainedConfig):
45
  model_type = "grok2"
46
 
@@ -96,7 +67,6 @@ class Grok2Config(PretrainedConfig):
96
  )
97
 
98
 
99
- # ── RMSNorm ───────────────────────────────────────────────────────────────────
100
  class Grok2RMSNorm(nn.Module):
101
  def __init__(self, hidden_size, eps=1e-5):
102
  super().__init__()
@@ -105,15 +75,16 @@ class Grok2RMSNorm(nn.Module):
105
 
106
  def forward(self, x):
107
  variance = x.pow(2).mean(-1, keepdim=True)
108
- return self.weight * x * torch.rsqrt(variance + self.eps)
109
 
110
 
111
- # ── RoPE ──────────────────────────────────────────────────────────────────────
112
  def rotate_half(x):
113
  x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
114
  return torch.cat([-x2, x1], dim=-1)
115
 
116
  def apply_rotary_emb(q, k, cos, sin):
 
 
117
  return (q * cos) + (rotate_half(q) * sin), \
118
  (k * cos) + (rotate_half(k) * sin)
119
 
@@ -140,7 +111,6 @@ class Grok2RotaryEmbedding(nn.Module):
140
  self.sin_cached[:, :, :seq_len, :]
141
 
142
 
143
- # ── Attention ─────────────────────────────────────────────────────────────────
144
  class Grok2Attention(nn.Module):
145
  def __init__(self, config: Grok2Config):
146
  super().__init__()
@@ -158,17 +128,18 @@ class Grok2Attention(nn.Module):
158
 
159
  def forward(self, hidden_states, attention_mask=None, **kwargs):
160
  B, T, _ = hidden_states.shape
 
 
161
 
162
  q = self.q_proj(hidden_states).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
163
  k = self.k_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
164
  v = self.v_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
165
 
166
- cos, sin = self.rotary_emb(T, hidden_states.device)
167
- cos = cos[:, :, :T, :self.head_dim]
168
- sin = sin[:, :, :T, :self.head_dim]
169
  q, k = apply_rotary_emb(q, k, cos, sin)
170
 
171
- # GQA expand
172
  k = k.repeat_interleave(self.num_kv_groups, dim=1)
173
  v = v.repeat_interleave(self.num_kv_groups, dim=1)
174
 
@@ -181,21 +152,20 @@ class Grok2Attention(nn.Module):
181
  attn = attn * self.attn_softcap
182
 
183
  causal = torch.triu(
184
- torch.full((T, T), float("-inf"), device=q.device, dtype=q.dtype),
185
  diagonal=1
186
  )
187
  attn = attn + causal.unsqueeze(0).unsqueeze(0)
188
 
189
  if attention_mask is not None:
190
- attn = attn + attention_mask
191
 
192
- attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype)
193
  out = torch.matmul(attn, v)
194
  out = out.transpose(1, 2).contiguous().view(B, T, -1)
195
  return self.o_proj(out)
196
 
197
 
198
- # ── MoE Expert ────────────────────────────────────────────────────────────────
199
  class Grok2Expert(nn.Module):
200
  def __init__(self, hidden_size, moe_intermediate_size):
201
  super().__init__()
@@ -207,14 +177,12 @@ class Grok2Expert(nn.Module):
207
  return self.w2(F.silu(self.w1(x)) * self.w3(x))
208
 
209
 
210
- # ── Sparse MoE ────────────────────────────────────────────────────────────────
211
  class Grok2SparseMoE(nn.Module):
212
  def __init__(self, config: Grok2Config):
213
  super().__init__()
214
  self.num_experts = config.num_local_experts
215
  self.top_k = config.num_experts_per_tok
216
  self.router_softcap = config.router_logit_softcapping
217
-
218
  self.gate = nn.Linear(config.hidden_size, config.num_local_experts, bias=False)
219
  self.experts = nn.ModuleList([
220
  Grok2Expert(config.hidden_size, config.moe_intermediate_size)
@@ -224,31 +192,33 @@ class Grok2SparseMoE(nn.Module):
224
  def forward(self, x):
225
  B, T, H = x.shape
226
  x_flat = x.view(-1, H)
 
227
 
228
  router_logits = self.gate(x_flat)
229
-
230
  if self.router_softcap > 0:
231
- router_logits = router_logits / self.router_softcap
232
- router_logits = torch.tanh(router_logits)
233
- router_logits = router_logits * self.router_softcap
234
 
235
- router_weights = F.softmax(router_logits, dim=-1)
236
  top_weights, top_indices = router_weights.topk(self.top_k, dim=-1)
237
  top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)
238
 
239
  out = torch.zeros_like(x_flat)
240
  for k in range(self.top_k):
241
- expert_idx = top_indices[:, k]
242
- weight = top_weights[:, k].unsqueeze(-1)
243
  for e in range(self.num_experts):
244
- mask = (expert_idx == e)
245
- if mask.any():
246
- out[mask] += weight[mask] * self.experts[e](x_flat[mask])
 
 
 
 
 
247
 
248
  return out.view(B, T, H)
249
 
250
 
251
- # ── Dense MLP ─────────────────────────────────────────────────────────────────
252
  class Grok2MLP(nn.Module):
253
  def __init__(self, config: Grok2Config):
254
  super().__init__()
@@ -260,7 +230,6 @@ class Grok2MLP(nn.Module):
260
  return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
261
 
262
 
263
- # ── Decoder Layer ─────────────────────────────────────────────────────────────
264
  class Grok2DecoderLayer(nn.Module):
265
  def __init__(self, config: Grok2Config, layer_idx: int):
266
  super().__init__()
@@ -274,25 +243,22 @@ class Grok2DecoderLayer(nn.Module):
274
  self.post_moe_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
275
 
276
  def forward(self, hidden_states, attention_mask=None, **kwargs):
277
- # Attention
278
  residual = hidden_states
279
  hidden_states = self.pre_attn_norm(hidden_states)
280
  hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask)
281
  hidden_states = self.post_attn_norm(hidden_states)
282
  hidden_states = residual + hidden_states
283
 
284
- # MoE + dense residual
285
  residual = hidden_states
286
  hidden_states = self.pre_moe_norm(hidden_states)
287
  moe_out = self.block_sparse_moe(hidden_states)
288
  mlp_out = self.mlp(hidden_states)
289
- hidden_states = self.post_moe_norm(moe_out + mlp_out)
290
  hidden_states = residual + hidden_states
291
 
292
  return hidden_states
293
 
294
 
295
- # ── Model ─────────────────────────────────────────────────────────────────────
296
  class Grok2Model(nn.Module):
297
  def __init__(self, config: Grok2Config):
298
  super().__init__()
@@ -310,7 +276,6 @@ class Grok2Model(nn.Module):
310
  return self.norm(hidden_states)
311
 
312
 
313
- # ── CausalLM ──────────────────────────────────────────────────────────────────
314
  class Grok1ForCausalLM(PreTrainedModel, GenerationMixin):
315
  config_class = Grok2Config
316
  base_model_prefix = "model"
@@ -338,13 +303,10 @@ class Grok1ForCausalLM(PreTrainedModel, GenerationMixin):
338
  **kwargs,
339
  ):
340
  hidden_states = self.model(input_ids, attention_mask=attention_mask)
341
-
342
  logits = self.lm_head(hidden_states) * self.output_multiplier_scale
343
 
344
  if self.final_logit_softcapping > 0:
345
- logits = logits / self.final_logit_softcapping
346
- logits = torch.tanh(logits)
347
- logits = logits * self.final_logit_softcapping
348
 
349
  loss = None
350
  if labels is not None:
@@ -356,16 +318,11 @@ class Grok1ForCausalLM(PreTrainedModel, GenerationMixin):
356
  ignore_index=-100,
357
  )
358
 
359
- return CausalLMOutputWithPast(
360
- loss=loss,
361
- logits=logits,
362
- past_key_values=None,
363
- )
364
 
365
  def prepare_inputs_for_generation(self, input_ids, **kwargs):
366
  return {"input_ids": input_ids}
367
 
368
 
369
- # ── Register ──────────────────────────────────────────────────────────────────
370
  AutoConfig.register("grok2", Grok2Config)
371
  AutoModelForCausalLM.register(Grok2Config, Grok1ForCausalLM)
 
1
  """
2
+ modeling_grok2.py — Grok 2 modeling code for transformers.
3
+ Pure bf16, device-aware MoE, no dtype casting.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  """
5
 
6
  import math
 
12
  from transformers import AutoConfig, AutoModelForCausalLM
13
 
14
 
 
15
  class Grok2Config(PretrainedConfig):
16
  model_type = "grok2"
17
 
 
67
  )
68
 
69
 
 
70
  class Grok2RMSNorm(nn.Module):
71
  def __init__(self, hidden_size, eps=1e-5):
72
  super().__init__()
 
75
 
76
  def forward(self, x):
77
  variance = x.pow(2).mean(-1, keepdim=True)
78
+ return self.weight.to(x.device) * x * torch.rsqrt(variance + self.eps)
79
 
80
 
 
81
  def rotate_half(x):
82
  x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
83
  return torch.cat([-x2, x1], dim=-1)
84
 
85
  def apply_rotary_emb(q, k, cos, sin):
86
+ cos = cos.to(q.device, q.dtype)
87
+ sin = sin.to(q.device, q.dtype)
88
  return (q * cos) + (rotate_half(q) * sin), \
89
  (k * cos) + (rotate_half(k) * sin)
90
 
 
111
  self.sin_cached[:, :, :seq_len, :]
112
 
113
 
 
114
  class Grok2Attention(nn.Module):
115
  def __init__(self, config: Grok2Config):
116
  super().__init__()
 
128
 
129
  def forward(self, hidden_states, attention_mask=None, **kwargs):
130
  B, T, _ = hidden_states.shape
131
+ dtype = hidden_states.dtype
132
+ device = hidden_states.device
133
 
134
  q = self.q_proj(hidden_states).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
135
  k = self.k_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
136
  v = self.v_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
137
 
138
+ cos, sin = self.rotary_emb(T, device)
139
+ cos = cos[:, :, :T, :self.head_dim].to(dtype)
140
+ sin = sin[:, :, :T, :self.head_dim].to(dtype)
141
  q, k = apply_rotary_emb(q, k, cos, sin)
142
 
 
143
  k = k.repeat_interleave(self.num_kv_groups, dim=1)
144
  v = v.repeat_interleave(self.num_kv_groups, dim=1)
145
 
 
152
  attn = attn * self.attn_softcap
153
 
154
  causal = torch.triu(
155
+ torch.full((T, T), float("-inf"), device=device, dtype=dtype),
156
  diagonal=1
157
  )
158
  attn = attn + causal.unsqueeze(0).unsqueeze(0)
159
 
160
  if attention_mask is not None:
161
+ attn = attn + attention_mask.to(device, dtype)
162
 
163
+ attn = F.softmax(attn, dim=-1).to(dtype)
164
  out = torch.matmul(attn, v)
165
  out = out.transpose(1, 2).contiguous().view(B, T, -1)
166
  return self.o_proj(out)
167
 
168
 
 
169
  class Grok2Expert(nn.Module):
170
  def __init__(self, hidden_size, moe_intermediate_size):
171
  super().__init__()
 
177
  return self.w2(F.silu(self.w1(x)) * self.w3(x))
178
 
179
 
 
180
  class Grok2SparseMoE(nn.Module):
181
  def __init__(self, config: Grok2Config):
182
  super().__init__()
183
  self.num_experts = config.num_local_experts
184
  self.top_k = config.num_experts_per_tok
185
  self.router_softcap = config.router_logit_softcapping
 
186
  self.gate = nn.Linear(config.hidden_size, config.num_local_experts, bias=False)
187
  self.experts = nn.ModuleList([
188
  Grok2Expert(config.hidden_size, config.moe_intermediate_size)
 
192
  def forward(self, x):
193
  B, T, H = x.shape
194
  x_flat = x.view(-1, H)
195
+ dtype = x_flat.dtype
196
 
197
  router_logits = self.gate(x_flat)
 
198
  if self.router_softcap > 0:
199
+ router_logits = torch.tanh(router_logits / self.router_softcap) * self.router_softcap
 
 
200
 
201
+ router_weights = F.softmax(router_logits, dim=-1).to(dtype)
202
  top_weights, top_indices = router_weights.topk(self.top_k, dim=-1)
203
  top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)
204
 
205
  out = torch.zeros_like(x_flat)
206
  for k in range(self.top_k):
207
+ expert_ids = top_indices[:, k]
208
+ weights = top_weights[:, k].unsqueeze(-1)
209
  for e in range(self.num_experts):
210
+ mask = (expert_ids == e)
211
+ if not mask.any():
212
+ continue
213
+ expert_device = next(self.experts[e].parameters()).device
214
+ x_e = x_flat[mask].to(expert_device)
215
+ w_e = weights[mask].to(expert_device)
216
+ y_e = self.experts[e](x_e) * w_e
217
+ out[mask] += y_e.to(out.device)
218
 
219
  return out.view(B, T, H)
220
 
221
 
 
222
  class Grok2MLP(nn.Module):
223
  def __init__(self, config: Grok2Config):
224
  super().__init__()
 
230
  return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
231
 
232
 
 
233
  class Grok2DecoderLayer(nn.Module):
234
  def __init__(self, config: Grok2Config, layer_idx: int):
235
  super().__init__()
 
243
  self.post_moe_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
244
 
245
  def forward(self, hidden_states, attention_mask=None, **kwargs):
 
246
  residual = hidden_states
247
  hidden_states = self.pre_attn_norm(hidden_states)
248
  hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask)
249
  hidden_states = self.post_attn_norm(hidden_states)
250
  hidden_states = residual + hidden_states
251
 
 
252
  residual = hidden_states
253
  hidden_states = self.pre_moe_norm(hidden_states)
254
  moe_out = self.block_sparse_moe(hidden_states)
255
  mlp_out = self.mlp(hidden_states)
256
+ hidden_states = self.post_moe_norm(moe_out.to(mlp_out.device) + mlp_out)
257
  hidden_states = residual + hidden_states
258
 
259
  return hidden_states
260
 
261
 
 
262
  class Grok2Model(nn.Module):
263
  def __init__(self, config: Grok2Config):
264
  super().__init__()
 
276
  return self.norm(hidden_states)
277
 
278
 
 
279
  class Grok1ForCausalLM(PreTrainedModel, GenerationMixin):
280
  config_class = Grok2Config
281
  base_model_prefix = "model"
 
303
  **kwargs,
304
  ):
305
  hidden_states = self.model(input_ids, attention_mask=attention_mask)
 
306
  logits = self.lm_head(hidden_states) * self.output_multiplier_scale
307
 
308
  if self.final_logit_softcapping > 0:
309
+ logits = torch.tanh(logits / self.final_logit_softcapping) * self.final_logit_softcapping
 
 
310
 
311
  loss = None
312
  if labels is not None:
 
318
  ignore_index=-100,
319
  )
320
 
321
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=None)
 
 
 
 
322
 
323
  def prepare_inputs_for_generation(self, input_ids, **kwargs):
324
  return {"input_ids": input_ids}
325
 
326
 
 
327
  AutoConfig.register("grok2", Grok2Config)
328
  AutoModelForCausalLM.register(Grok2Config, Grok1ForCausalLM)