Johnblick187 commited on
Commit
e068954
Β·
verified Β·
1 Parent(s): 87f6245

Update modeling_grok2.py

Browse files
Files changed (1) hide show
  1. modeling_grok2.py +56 -36
modeling_grok2.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
- modeling_grok2.py β€” Grok 2 modeling code for transformers.
3
- Pure bf16, device-aware MoE, no dtype casting.
4
  """
5
 
6
  import math
@@ -12,6 +12,7 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
12
  from transformers import AutoConfig, AutoModelForCausalLM
13
 
14
 
 
15
  class Grok2Config(PretrainedConfig):
16
  model_type = "grok2"
17
 
@@ -67,6 +68,7 @@ class Grok2Config(PretrainedConfig):
67
  )
68
 
69
 
 
70
  class Grok2RMSNorm(nn.Module):
71
  def __init__(self, hidden_size, eps=1e-5):
72
  super().__init__()
@@ -74,17 +76,18 @@ class Grok2RMSNorm(nn.Module):
74
  self.eps = eps
75
 
76
  def forward(self, x):
 
77
  variance = x.pow(2).mean(-1, keepdim=True)
78
- return self.weight.to(x.device) * x * torch.rsqrt(variance + self.eps)
 
79
 
80
 
 
81
  def rotate_half(x):
82
  x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
83
  return torch.cat([-x2, x1], dim=-1)
84
 
85
  def apply_rotary_emb(q, k, cos, sin):
86
- cos = cos.to(q.device, q.dtype)
87
- sin = sin.to(q.device, q.dtype)
88
  return (q * cos) + (rotate_half(q) * sin), \
89
  (k * cos) + (rotate_half(k) * sin)
90
 
@@ -93,24 +96,25 @@ class Grok2RotaryEmbedding(nn.Module):
93
  super().__init__()
94
  base = base * scaling_factor
95
  inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
96
- self.register_buffer("inv_freq", inv_freq)
97
  self._cached_len = 0
98
 
99
- def _build_cache(self, seq_len, device):
100
  t = torch.arange(seq_len, device=device).float()
101
  freqs = torch.outer(t, self.inv_freq.to(device))
102
  emb = torch.cat([freqs, freqs], dim=-1)
103
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
104
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
105
  self._cached_len = seq_len
 
106
 
107
- def forward(self, seq_len, device):
108
- if seq_len > self._cached_len:
109
- self._build_cache(seq_len, device)
110
- return self.cos_cached[:, :, :seq_len, :], \
111
- self.sin_cached[:, :, :seq_len, :]
112
 
113
 
 
114
  class Grok2Attention(nn.Module):
115
  def __init__(self, config: Grok2Config):
116
  super().__init__()
@@ -128,18 +132,19 @@ class Grok2Attention(nn.Module):
128
 
129
  def forward(self, hidden_states, attention_mask=None, **kwargs):
130
  B, T, _ = hidden_states.shape
131
- dtype = hidden_states.dtype
132
  device = hidden_states.device
 
133
 
134
  q = self.q_proj(hidden_states).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
135
  k = self.k_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
136
  v = self.v_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
137
 
138
- cos, sin = self.rotary_emb(T, device)
139
- cos = cos[:, :, :T, :self.head_dim].to(dtype)
140
- sin = sin[:, :, :T, :self.head_dim].to(dtype)
141
  q, k = apply_rotary_emb(q, k, cos, sin)
142
 
 
143
  k = k.repeat_interleave(self.num_kv_groups, dim=1)
144
  v = v.repeat_interleave(self.num_kv_groups, dim=1)
145
 
@@ -147,9 +152,7 @@ class Grok2Attention(nn.Module):
147
  attn = torch.matmul(q, k.transpose(-2, -1)) / scale
148
 
149
  if self.attn_softcap > 0:
150
- attn = attn / self.attn_softcap
151
- attn = torch.tanh(attn)
152
- attn = attn * self.attn_softcap
153
 
154
  causal = torch.triu(
155
  torch.full((T, T), float("-inf"), device=device, dtype=dtype),
@@ -158,7 +161,7 @@ class Grok2Attention(nn.Module):
158
  attn = attn + causal.unsqueeze(0).unsqueeze(0)
159
 
160
  if attention_mask is not None:
161
- attn = attn + attention_mask.to(device, dtype)
162
 
163
  attn = F.softmax(attn, dim=-1).to(dtype)
164
  out = torch.matmul(attn, v)
@@ -166,6 +169,7 @@ class Grok2Attention(nn.Module):
166
  return self.o_proj(out)
167
 
168
 
 
169
  class Grok2Expert(nn.Module):
170
  def __init__(self, hidden_size, moe_intermediate_size):
171
  super().__init__()
@@ -177,12 +181,14 @@ class Grok2Expert(nn.Module):
177
  return self.w2(F.silu(self.w1(x)) * self.w3(x))
178
 
179
 
 
180
  class Grok2SparseMoE(nn.Module):
181
  def __init__(self, config: Grok2Config):
182
  super().__init__()
183
  self.num_experts = config.num_local_experts
184
  self.top_k = config.num_experts_per_tok
185
  self.router_softcap = config.router_logit_softcapping
 
186
  self.gate = nn.Linear(config.hidden_size, config.num_local_experts, bias=False)
187
  self.experts = nn.ModuleList([
188
  Grok2Expert(config.hidden_size, config.moe_intermediate_size)
@@ -191,14 +197,15 @@ class Grok2SparseMoE(nn.Module):
191
 
192
  def forward(self, x):
193
  B, T, H = x.shape
 
 
194
  x_flat = x.view(-1, H)
195
- dtype = x_flat.dtype
196
 
197
  router_logits = self.gate(x_flat)
198
  if self.router_softcap > 0:
199
  router_logits = torch.tanh(router_logits / self.router_softcap) * self.router_softcap
200
 
201
- router_weights = F.softmax(router_logits, dim=-1).to(dtype)
202
  top_weights, top_indices = router_weights.topk(self.top_k, dim=-1)
203
  top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)
204
 
@@ -210,15 +217,16 @@ class Grok2SparseMoE(nn.Module):
210
  mask = (expert_ids == e)
211
  if not mask.any():
212
  continue
 
213
  expert_device = next(self.experts[e].parameters()).device
214
- x_e = x_flat[mask].to(expert_device)
215
- w_e = weights[mask].to(expert_device)
216
- y_e = self.experts[e](x_e) * w_e
217
- out[mask] += y_e.to(out.device)
218
 
219
  return out.view(B, T, H)
220
 
221
 
 
222
  class Grok2MLP(nn.Module):
223
  def __init__(self, config: Grok2Config):
224
  super().__init__()
@@ -230,6 +238,7 @@ class Grok2MLP(nn.Module):
230
  return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
231
 
232
 
 
233
  class Grok2DecoderLayer(nn.Module):
234
  def __init__(self, config: Grok2Config, layer_idx: int):
235
  super().__init__()
@@ -243,22 +252,29 @@ class Grok2DecoderLayer(nn.Module):
243
  self.post_moe_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
244
 
245
  def forward(self, hidden_states, attention_mask=None, **kwargs):
 
 
 
 
246
  residual = hidden_states
247
  hidden_states = self.pre_attn_norm(hidden_states)
248
  hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask)
249
- hidden_states = self.post_attn_norm(hidden_states)
250
- hidden_states = residual + hidden_states
251
 
 
252
  residual = hidden_states
253
- hidden_states = self.pre_moe_norm(hidden_states)
254
- moe_out = self.block_sparse_moe(hidden_states)
255
- mlp_out = self.mlp(hidden_states)
256
- hidden_states = self.post_moe_norm(moe_out.to(mlp_out.device) + mlp_out)
257
- hidden_states = residual + hidden_states
 
258
 
259
  return hidden_states
260
 
261
 
 
262
  class Grok2Model(nn.Module):
263
  def __init__(self, config: Grok2Config):
264
  super().__init__()
@@ -276,6 +292,7 @@ class Grok2Model(nn.Module):
276
  return self.norm(hidden_states)
277
 
278
 
 
279
  class Grok1ForCausalLM(PreTrainedModel, GenerationMixin):
280
  config_class = Grok2Config
281
  base_model_prefix = "model"
@@ -303,6 +320,8 @@ class Grok1ForCausalLM(PreTrainedModel, GenerationMixin):
303
  **kwargs,
304
  ):
305
  hidden_states = self.model(input_ids, attention_mask=attention_mask)
 
 
306
  logits = self.lm_head(hidden_states) * self.output_multiplier_scale
307
 
308
  if self.final_logit_softcapping > 0:
@@ -324,5 +343,6 @@ class Grok1ForCausalLM(PreTrainedModel, GenerationMixin):
324
  return {"input_ids": input_ids}
325
 
326
 
 
327
  AutoConfig.register("grok2", Grok2Config)
328
- AutoModelForCausalLM.register(Grok2Config, Grok1ForCausalLM)
 
1
  """
