File size: 3,327 Bytes
2eca14b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
Transformer Block for FrawdLLM.

A transformer block combines:
1. Multi-head self-attention (tokens gather info from each other)
2. MLP (each token processes info independently)

With two important additions:
- LayerNorm: Keeps values stable during training
- Residual connections: Add input to output ("don't lose what you had")

Structure (Pre-LN, which is more stable):

    Input
      ↓
    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
    β”‚  LayerNorm  β”‚
    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
      ↓
    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
    β”‚  Attention  │───────┐
    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜       β”‚ (residual)
      ↓                   β”‚
      + β†β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
      ↓
    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
    β”‚  LayerNorm  β”‚
    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
      ↓
    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
    β”‚     MLP     │───────┐
    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜       β”‚ (residual)
      ↓                   β”‚
      + β†β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
      ↓
    Output
"""

import torch
import torch.nn as nn

from .config import ModelConfig
from .attention import CausalSelfAttention
from .mlp import MLP


class TransformerBlock(nn.Module):
    """
    One transformer block = Attention + MLP with norms and residuals.

    Input:  [batch_size, seq_len, n_embd]
    Output: [batch_size, seq_len, n_embd]
    """

    def __init__(self, config: ModelConfig):
        super().__init__()

        self.config = config

        # Layer norms (one before attention, one before MLP)
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)

        # Attention and MLP
        self.attn = CausalSelfAttention(config)
        self.mlp = MLP(config)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Apply transformer block.

        Args:
            x: [batch_size, seq_len, n_embd]

        Returns:
            [batch_size, seq_len, n_embd]
        """
        # Attention with residual connection
        # x + attention(norm(x))
        # "Keep x, add attention's contribution"
        x = x + self.attn(self.ln1(x))

        # MLP with residual connection
        # x + mlp(norm(x))
        # "Keep x, add MLP's contribution"
        x = x + self.mlp(self.ln2(x))

        return x


if __name__ == "__main__":
    # Test the transformer block
    from .config import get_config

    print("Testing TransformerBlock...")
    print("=" * 50)

    config = get_config("tiny")
    print(f"Config: n_embd={config.n_embd}, n_head={config.n_head}, "
          f"n_layer={config.n_layer}")

    block = TransformerBlock(config)

    # Count parameters
    num_params = sum(p.numel() for p in block.parameters())
    print(f"Block parameters: {num_params:,}")

    # Test input: [batch=2, seq=8, n_embd=256]
    x = torch.randn(2, 8, config.n_embd)
    print(f"\nInput shape: {x.shape}")

    # Forward pass
    out = block(x)
    print(f"Output shape: {out.shape}")

    # Verify shapes match
    assert x.shape == out.shape, "Input and output shapes should match!"
    print("\nTransformerBlock working!")