[hotfix] update gqa impl
Browse files- modeling_grok1.py +20 -0
modeling_grok1.py
CHANGED
|
@@ -74,6 +74,21 @@ def load_balancing_loss_func(
|
|
| 74 |
) * (num_experts**2)
|
| 75 |
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
class RMSNorm(nn.Module):
|
| 78 |
def __init__(
|
| 79 |
self,
|
|
@@ -194,6 +209,7 @@ class MultiHeadAttention(nn.Module):
|
|
| 194 |
if num_key_value_heads is None:
|
| 195 |
num_key_value_heads = num_heads
|
| 196 |
self.num_key_value_heads = num_key_value_heads
|
|
|
|
| 197 |
self.attn_output_multiplier = attn_output_multiplier
|
| 198 |
self.max_attn_val = max_attn_val
|
| 199 |
|
|
@@ -259,6 +275,10 @@ class MultiHeadAttention(nn.Module):
|
|
| 259 |
|
| 260 |
past_key_value = (key_states, value_states) if use_cache else None
|
| 261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)).to(
|
| 263 |
torch.float
|
| 264 |
)
|
|
|
|
| 74 |
) * (num_experts**2)
|
| 75 |
|
| 76 |
|
| 77 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
| 78 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 79 |
+
"""
|
| 80 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 81 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 82 |
+
"""
|
| 83 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 84 |
+
if n_rep == 1:
|
| 85 |
+
return hidden_states
|
| 86 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(
|
| 87 |
+
batch, num_key_value_heads, n_rep, slen, head_dim
|
| 88 |
+
)
|
| 89 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
class RMSNorm(nn.Module):
|
| 93 |
def __init__(
|
| 94 |
self,
|
|
|
|
| 209 |
if num_key_value_heads is None:
|
| 210 |
num_key_value_heads = num_heads
|
| 211 |
self.num_key_value_heads = num_key_value_heads
|
| 212 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 213 |
self.attn_output_multiplier = attn_output_multiplier
|
| 214 |
self.max_attn_val = max_attn_val
|
| 215 |
|
|
|
|
| 275 |
|
| 276 |
past_key_value = (key_states, value_states) if use_cache else None
|
| 277 |
|
| 278 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
| 279 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 280 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 281 |
+
|
| 282 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)).to(
|
| 283 |
torch.float
|
| 284 |
)
|