2
+ modeling_grok2.py β€” Grok 2 for transformers, full multi-GPU support.
3
+ Pure bf16 throughout. Device-aware at every operation.
4
  """
5
 
6
  import math
 
12
  from transformers import AutoConfig, AutoModelForCausalLM
13
 
14
 
15
+ # ── Config ────────────────────────────────────────────────────────────────────
16
  class Grok2Config(PretrainedConfig):
17
  model_type = "grok2"
18
 
 
68
  )
69
 
70
 
71
+ # ── RMSNorm ───────────────────────────────────────────────────────────────────
72
  class Grok2RMSNorm(nn.Module):
73
  def __init__(self, hidden_size, eps=1e-5):
74
  super().__init__()
 
76
  self.eps = eps
77
 
78
  def forward(self, x):
79
+ # Stay in input dtype throughout
80
  variance = x.pow(2).mean(-1, keepdim=True)
81
+ x = x * torch.rsqrt(variance + self.eps)
82
+ return self.weight.to(x.device, x.dtype) * x
83
 
84
 
85
+ # ── RoPE ──────────────────────────────────────────────────────────────────────
86
  def rotate_half(x):
87
  x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
88
  return torch.cat([-x2, x1], dim=-1)
89
 
90
  def apply_rotary_emb(q, k, cos, sin):
 
 
91
  return (q * cos) + (rotate_half(q) * sin), \
92
  (k * cos) + (rotate_half(k) * sin)
93
 
 
96
  super().__init__()
97
  base = base * scaling_factor
98
  inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
99
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
100
  self._cached_len = 0
101
 
102
+ def _build_cache(self, seq_len, device, dtype):
103
  t = torch.arange(seq_len, device=device).float()
104
  freqs = torch.outer(t, self.inv_freq.to(device))
105
  emb = torch.cat([freqs, freqs], dim=-1)
106
+ self._cos = emb.cos().to(dtype)[None, None, :, :]
107
+ self._sin = emb.sin().to(dtype)[None, None, :, :]
108
  self._cached_len = seq_len
109
+ self._cached_device = device
110
 
111
+ def forward(self, seq_len, device, dtype):
112
+ if seq_len > self._cached_len or not hasattr(self, '_cached_device') or device != self._cached_device:
113
+ self._build_cache(seq_len, device, dtype)
114
+ return self._cos[:, :, :seq_len, :], self._sin[:, :, :seq_len, :]
 
115
 
116
 
117
+ # ── Attention ─────────────────────────────────────────────────────────────────
118
  class Grok2Attention(nn.Module):
119
  def __init__(self, config: Grok2Config):
120
  super().__init__()
 
132
 
133
  def forward(self, hidden_states, attention_mask=None, **kwargs):
134
  B, T, _ = hidden_states.shape
 
135
  device = hidden_states.device
136
+ dtype = hidden_states.dtype
137
 
138
  q = self.q_proj(hidden_states).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
139
  k = self.k_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
140
  v = self.v_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
141
 
142
+ cos, sin = self.rotary_emb(T, device, dtype)
143
+ cos = cos[:, :, :T, :self.head_dim]
144
+ sin = sin[:, :, :T, :self.head_dim]
145
  q, k = apply_rotary_emb(q, k, cos, sin)
146
 
147
+ # GQA expand
148
  k = k.repeat_interleave(self.num_kv_groups, dim=1)
149
  v = v.repeat_interleave(self.num_kv_groups, dim=1)
150
 
 
152
  attn = torch.matmul(q, k.transpose(-2, -1)) / scale
153
 
154
  if self.attn_softcap > 0:
155
+ attn = torch.tanh(attn / self.attn_softcap) * self.attn_softcap
 
 
156
 
157
  causal = torch.triu(
158
  torch.full((T, T), float("-inf"), device=device, dtype=dtype),
 
161
  attn = attn + causal.unsqueeze(0).unsqueeze(0)
162
 
163
  if attention_mask is not None:
164
+ attn = attn + attention_mask.to(device=device, dtype=dtype)
165
 
166
  attn = F.softmax(attn, dim=-1).to(dtype)
167
  out = torch.matmul(attn, v)
 
169
  return self.o_proj(out)
170
 
171
 
172
+ # ── MoE Expert ────────────────────────────────────────────────────────────────
173
  class Grok2Expert(nn.Module):
174
  def __init__(self, hidden_size, moe_intermediate_size):
175
  super().__init__()
 
181
  return self.w2(F.silu(self.w1(x)) * self.w3(x))
182
 
183
 
184
+ # ── Sparse MoE ────────────────────────────────────────────────────────────────
185
  class Grok2SparseMoE(nn.Module):
186
  def __init__(self, config: Grok2Config):
187
  super().__init__()
188
  self.num_experts = config.num_local_experts
189
  self.top_k = config.num_experts_per_tok
190
  self.router_softcap = config.router_logit_softcapping
191
+
192
  self.gate = nn.Linear(config.hidden_size, config.num_local_experts, bias=False)
193
  self.experts = nn.ModuleList([
194
  Grok2Expert(config.hidden_size, config.moe_intermediate_size)
 
197
 
198
  def forward(self, x):
199
  B, T, H = x.shape
200
+ device = x.device
201
+ dtype = x.dtype
202
  x_flat = x.view(-1, H)
 
203
 
204
  router_logits = self.gate(x_flat)
205
  if self.router_softcap > 0:
206
  router_logits = torch.tanh(router_logits / self.router_softcap) * self.router_softcap
207
 
208
+ router_weights = F.softmax(router_logits, dim=-1)
209
  top_weights, top_indices = router_weights.topk(self.top_k, dim=-1)
210
  top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)
211
 
 
217
  mask = (expert_ids == e)
218
  if not mask.any():
219
  continue
220
+ # Move tokens to expert's device, compute, move result back
221
  expert_device = next(self.experts[e].parameters()).device
222
+ x_masked = x_flat[mask].to(device=expert_device, dtype=dtype)
223
+ expert_out = self.experts[e](x_masked).to(device=device, dtype=dtype)
224
+ out[mask] += weights[mask] * expert_out
 
225
 
226
  return out.view(B, T, H)
227
 
228
 
229
+ # ── Dense MLP ─────────────────────────────────────────────────────────────────
230
  class Grok2MLP(nn.Module):
231
  def __init__(self, config: Grok2Config):
232
  super().__init__()
 
238
  return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
239
 
240
 
241
+ # ── Decoder Layer ─────────────────────────────────────────────────────────────
242
  class Grok2DecoderLayer(nn.Module):
243
  def __init__(self, config: Grok2Config, layer_idx: int):
244
  super().__init__()
 
252
  self.post_moe_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
253
 
254
  def forward(self, hidden_states, attention_mask=None, **kwargs):
255
+ device = hidden_states.device
256
+ dtype = hidden_states.dtype
257
+
258
+ # Attention block
259
  residual = hidden_states
260
  hidden_states = self.pre_attn_norm(hidden_states)
261
  hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask)
262
+ hidden_states = self.post_attn_norm(hidden_states.to(device=device, dtype=dtype))
263
+ hidden_states = residual + hidden_states.to(device=device, dtype=dtype)
264
 
265
+ # MoE + dense residual block
266
  residual = hidden_states
267
+ normed = self.pre_moe_norm(hidden_states)
268
+ moe_out = self.block_sparse_moe(normed)
269
+ mlp_out = self.mlp(normed)
270
+ combined = moe_out.to(device=device, dtype=dtype) + mlp_out.to(device=device, dtype=dtype)
271
+ hidden_states = self.post_moe_norm(combined)
272
+ hidden_states = residual + hidden_states.to(device=device, dtype=dtype)
273
 
274
  return hidden_states
275
 
276
 
277
+ # ── Model ─────────────────────────────────────────────────────────────────────
278
  class Grok2Model(nn.Module):
279
  def __init__(self, config: Grok2Config):
280
  super().__init__()
 
292
  return self.norm(hidden_states)
293
 
294
 
295
+ # ── CausalLM ──────────────────────────────────────────────────────────────────
296
  class Grok1ForCausalLM(PreTrainedModel, GenerationMixin):
297
  config_class = Grok2Config
298
  base_model_prefix = "model"
 
320
  **kwargs,
321
  ):
322
  hidden_states = self.model(input_ids, attention_mask=attention_mask)
323
+ # Move to lm_head device
324
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
325
  logits = self.lm_head(hidden_states) * self.output_multiplier_scale
326
 
327
  if self.final_logit_softcapping > 0:
 
343
  return {"input_ids": input_ids}
344
 
345
 
346
+ # ── Register ──────────────────────────────────────────────────────────────────
347
  AutoConfig.register("grok2", Grok2Config)
348
+ AutoModelForCausalLM.register(Grok2Config, Grok1ForCausalLM)