bmeyer2025 commited on
Commit
cc64e8a
Β·
verified Β·
1 Parent(s): 1e86f73

Upload src/transformer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/transformer.py +107 -0
src/transformer.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Milestone 3: MultiHeadAttention, FeedForward, and Transformer Block.
3
+
4
+ Architecture uses pre-norm (LayerNorm before attention/FFN, not after).
5
+ This is what modern models like LLaMA/Qwen do β€” it trains more stably.
6
+
7
+ Block layout:
8
+ x -> LayerNorm -> MultiHeadAttention -> + (residual) -> LayerNorm -> FeedForward -> + (residual)
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from attention import Head
16
+
17
+
18
+ class MultiHeadAttention(nn.Module):
19
+ """Multiple attention heads running in parallel, outputs concatenated and projected."""
20
+
21
+ def __init__(self, n_heads: int, head_size: int, n_embd: int, block_size: int, dropout: float):
22
+ super().__init__()
23
+ self.heads = nn.ModuleList([
24
+ Head(head_size=head_size, n_embd=n_embd, block_size=block_size, dropout=dropout)
25
+ for _ in range(n_heads)
26
+ ])
27
+ # Project concatenated heads back to n_embd
28
+ self.proj = nn.Linear(n_heads * head_size, n_embd)
29
+ self.dropout = nn.Dropout(dropout)
30
+
31
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
32
+ # Run all heads, concatenate along the last dim
33
+ out = torch.cat([h(x) for h in self.heads], dim=-1) # (B, T, n_heads * head_size)
34
+ out = self.dropout(self.proj(out)) # (B, T, n_embd)
35
+ return out
36
+
37
+
38
+ class FeedForward(nn.Module):
39
+ """Position-wise feed-forward network: Linear -> ReLU -> Linear.
40
+
41
+ Standard GPT uses a 4x expansion of n_embd in the hidden layer.
42
+ We'll swap ReLU for SwiGLU in the modernization phase.
43
+ """
44
+
45
+ def __init__(self, n_embd: int, dropout: float):
46
+ super().__init__()
47
+ self.net = nn.Sequential(
48
+ nn.Linear(n_embd, 4 * n_embd),
49
+ nn.ReLU(),
50
+ nn.Linear(4 * n_embd, n_embd),
51
+ nn.Dropout(dropout),
52
+ )
53
+
54
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
55
+ return self.net(x)
56
+
57
+
58
+ class Block(nn.Module):
59
+ """One transformer block with pre-norm architecture.
60
+
61
+ Pre-norm applies LayerNorm *before* attention/FFN (not after).
62
+ This is more stable to train than post-norm (the original Transformer).
63
+ """
64
+
65
+ def __init__(self, n_embd: int, n_heads: int, block_size: int, dropout: float):
66
+ super().__init__()
67
+ head_size = n_embd // n_heads
68
+ self.attn = MultiHeadAttention(
69
+ n_heads=n_heads,
70
+ head_size=head_size,
71
+ n_embd=n_embd,
72
+ block_size=block_size,
73
+ dropout=dropout,
74
+ )
75
+ self.ffn = FeedForward(n_embd=n_embd, dropout=dropout)
76
+ self.ln1 = nn.LayerNorm(n_embd)
77
+ self.ln2 = nn.LayerNorm(n_embd)
78
+
79
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
80
+ # Pre-norm + residual for attention
81
+ x = x + self.attn(self.ln1(x))
82
+ # Pre-norm + residual for feed-forward
83
+ x = x + self.ffn(self.ln2(x))
84
+ return x
85
+
86
+
87
+ # ── Quick sanity check ────────────────────────────────────────────────────────
88
+ if __name__ == "__main__":
89
+ from tokenizer import DEVICE, BLOCK_SIZE
90
+
91
+ n_embd = 384
92
+ n_heads = 6
93
+ dropout = 0.1
94
+ batch_size = 4
95
+
96
+ block = Block(n_embd=n_embd, n_heads=n_heads, block_size=BLOCK_SIZE, dropout=dropout).to(DEVICE)
97
+
98
+ x = torch.randn(batch_size, BLOCK_SIZE, n_embd, device=DEVICE)
99
+ out = block(x)
100
+
101
+ print(f"Input shape : {x.shape}")
102
+ print(f"Output shape : {out.shape} (expected [4, {BLOCK_SIZE}, {n_embd}])")
103
+
104
+ # Count parameters
105
+ n_params = sum(p.numel() for p in block.parameters())
106
+ print(f"Block params : {n_params:,}")
107
+ print("\nMilestone 3 OK: transformer block works.")