Vjeong Claude Sonnet 4.6 commited on
Commit
81a9145
Β·
1 Parent(s): 99c1b85

refactor(model): replace single-letter vars with descriptive names for readability

Browse files

Rename B/S β†’ batch_size/seq_len and h β†’ hidden_states across attention,
transformer_block, and llm_model modules. No functional changes.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

llm_lab/model/attention.py CHANGED
@@ -69,21 +69,21 @@ class GroupedQueryAttention(nn.Module):
69
  Returns:
70
  (batch_size, seq_len, hidden_dim)
71
  """
72
- B, S, _ = x.shape
73
 
74
  # ──────────────────────────────────────────────
75
  # Step 1: Q, K, V projections
76
  # ──────────────────────────────────────────────
77
- q = self.q_proj(x) # (B, S, num_heads Γ— head_dim)
78
- k = self.k_proj(x) # (B, S, num_kv_heads Γ— head_dim)
79
- v = self.v_proj(x) # (B, S, num_kv_heads Γ— head_dim)
80
 
81
  # Reshape into multi-head form
82
- q = q.view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
83
- # β†’ (B, num_heads, S, head_dim)
84
- k = k.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
85
- # β†’ (B, num_kv_heads, S, head_dim)
86
- v = v.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
87
 
88
  # ──────────────────────────────────────────────
89
  # Step 2: Apply RoPE (to Q and K only! Not to V)
@@ -97,7 +97,7 @@ class GroupedQueryAttention(nn.Module):
97
  # ──────────────────────────────────────────────
98
  # num_kv_heads=4 β†’ num_heads=16: repeat each KV 4 times
99
  if self.num_kv_groups > 1:
100
- k = self._repeat_kv(k) # (B, num_heads, S, head_dim)
101
  v = self._repeat_kv(v)
102
 
103
  # ──────────────────────────────────────────────
@@ -110,25 +110,25 @@ class GroupedQueryAttention(nn.Module):
110
  dropout_p=self.config.dropout if self.training else 0.0,
111
  is_causal=(mask is None), # apply automatic causal masking when no mask is provided
112
  )
113
- # β†’ (B, num_heads, S, head_dim)
114
 
115
  # ──────────────────────────────────────────────
116
  # Step 5: Merge heads + output projection
117
  # ──────────────────────────────────────────────
118
- attn_out = attn_out.transpose(1, 2).contiguous().view(B, S, -1)
119
- # β†’ (B, S, num_heads Γ— head_dim)
120
 
121
- return self.o_proj(attn_out) # β†’ (B, S, hidden_dim)
122
 
123
  def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
124
  """Repeat KV heads to match the number of Q heads.
125
 
126
- (B, num_kv_heads, S, head_dim) β†’ (B, num_heads, S, head_dim)
127
 
128
  Example: num_kv_heads=4, num_kv_groups=4
129
  [kv0, kv1, kv2, kv3] β†’ [kv0,kv0,kv0,kv0, kv1,kv1,kv1,kv1, ...]
130
  """
131
- B, H_kv, S, D = x.shape
132
- x = x[:, :, None, :, :] # (B, H_kv, 1, S, D)
133
- x = x.expand(B, H_kv, self.num_kv_groups, S, D) # (B, H_kv, groups, S, D)
134
- return x.reshape(B, self.num_heads, S, D)
 
69
  Returns:
70
  (batch_size, seq_len, hidden_dim)
71
  """
72
+ batch_size, seq_len, _ = x.shape
73
 
74
  # ──────────────────────────────────────────────
75
  # Step 1: Q, K, V projections
76
  # ──────────────────────────────────────────────
77
+ q = self.q_proj(x) # (batch_size, seq_len, num_heads Γ— head_dim)
78
+ k = self.k_proj(x) # (batch_size, seq_len, num_kv_heads Γ— head_dim)
79
+ v = self.v_proj(x) # (batch_size, seq_len, num_kv_heads Γ— head_dim)
80
 
81
  # Reshape into multi-head form
82
+ q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
83
+ # β†’ (batch_size, num_heads, seq_len, head_dim)
84
+ k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
85
+ # β†’ (batch_size, num_kv_heads, seq_len, head_dim)
86
+ v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
87
 
88
  # ──────────────────────────────────────────────
89
  # Step 2: Apply RoPE (to Q and K only! Not to V)
 
97
  # ──────────────────────────────────────────────
98
  # num_kv_heads=4 β†’ num_heads=16: repeat each KV 4 times
99
  if self.num_kv_groups > 1:
100
+ k = self._repeat_kv(k) # (batch_size, num_heads, seq_len, head_dim)
101
  v = self._repeat_kv(v)
102
 
103
  # ──────────────────────────────────────────────
 
110
  dropout_p=self.config.dropout if self.training else 0.0,
111
  is_causal=(mask is None), # apply automatic causal masking when no mask is provided
112
  )
113
+ # β†’ (batch_size, num_heads, seq_len, head_dim)
114
 
