Update modeling_grok2.py
Browse files- modeling_grok2.py +56 -36
modeling_grok2.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
-
modeling_grok2.py β Grok 2
|
| 3 |
-
Pure bf16
|
| 4 |
"""
|
| 5 |
|
| 6 |
import math
|
|
@@ -12,6 +12,7 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
| 12 |
from transformers import AutoConfig, AutoModelForCausalLM
|
| 13 |
|
| 14 |
|
|
|
|
| 15 |
class Grok2Config(PretrainedConfig):
|
| 16 |
model_type = "grok2"
|
| 17 |
|
|
@@ -67,6 +68,7 @@ class Grok2Config(PretrainedConfig):
|
|
| 67 |
)
|
| 68 |
|
| 69 |
|
|
|
|
| 70 |
class Grok2RMSNorm(nn.Module):
|
| 71 |
def __init__(self, hidden_size, eps=1e-5):
|
| 72 |
super().__init__()
|
|
@@ -74,17 +76,18 @@ class Grok2RMSNorm(nn.Module):
|
|
| 74 |
self.eps = eps
|
| 75 |
|
| 76 |
def forward(self, x):
|
|
|
|
| 77 |
variance = x.pow(2).mean(-1, keepdim=True)
|
| 78 |
-
|
|
|
|
| 79 |
|
| 80 |
|
|
|
|
| 81 |
def rotate_half(x):
|
| 82 |
x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
|
| 83 |
return torch.cat([-x2, x1], dim=-1)
|
| 84 |
|
| 85 |
def apply_rotary_emb(q, k, cos, sin):
|
| 86 |
-
cos = cos.to(q.device, q.dtype)
|
| 87 |
-
sin = sin.to(q.device, q.dtype)
|
| 88 |
return (q * cos) + (rotate_half(q) * sin), \
|
| 89 |
(k * cos) + (rotate_half(k) * sin)
|
| 90 |
|
|
@@ -93,24 +96,25 @@ class Grok2RotaryEmbedding(nn.Module):
|
|
| 93 |
super().__init__()
|
| 94 |
base = base * scaling_factor
|
| 95 |
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 96 |
-
self.register_buffer("inv_freq", inv_freq)
|
| 97 |
self._cached_len = 0
|
| 98 |
|
| 99 |
-
def _build_cache(self, seq_len, device):
|
| 100 |
t = torch.arange(seq_len, device=device).float()
|
| 101 |
freqs = torch.outer(t, self.inv_freq.to(device))
|
| 102 |
emb = torch.cat([freqs, freqs], dim=-1)
|
| 103 |
-
self.
|
| 104 |
-
self.
|
| 105 |
self._cached_len = seq_len
|
|
|
|
| 106 |
|
| 107 |
-
def forward(self, seq_len, device):
|
| 108 |
-
if seq_len > self._cached_len:
|
| 109 |
-
self._build_cache(seq_len, device)
|
| 110 |
-
return self.
|
| 111 |
-
self.sin_cached[:, :, :seq_len, :]
|
| 112 |
|
| 113 |
|
|
|
|
| 114 |
class Grok2Attention(nn.Module):
|
| 115 |
def __init__(self, config: Grok2Config):
|
| 116 |
super().__init__()
|
|
@@ -128,18 +132,19 @@ class Grok2Attention(nn.Module):
|
|
| 128 |
|
| 129 |
def forward(self, hidden_states, attention_mask=None, **kwargs):
|
| 130 |
B, T, _ = hidden_states.shape
|
| 131 |
-
dtype = hidden_states.dtype
|
| 132 |
device = hidden_states.device
|
|
|
|
| 133 |
|
| 134 |
q = self.q_proj(hidden_states).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
|
| 135 |
k = self.k_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| 136 |
v = self.v_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| 137 |
|
| 138 |
-
cos, sin = self.rotary_emb(T, device)
|
| 139 |
-
cos = cos[:, :, :T, :self.head_dim]
|
| 140 |
-
sin = sin[:, :, :T, :self.head_dim]
|
| 141 |
q, k = apply_rotary_emb(q, k, cos, sin)
|
| 142 |
|
|
|
|
| 143 |
k = k.repeat_interleave(self.num_kv_groups, dim=1)
|
| 144 |
v = v.repeat_interleave(self.num_kv_groups, dim=1)
|
| 145 |
|
|
@@ -147,9 +152,7 @@ class Grok2Attention(nn.Module):
|
|
| 147 |
attn = torch.matmul(q, k.transpose(-2, -1)) / scale
|
| 148 |
|
| 149 |
if self.attn_softcap > 0:
|
| 150 |
-
attn = attn / self.attn_softcap
|
| 151 |
-
attn = torch.tanh(attn)
|
| 152 |
-
attn = attn * self.attn_softcap
|
| 153 |
|
| 154 |
causal = torch.triu(
|
| 155 |
torch.full((T, T), float("-inf"), device=device, dtype=dtype),
|
|
@@ -158,7 +161,7 @@ class Grok2Attention(nn.Module):
|
|
| 158 |
attn = attn + causal.unsqueeze(0).unsqueeze(0)
|
| 159 |
|
| 160 |
if attention_mask is not None:
|
| 161 |
-
attn = attn + attention_mask.to(device, dtype)
|
| 162 |
|
| 163 |
attn = F.softmax(attn, dim=-1).to(dtype)
|
| 164 |
out = torch.matmul(attn, v)
|
|
@@ -166,6 +169,7 @@ class Grok2Attention(nn.Module):
|
|
| 166 |
return self.o_proj(out)
|
| 167 |
|
| 168 |
|
|
|
|
| 169 |
class Grok2Expert(nn.Module):
|
| 170 |
def __init__(self, hidden_size, moe_intermediate_size):
|
| 171 |
super().__init__()
|
|
@@ -177,12 +181,14 @@ class Grok2Expert(nn.Module):
|
|
| 177 |
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
| 178 |
|
| 179 |
|
|
|
|
| 180 |
class Grok2SparseMoE(nn.Module):
|
| 181 |
def __init__(self, config: Grok2Config):
|
| 182 |
super().__init__()
|
| 183 |
self.num_experts = config.num_local_experts
|
| 184 |
self.top_k = config.num_experts_per_tok
|
| 185 |
self.router_softcap = config.router_logit_softcapping
|
|
|
|
| 186 |
self.gate = nn.Linear(config.hidden_size, config.num_local_experts, bias=False)
|
| 187 |
self.experts = nn.ModuleList([
|
| 188 |
Grok2Expert(config.hidden_size, config.moe_intermediate_size)
|
|
@@ -191,14 +197,15 @@ class Grok2SparseMoE(nn.Module):
|
|
| 191 |
|
| 192 |
def forward(self, x):
|
| 193 |
B, T, H = x.shape
|
|
|
|
|
|
|
| 194 |
x_flat = x.view(-1, H)
|
| 195 |
-
dtype = x_flat.dtype
|
| 196 |
|
| 197 |
router_logits = self.gate(x_flat)
|
| 198 |
if self.router_softcap > 0:
|
| 199 |
router_logits = torch.tanh(router_logits / self.router_softcap) * self.router_softcap
|
| 200 |
|
| 201 |
-
router_weights = F.softmax(router_logits, dim=-1)
|
| 202 |
top_weights, top_indices = router_weights.topk(self.top_k, dim=-1)
|
| 203 |
top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)
|
| 204 |
|
|
@@ -210,15 +217,16 @@ class Grok2SparseMoE(nn.Module):
|
|
| 210 |
mask = (expert_ids == e)
|
| 211 |
if not mask.any():
|
| 212 |
continue
|
|
|
|
| 213 |
expert_device = next(self.experts[e].parameters()).device
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
out[mask] += y_e.to(out.device)
|
| 218 |
|
| 219 |
return out.view(B, T, H)
|
| 220 |
|
| 221 |
|
|
|
|
| 222 |
class Grok2MLP(nn.Module):
|
| 223 |
def __init__(self, config: Grok2Config):
|
| 224 |
super().__init__()
|
|
@@ -230,6 +238,7 @@ class Grok2MLP(nn.Module):
|
|
| 230 |
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
| 231 |
|
| 232 |
|
|
|
|
| 233 |
class Grok2DecoderLayer(nn.Module):
|
| 234 |
def __init__(self, config: Grok2Config, layer_idx: int):
|
| 235 |
super().__init__()
|
|
@@ -243,22 +252,29 @@ class Grok2DecoderLayer(nn.Module):
|
|
| 243 |
self.post_moe_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
| 244 |
|
| 245 |
def forward(self, hidden_states, attention_mask=None, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
residual = hidden_states
|
| 247 |
hidden_states = self.pre_attn_norm(hidden_states)
|
| 248 |
hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask)
|
| 249 |
-
hidden_states = self.post_attn_norm(hidden_states)
|
| 250 |
-
hidden_states = residual + hidden_states
|
| 251 |
|
|
|
|
| 252 |
residual = hidden_states
|
| 253 |
-
|
| 254 |
-
moe_out = self.block_sparse_moe(
|
| 255 |
-
mlp_out = self.mlp(
|
| 256 |
-
|
| 257 |
-
hidden_states =
|
|
|
|
| 258 |
|
| 259 |
return hidden_states
|
| 260 |
|
| 261 |
|
|
|
|
| 262 |
class Grok2Model(nn.Module):
|
| 263 |
def __init__(self, config: Grok2Config):
|
| 264 |
super().__init__()
|
|
@@ -276,6 +292,7 @@ class Grok2Model(nn.Module):
|
|
| 276 |
return self.norm(hidden_states)
|
| 277 |
|
| 278 |
|
|
|
|
| 279 |
class Grok1ForCausalLM(PreTrainedModel, GenerationMixin):
|
| 280 |
config_class = Grok2Config
|
| 281 |
base_model_prefix = "model"
|
|
@@ -303,6 +320,8 @@ class Grok1ForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 303 |
**kwargs,
|
| 304 |
):
|
| 305 |
hidden_states = self.model(input_ids, attention_mask=attention_mask)
|
|
|
|
|
|
|
| 306 |
logits = self.lm_head(hidden_states) * self.output_multiplier_scale
|
| 307 |
|
| 308 |
if self.final_logit_softcapping > 0:
|
|
@@ -324,5 +343,6 @@ class Grok1ForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 324 |
return {"input_ids": input_ids}
|
| 325 |
|
| 326 |
|
|
|
|
| 327 |
AutoConfig.register("grok2", Grok2Config)
|
| 328 |
-
AutoModelForCausalLM.register(Grok2Config, Grok1ForCausalLM)
|
|
|
|
| 1 |
"""
|
| 2 |
+
modeling_grok2.py β Grok 2 for transformers, full multi-GPU support.
|
| 3 |
+
Pure bf16 throughout. Device-aware at every operation.
|
| 4 |
"""
|
| 5 |
|
| 6 |
import math
|
|
|
|
| 12 |
from transformers import AutoConfig, AutoModelForCausalLM
|
| 13 |
|
| 14 |
|
| 15 |
+
# ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 16 |
class Grok2Config(PretrainedConfig):
|
| 17 |
model_type = "grok2"
|
| 18 |
|
|
|
|
| 68 |
)
|
| 69 |
|
| 70 |
|
| 71 |
+
# ββ RMSNorm βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 72 |
class Grok2RMSNorm(nn.Module):
|
| 73 |
def __init__(self, hidden_size, eps=1e-5):
|
| 74 |
super().__init__()
|
|
|
|
| 76 |
self.eps = eps
|
| 77 |
|
| 78 |
def forward(self, x):
|
| 79 |
+
# Stay in input dtype throughout
|
| 80 |
variance = x.pow(2).mean(-1, keepdim=True)
|
| 81 |
+
x = x * torch.rsqrt(variance + self.eps)
|
| 82 |
+
return self.weight.to(x.device, x.dtype) * x
|
| 83 |
|
| 84 |
|
| 85 |
+
# ββ RoPE ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 86 |
def rotate_half(x):
|
| 87 |
x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
|
| 88 |
return torch.cat([-x2, x1], dim=-1)
|
| 89 |
|
| 90 |
def apply_rotary_emb(q, k, cos, sin):
|
|
|
|
|
|
|
| 91 |
return (q * cos) + (rotate_half(q) * sin), \
|
| 92 |
(k * cos) + (rotate_half(k) * sin)
|
| 93 |
|
|
|
|
| 96 |
super().__init__()
|
| 97 |
base = base * scaling_factor
|
| 98 |
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 99 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 100 |
self._cached_len = 0
|
| 101 |
|
| 102 |
+
def _build_cache(self, seq_len, device, dtype):
|
| 103 |
t = torch.arange(seq_len, device=device).float()
|
| 104 |
freqs = torch.outer(t, self.inv_freq.to(device))
|
| 105 |
emb = torch.cat([freqs, freqs], dim=-1)
|
| 106 |
+
self._cos = emb.cos().to(dtype)[None, None, :, :]
|
| 107 |
+
self._sin = emb.sin().to(dtype)[None, None, :, :]
|
| 108 |
self._cached_len = seq_len
|
| 109 |
+
self._cached_device = device
|
| 110 |
|
| 111 |
+
def forward(self, seq_len, device, dtype):
|
| 112 |
+
if seq_len > self._cached_len or not hasattr(self, '_cached_device') or device != self._cached_device:
|
| 113 |
+
self._build_cache(seq_len, device, dtype)
|
| 114 |
+
return self._cos[:, :, :seq_len, :], self._sin[:, :, :seq_len, :]
|
|
|
|
| 115 |
|
| 116 |
|
| 117 |
+
# ββ Attention βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 118 |
class Grok2Attention(nn.Module):
|
| 119 |
def __init__(self, config: Grok2Config):
|
| 120 |
super().__init__()
|
|
|
|
| 132 |
|
| 133 |
def forward(self, hidden_states, attention_mask=None, **kwargs):
|
| 134 |
B, T, _ = hidden_states.shape
|
|
|
|
| 135 |
device = hidden_states.device
|
| 136 |
+
dtype = hidden_states.dtype
|
| 137 |
|
| 138 |
q = self.q_proj(hidden_states).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
|
| 139 |
k = self.k_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| 140 |
v = self.v_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| 141 |
|
| 142 |
+
cos, sin = self.rotary_emb(T, device, dtype)
|
| 143 |
+
cos = cos[:, :, :T, :self.head_dim]
|
| 144 |
+
sin = sin[:, :, :T, :self.head_dim]
|
| 145 |
q, k = apply_rotary_emb(q, k, cos, sin)
|
| 146 |
|
| 147 |
+
# GQA expand
|
| 148 |
k = k.repeat_interleave(self.num_kv_groups, dim=1)
|
| 149 |
v = v.repeat_interleave(self.num_kv_groups, dim=1)
|
| 150 |
|
|
|
|
| 152 |
attn = torch.matmul(q, k.transpose(-2, -1)) / scale
|
| 153 |
|
| 154 |
if self.attn_softcap > 0:
|
| 155 |
+
attn = torch.tanh(attn / self.attn_softcap) * self.attn_softcap
|
|
|
|
|
|
|
| 156 |
|
| 157 |
causal = torch.triu(
|
| 158 |
torch.full((T, T), float("-inf"), device=device, dtype=dtype),
|
|
|
|
| 161 |
attn = attn + causal.unsqueeze(0).unsqueeze(0)
|
| 162 |
|
| 163 |
if attention_mask is not None:
|
| 164 |
+
attn = attn + attention_mask.to(device=device, dtype=dtype)
|
| 165 |
|
| 166 |
attn = F.softmax(attn, dim=-1).to(dtype)
|
| 167 |
out = torch.matmul(attn, v)
|
|
|
|
| 169 |
return self.o_proj(out)
|
| 170 |
|
| 171 |
|
| 172 |
+
# ββ MoE Expert ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 173 |
class Grok2Expert(nn.Module):
|
| 174 |
def __init__(self, hidden_size, moe_intermediate_size):
|
| 175 |
super().__init__()
|
|
|
|
| 181 |
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
| 182 |
|
| 183 |
|
| 184 |
+
# ββ Sparse MoE ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 185 |
class Grok2SparseMoE(nn.Module):
|
| 186 |
def __init__(self, config: Grok2Config):
|
| 187 |
super().__init__()
|
| 188 |
self.num_experts = config.num_local_experts
|
| 189 |
self.top_k = config.num_experts_per_tok
|
| 190 |
self.router_softcap = config.router_logit_softcapping
|
| 191 |
+
|
| 192 |
self.gate = nn.Linear(config.hidden_size, config.num_local_experts, bias=False)
|
| 193 |
self.experts = nn.ModuleList([
|
| 194 |
Grok2Expert(config.hidden_size, config.moe_intermediate_size)
|
|
|
|
| 197 |
|
| 198 |
def forward(self, x):
|
| 199 |
B, T, H = x.shape
|
| 200 |
+
device = x.device
|
| 201 |
+
dtype = x.dtype
|
| 202 |
x_flat = x.view(-1, H)
|
|
|
|
| 203 |
|
| 204 |
router_logits = self.gate(x_flat)
|
| 205 |
if self.router_softcap > 0:
|
| 206 |
router_logits = torch.tanh(router_logits / self.router_softcap) * self.router_softcap
|
| 207 |
|
| 208 |
+
router_weights = F.softmax(router_logits, dim=-1)
|
| 209 |
top_weights, top_indices = router_weights.topk(self.top_k, dim=-1)
|
| 210 |
top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)
|
| 211 |
|
|
|
|
| 217 |
mask = (expert_ids == e)
|
| 218 |
if not mask.any():
|
| 219 |
continue
|
| 220 |
+
# Move tokens to expert's device, compute, move result back
|
| 221 |
expert_device = next(self.experts[e].parameters()).device
|
| 222 |
+
x_masked = x_flat[mask].to(device=expert_device, dtype=dtype)
|
| 223 |
+
expert_out = self.experts[e](x_masked).to(device=device, dtype=dtype)
|
| 224 |
+
out[mask] += weights[mask] * expert_out
|
|
|
|
| 225 |
|
| 226 |
return out.view(B, T, H)
|
| 227 |
|
| 228 |
|
| 229 |
+
# ββ Dense MLP βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 230 |
class Grok2MLP(nn.Module):
|
| 231 |
def __init__(self, config: Grok2Config):
|
| 232 |
super().__init__()
|
|
|
|
| 238 |
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
| 239 |
|
| 240 |
|
| 241 |
+
# ββ Decoder Layer βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 242 |
class Grok2DecoderLayer(nn.Module):
|
| 243 |
def __init__(self, config: Grok2Config, layer_idx: int):
|
| 244 |
super().__init__()
|
|
|
|
| 252 |
self.post_moe_norm = Grok2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
| 253 |
|
| 254 |
def forward(self, hidden_states, attention_mask=None, **kwargs):
|
| 255 |
+
device = hidden_states.device
|
| 256 |
+
dtype = hidden_states.dtype
|
| 257 |
+
|
| 258 |
+
# Attention block
|
| 259 |
residual = hidden_states
|
| 260 |
hidden_states = self.pre_attn_norm(hidden_states)
|
| 261 |
hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask)
|
| 262 |
+
hidden_states = self.post_attn_norm(hidden_states.to(device=device, dtype=dtype))
|
| 263 |
+
hidden_states = residual + hidden_states.to(device=device, dtype=dtype)
|
| 264 |
|
| 265 |
+
# MoE + dense residual block
|
| 266 |
residual = hidden_states
|
| 267 |
+
normed = self.pre_moe_norm(hidden_states)
|
| 268 |
+
moe_out = self.block_sparse_moe(normed)
|
| 269 |
+
mlp_out = self.mlp(normed)
|
| 270 |
+
combined = moe_out.to(device=device, dtype=dtype) + mlp_out.to(device=device, dtype=dtype)
|
| 271 |
+
hidden_states = self.post_moe_norm(combined)
|
| 272 |
+
hidden_states = residual + hidden_states.to(device=device, dtype=dtype)
|
| 273 |
|
| 274 |
return hidden_states
|
| 275 |
|
| 276 |
|
| 277 |
+
# ββ Model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 278 |
class Grok2Model(nn.Module):
|
| 279 |
def __init__(self, config: Grok2Config):
|
| 280 |
super().__init__()
|
|
|
|
| 292 |
return self.norm(hidden_states)
|
| 293 |
|
| 294 |
|
| 295 |
+
# ββ CausalLM ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 296 |
class Grok1ForCausalLM(PreTrainedModel, GenerationMixin):
|
| 297 |
config_class = Grok2Config
|
| 298 |
base_model_prefix = "model"
|
|
|
|
| 320 |
**kwargs,
|
| 321 |
):
|
| 322 |
hidden_states = self.model(input_ids, attention_mask=attention_mask)
|
| 323 |
+
# Move to lm_head device
|
| 324 |
+
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
| 325 |
logits = self.lm_head(hidden_states) * self.output_multiplier_scale
|
| 326 |
|
| 327 |
if self.final_logit_softcapping > 0:
|
|
|
|
| 343 |
return {"input_ids": input_ids}
|
| 344 |
|
| 345 |
|
| 346 |
+
# ββ Register ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 347 |
AutoConfig.register("grok2", Grok2Config)
|
| 348 |
+
AutoModelForCausalLM.register(Grok2Config, Grok1ForCausalLM)
|