LLM-1B-Lab / llm_lab /model /transformer_block.py
Vjeong's picture
refactor(model): replace single-letter vars with descriptive names for readability
81a9145
"""Transformer Block (a single layer)."""
from typing import Optional
import torch
import torch.nn as nn
from llm_lab.config import ModelConfig
from .norm import RMSNorm
from .attention import GroupedQueryAttention
from .feedforward import SwiGLUFeedForward
class TransformerBlock(nn.Module):
"""A single Transformer decoder block.
Structure (Pre-Norm style):
x β†’ RMSNorm β†’ Attention β†’ + (residual) β†’ RMSNorm β†’ FFN β†’ + (residual) β†’ out
Pre-Norm vs Post-Norm:
- Post-Norm (original Transformer): LayerNorm applied after the residual
β†’ training instability in deep models
- Pre-Norm (standard since GPT-2): LayerNorm applied before the sublayer
β†’ smooth gradient flow, stable training
Role of Residual Connection:
- Adds the input to the output β†’ a "highway" that lets gradients skip layers
- The key reason training is feasible even with 22 stacked layers
"""
def __init__(self, config: ModelConfig, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
# Pre-Norm: normalization before Attention
self.attn_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
# Self-Attention
self.attention = GroupedQueryAttention(config)
# Pre-Norm: normalization before FFN
self.ffn_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
# Feed-Forward Network
self.feed_forward = SwiGLUFeedForward(config)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
position_offset: int = 0,
) -> torch.Tensor:
"""
Args:
x: (batch_size, seq_len, hidden_dim)
Returns:
(batch_size, seq_len, hidden_dim)
"""
# ── Attention sublayer with residual ──
# h = x + Attention(RMSNorm(x))
hidden_states = x + self.attention(self.attn_norm(x), mask, position_offset)
# ── FFN sublayer with residual ──
# out = h + FFN(RMSNorm(h))
out = hidden_states + self.feed_forward(self.ffn_norm(hidden_states))
return out