File size: 2,205 Bytes
7f974df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""
model/block.py

Single Transformer Block (pre-norm LLaMA-style).

Pre-Norm vs Post-Norm:
    GPT-2 (post-norm):  x = x + Attention(LayerNorm(x))   <- less stable
    LLaMA (pre-norm):   x = LayerNorm(x); x = x + Attention(x)  <- more stable

    We use PRE-NORM with RMSNorm for training stability at scale.

Block structure:
    x  ->  RMSNorm  ->  CausalSelfAttention  ->  (+residual)
       ->  RMSNorm  ->  SwiGLU MLP            ->  (+residual)
       ->  output

Note: Residual connections bypass both norm and sublayer, which allows
gradients to flow directly to earlier layers during backprop.
"""

import torch
import torch.nn as nn

from model.config    import ModelConfig
from model.norm      import RMSNorm
from model.attention import CausalSelfAttention
from model.mlp       import SwiGLU


class TransformerBlock(nn.Module):

    def __init__(self, config: ModelConfig):
        super().__init__()

        # Pre-attention norm
        self.norm_attn = RMSNorm(config.d_model)

        # Causal self-attention with RoPE
        self.attn      = CausalSelfAttention(config)

        # Pre-FFN norm
        self.norm_mlp  = RMSNorm(config.d_model)

        # SwiGLU feed-forward
        self.mlp       = SwiGLU(config)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x : (B, T, d_model)

        Returns:
            x : (B, T, d_model)
        """
        # Attention sub-layer with residual
        x = x + self.attn(self.norm_attn(x))

        # FFN sub-layer with residual
        x = x + self.mlp(self.norm_mlp(x))

        return x


# ------------------------------------------------------------------ #
#  QUICK CHECK
# ------------------------------------------------------------------ #

if __name__ == "__main__":
    from model.config import SLLM_100M

    cfg   = SLLM_100M
    block = TransformerBlock(cfg)

    n = sum(p.numel() for p in block.parameters())
    print(f"Block params : {n/1e6:.3f}M")

    B, T = 2, 64
    x   = torch.randn(B, T, cfg.d_model)
    out = block(x)

    print(f"Input  shape : {x.shape}")
    print(f"Output shape : {out.shape}")
    assert out.shape == x.shape, "Shape mismatch!"
    print("PASS")