115
  # ──────────────────────────────────────────────
116
  # Step 5: Merge heads + output projection
117
  # ──────────────────────────────────────────────
118
+ attn_out = attn_out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
119
+ # β†’ (batch_size, seq_len, num_heads Γ— head_dim)
120
 
121
+ return self.o_proj(attn_out) # β†’ (batch_size, seq_len, hidden_dim)
122
 
123
  def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
124
  """Repeat KV heads to match the number of Q heads.
125
 
126
+ (batch_size, num_kv_heads, seq_len, head_dim) β†’ (batch_size, num_heads, seq_len, head_dim)
127
 
128
  Example: num_kv_heads=4, num_kv_groups=4
129
  [kv0, kv1, kv2, kv3] β†’ [kv0,kv0,kv0,kv0, kv1,kv1,kv1,kv1, ...]
130
  """
131
+ batch_size, num_kv_heads, seq_len, head_dim = x.shape
132
+ x = x[:, :, None, :, :] # (batch_size, num_kv_heads, 1, seq_len, head_dim)
133
+ x = x.expand(batch_size, num_kv_heads, self.num_kv_groups, seq_len, head_dim)
134
+ return x.reshape(batch_size, self.num_heads, seq_len, head_dim)
llm_lab/model/llm_model.py CHANGED
@@ -97,11 +97,11 @@ class LLMModel(nn.Module):
97
  logits: (batch_size, seq_len, vocab_size)
98
  loss: scalar (when targets are provided) or None
99
  """
100
- B, S = input_ids.shape
101
 
102
  # ── Step 1: Token Embedding ──
103
  # Convert each token ID into a vector of dimension hidden_dim
104
- h = self.token_embedding(input_ids) # (B, S, hidden_dim)
105
 
106
  # ── Step 2: Transformer Blocks ──
107
  # Activation Checkpointing: saves memory during training
@@ -109,18 +109,18 @@ class LLMModel(nn.Module):
109
  for layer in self.layers:
110
  if self.training and torch.is_grad_enabled():
111
  # Apply Activation Checkpointing
112
- h = torch.utils.checkpoint.checkpoint(
113
- layer, h, None, position_offset,
114
  use_reentrant=False, # recommended for PyTorch >= 2.0
115
  )
116
  else:
117
- h = layer(h, mask=None, position_offset=position_offset)
118
 
119
  # ── Step 3: Final normalization ──
120
- h = self.final_norm(h)
121
 
122
  # ── Step 4: Compute output logits ──
123
- logits = self.lm_head(h) # (B, S, vocab_size)
124
 
125
  # ── Step 5: Compute loss (during training) ──
126
  loss = None
 
97
  logits: (batch_size, seq_len, vocab_size)
98
  loss: scalar (when targets are provided) or None
99
  """
100
+ batch_size, seq_len = input_ids.shape
101
 
102
  # ── Step 1: Token Embedding ──
103
  # Convert each token ID into a vector of dimension hidden_dim
104
+ hidden_states = self.token_embedding(input_ids) # (batch_size, seq_len, hidden_dim)
105
 
106
  # ── Step 2: Transformer Blocks ──
107
  # Activation Checkpointing: saves memory during training
 
109
  for layer in self.layers:
110
  if self.training and torch.is_grad_enabled():
111
  # Apply Activation Checkpointing
112
+ hidden_states = torch.utils.checkpoint.checkpoint(
113
+ layer, hidden_states, None, position_offset,
114
  use_reentrant=False, # recommended for PyTorch >= 2.0
115
  )
116
  else:
117
+ hidden_states = layer(hidden_states, mask=None, position_offset=position_offset)
118
 
119
  # ── Step 3: Final normalization ──
120
+ hidden_states = self.final_norm(hidden_states)
121
 
122
  # ── Step 4: Compute output logits ──
123
+ logits = self.lm_head(hidden_states) # (batch_size, seq_len, vocab_size)
124
 
125
  # ── Step 5: Compute loss (during training) ──
126
  loss = None
llm_lab/model/transformer_block.py CHANGED
@@ -56,10 +56,10 @@ class TransformerBlock(nn.Module):
56
  """
57
  # ── Attention sublayer with residual ──
58
  # h = x + Attention(RMSNorm(x))
59
- h = x + self.attention(self.attn_norm(x), mask, position_offset)
60
 
61
  # ── FFN sublayer with residual ──
62
  # out = h + FFN(RMSNorm(h))
63
- out = h + self.feed_forward(self.ffn_norm(h))
64
 
65
  return out
 
56
  """
57
  # ── Attention sublayer with residual ──
58
  # h = x + Attention(RMSNorm(x))
59
+ hidden_states = x + self.attention(self.attn_norm(x), mask, position_offset)
60
 
61
  # ── FFN sublayer with residual ──
62
  # out = h + FFN(RMSNorm(h))
63
+ out = hidden_states + self.feed_forward(self.ffn_norm(hidden_states))
64
 
65
  return out