refactor(model): replace single-letter vars with descriptive names for readability
Browse filesRename B/S β batch_size/seq_len and h β hidden_states across attention,
transformer_block, and llm_model modules. No functional changes.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- llm_lab/model/attention.py +19 -19
- llm_lab/model/llm_model.py +7 -7
- llm_lab/model/transformer_block.py +2 -2
llm_lab/model/attention.py
CHANGED
|
@@ -69,21 +69,21 @@ class GroupedQueryAttention(nn.Module):
|
|
| 69 |
Returns:
|
| 70 |
(batch_size, seq_len, hidden_dim)
|
| 71 |
"""
|
| 72 |
-
|
| 73 |
|
| 74 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 75 |
# Step 1: Q, K, V projections
|
| 76 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 77 |
-
q = self.q_proj(x) # (
|
| 78 |
-
k = self.k_proj(x) # (
|
| 79 |
-
v = self.v_proj(x) # (
|
| 80 |
|
| 81 |
# Reshape into multi-head form
|
| 82 |
-
q = q.view(
|
| 83 |
-
# β (
|
| 84 |
-
k = k.view(
|
| 85 |
-
# β (
|
| 86 |
-
v = v.view(
|
| 87 |
|
| 88 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 89 |
# Step 2: Apply RoPE (to Q and K only! Not to V)
|
|
@@ -97,7 +97,7 @@ class GroupedQueryAttention(nn.Module):
|
|
| 97 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 98 |
# num_kv_heads=4 β num_heads=16: repeat each KV 4 times
|
| 99 |
if self.num_kv_groups > 1:
|
| 100 |
-
k = self._repeat_kv(k) # (
|
| 101 |
v = self._repeat_kv(v)
|
| 102 |
|
| 103 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -110,25 +110,25 @@ class GroupedQueryAttention(nn.Module):
|
|
| 110 |
dropout_p=self.config.dropout if self.training else 0.0,
|
| 111 |
is_causal=(mask is None), # apply automatic causal masking when no mask is provided
|
| 112 |
)
|
| 113 |
-
# β (
|
| 114 |
|
| 115 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 116 |
# Step 5: Merge heads + output projection
|
| 117 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 118 |
-
attn_out = attn_out.transpose(1, 2).contiguous().view(
|
| 119 |
-
# β (
|
| 120 |
|
| 121 |
-
return self.o_proj(attn_out) # β (
|
| 122 |
|
| 123 |
def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
|
| 124 |
"""Repeat KV heads to match the number of Q heads.
|
| 125 |
|
| 126 |
-
(
|
| 127 |
|
| 128 |
Example: num_kv_heads=4, num_kv_groups=4
|
| 129 |
[kv0, kv1, kv2, kv3] β [kv0,kv0,kv0,kv0, kv1,kv1,kv1,kv1, ...]
|
| 130 |
"""
|
| 131 |
-
|
| 132 |
-
x = x[:, :, None, :, :] # (
|
| 133 |
-
x = x.expand(
|
| 134 |
-
return x.reshape(
|
|
|
|
| 69 |
Returns:
|
| 70 |
(batch_size, seq_len, hidden_dim)
|
| 71 |
"""
|
| 72 |
+
batch_size, seq_len, _ = x.shape
|
| 73 |
|
| 74 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 75 |
# Step 1: Q, K, V projections
|
| 76 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 77 |
+
q = self.q_proj(x) # (batch_size, seq_len, num_heads Γ head_dim)
|
| 78 |
+
k = self.k_proj(x) # (batch_size, seq_len, num_kv_heads Γ head_dim)
|
| 79 |
+
v = self.v_proj(x) # (batch_size, seq_len, num_kv_heads Γ head_dim)
|
| 80 |
|
| 81 |
# Reshape into multi-head form
|
| 82 |
+
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 83 |
+
# β (batch_size, num_heads, seq_len, head_dim)
|
| 84 |
+
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| 85 |
+
# β (batch_size, num_kv_heads, seq_len, head_dim)
|
| 86 |
+
v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| 87 |
|
| 88 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 89 |
# Step 2: Apply RoPE (to Q and K only! Not to V)
|
|
|
|
| 97 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 98 |
# num_kv_heads=4 β num_heads=16: repeat each KV 4 times
|
| 99 |
if self.num_kv_groups > 1:
|
| 100 |
+
k = self._repeat_kv(k) # (batch_size, num_heads, seq_len, head_dim)
|
| 101 |
v = self._repeat_kv(v)
|
| 102 |
|
| 103 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 110 |
dropout_p=self.config.dropout if self.training else 0.0,
|
| 111 |
is_causal=(mask is None), # apply automatic causal masking when no mask is provided
|
| 112 |
)
|
| 113 |
+
# β (batch_size, num_heads, seq_len, head_dim)
|
| 114 |
|
| 115 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 116 |
# Step 5: Merge heads + output projection
|
| 117 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 118 |
+
attn_out = attn_out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
|
| 119 |
+
# β (batch_size, seq_len, num_heads Γ head_dim)
|
| 120 |
|
| 121 |
+
return self.o_proj(attn_out) # β (batch_size, seq_len, hidden_dim)
|
| 122 |
|
| 123 |
def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
|
| 124 |
"""Repeat KV heads to match the number of Q heads.
|
| 125 |
|
| 126 |
+
(batch_size, num_kv_heads, seq_len, head_dim) β (batch_size, num_heads, seq_len, head_dim)
|
| 127 |
|
| 128 |
Example: num_kv_heads=4, num_kv_groups=4
|
| 129 |
[kv0, kv1, kv2, kv3] β [kv0,kv0,kv0,kv0, kv1,kv1,kv1,kv1, ...]
|
| 130 |
"""
|
| 131 |
+
batch_size, num_kv_heads, seq_len, head_dim = x.shape
|
| 132 |
+
x = x[:, :, None, :, :] # (batch_size, num_kv_heads, 1, seq_len, head_dim)
|
| 133 |
+
x = x.expand(batch_size, num_kv_heads, self.num_kv_groups, seq_len, head_dim)
|
| 134 |
+
return x.reshape(batch_size, self.num_heads, seq_len, head_dim)
|
llm_lab/model/llm_model.py
CHANGED
|
@@ -97,11 +97,11 @@ class LLMModel(nn.Module):
|
|
| 97 |
logits: (batch_size, seq_len, vocab_size)
|
| 98 |
loss: scalar (when targets are provided) or None
|
| 99 |
"""
|
| 100 |
-
|
| 101 |
|
| 102 |
# ββ Step 1: Token Embedding ββ
|
| 103 |
# Convert each token ID into a vector of dimension hidden_dim
|
| 104 |
-
|
| 105 |
|
| 106 |
# ββ Step 2: Transformer Blocks ββ
|
| 107 |
# Activation Checkpointing: saves memory during training
|
|
@@ -109,18 +109,18 @@ class LLMModel(nn.Module):
|
|
| 109 |
for layer in self.layers:
|
| 110 |
if self.training and torch.is_grad_enabled():
|
| 111 |
# Apply Activation Checkpointing
|
| 112 |
-
|
| 113 |
-
layer,
|
| 114 |
use_reentrant=False, # recommended for PyTorch >= 2.0
|
| 115 |
)
|
| 116 |
else:
|
| 117 |
-
|
| 118 |
|
| 119 |
# ββ Step 3: Final normalization ββ
|
| 120 |
-
|
| 121 |
|
| 122 |
# ββ Step 4: Compute output logits ββ
|
| 123 |
-
logits = self.lm_head(
|
| 124 |
|
| 125 |
# ββ Step 5: Compute loss (during training) ββ
|
| 126 |
loss = None
|
|
|
|
| 97 |
logits: (batch_size, seq_len, vocab_size)
|
| 98 |
loss: scalar (when targets are provided) or None
|
| 99 |
"""
|
| 100 |
+
batch_size, seq_len = input_ids.shape
|
| 101 |
|
| 102 |
# ββ Step 1: Token Embedding ββ
|
| 103 |
# Convert each token ID into a vector of dimension hidden_dim
|
| 104 |
+
hidden_states = self.token_embedding(input_ids) # (batch_size, seq_len, hidden_dim)
|
| 105 |
|
| 106 |
# ββ Step 2: Transformer Blocks ββ
|
| 107 |
# Activation Checkpointing: saves memory during training
|
|
|
|
| 109 |
for layer in self.layers:
|
| 110 |
if self.training and torch.is_grad_enabled():
|
| 111 |
# Apply Activation Checkpointing
|
| 112 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 113 |
+
layer, hidden_states, None, position_offset,
|
| 114 |
use_reentrant=False, # recommended for PyTorch >= 2.0
|
| 115 |
)
|
| 116 |
else:
|
| 117 |
+
hidden_states = layer(hidden_states, mask=None, position_offset=position_offset)
|
| 118 |
|
| 119 |
# ββ Step 3: Final normalization ββ
|
| 120 |
+
hidden_states = self.final_norm(hidden_states)
|
| 121 |
|
| 122 |
# ββ Step 4: Compute output logits ββ
|
| 123 |
+
logits = self.lm_head(hidden_states) # (batch_size, seq_len, vocab_size)
|
| 124 |
|
| 125 |
# ββ Step 5: Compute loss (during training) ββ
|
| 126 |
loss = None
|
llm_lab/model/transformer_block.py
CHANGED
|
@@ -56,10 +56,10 @@ class TransformerBlock(nn.Module):
|
|
| 56 |
"""
|
| 57 |
# ββ Attention sublayer with residual ββ
|
| 58 |
# h = x + Attention(RMSNorm(x))
|
| 59 |
-
|
| 60 |
|
| 61 |
# ββ FFN sublayer with residual ββ
|
| 62 |
# out = h + FFN(RMSNorm(h))
|
| 63 |
-
out =
|
| 64 |
|
| 65 |
return out
|
|
|
|
| 56 |
"""
|
| 57 |
# ββ Attention sublayer with residual ββ
|
| 58 |
# h = x + Attention(RMSNorm(x))
|
| 59 |
+
hidden_states = x + self.attention(self.attn_norm(x), mask, position_offset)
|
| 60 |
|
| 61 |
# ββ FFN sublayer with residual ββ
|
| 62 |
# out = h + FFN(RMSNorm(h))
|
| 63 |
+
out = hidden_states + self.feed_forward(self.ffn_norm(hidden_states))
|
| 64 |
|
| 65 |
return out
|