sllm / model /block.py
geeteshcodes's picture
Initial commit
7f974df verified
"""
model/block.py
Single Transformer Block (pre-norm LLaMA-style).
Pre-Norm vs Post-Norm:
GPT-2 (post-norm): x = x + Attention(LayerNorm(x)) <- less stable
LLaMA (pre-norm): x = LayerNorm(x); x = x + Attention(x) <- more stable
We use PRE-NORM with RMSNorm for training stability at scale.
Block structure:
x -> RMSNorm -> CausalSelfAttention -> (+residual)
-> RMSNorm -> SwiGLU MLP -> (+residual)
-> output
Note: Residual connections bypass both norm and sublayer, which allows
gradients to flow directly to earlier layers during backprop.
"""
import torch
import torch.nn as nn
from model.config import ModelConfig
from model.norm import RMSNorm
from model.attention import CausalSelfAttention
from model.mlp import SwiGLU
class TransformerBlock(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
# Pre-attention norm
self.norm_attn = RMSNorm(config.d_model)
# Causal self-attention with RoPE
self.attn = CausalSelfAttention(config)
# Pre-FFN norm
self.norm_mlp = RMSNorm(config.d_model)
# SwiGLU feed-forward
self.mlp = SwiGLU(config)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x : (B, T, d_model)
Returns:
x : (B, T, d_model)
"""
# Attention sub-layer with residual
x = x + self.attn(self.norm_attn(x))
# FFN sub-layer with residual
x = x + self.mlp(self.norm_mlp(x))
return x
# ------------------------------------------------------------------ #
# QUICK CHECK
# ------------------------------------------------------------------ #
if __name__ == "__main__":
from model.config import SLLM_100M
cfg = SLLM_100M
block = TransformerBlock(cfg)
n = sum(p.numel() for p in block.parameters())
print(f"Block params : {n/1e6:.3f}M")
B, T = 2, 64
x = torch.randn(B, T, cfg.d_model)
out = block(x)
print(f"Input shape : {x.shape}")
print(f"Output shape : {out.shape}")
assert out.shape == x.shape, "Shape mismatch!"
print("PASS")