| """Transformer Block (a single layer).""" |
|
|
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from llm_lab.config import ModelConfig |
| from .norm import RMSNorm |
| from .attention import GroupedQueryAttention |
| from .feedforward import SwiGLUFeedForward |
|
|
|
|
| class TransformerBlock(nn.Module): |
| """A single Transformer decoder block. |
| |
| Structure (Pre-Norm style): |
| x β RMSNorm β Attention β + (residual) β RMSNorm β FFN β + (residual) β out |
| |
| Pre-Norm vs Post-Norm: |
| - Post-Norm (original Transformer): LayerNorm applied after the residual |
| β training instability in deep models |
| - Pre-Norm (standard since GPT-2): LayerNorm applied before the sublayer |
| β smooth gradient flow, stable training |
| |
| Role of Residual Connection: |
| - Adds the input to the output β a "highway" that lets gradients skip layers |
| - The key reason training is feasible even with 22 stacked layers |
| """ |
|
|
| def __init__(self, config: ModelConfig, layer_idx: int): |
| super().__init__() |
| self.layer_idx = layer_idx |
|
|
| |
| self.attn_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps) |
| |
| self.attention = GroupedQueryAttention(config) |
|
|
| |
| self.ffn_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps) |
| |
| self.feed_forward = SwiGLUFeedForward(config) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| mask: Optional[torch.Tensor] = None, |
| position_offset: int = 0, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| x: (batch_size, seq_len, hidden_dim) |
| Returns: |
| (batch_size, seq_len, hidden_dim) |
| """ |
| |
| |
| hidden_states = x + self.attention(self.attn_norm(x), mask, position_offset) |
|
|
| |
| |
| out = hidden_states + self.feed_forward(self.ffn_norm(hidden_states)) |
|
|
| return out |
|
|