TheCodeKat commited on
Commit
5129187
·
verified ·
1 Parent(s): d88c88b

Update model/transformer_explained.py

Browse files
Files changed (1) hide show
  1. model/transformer_explained.py +199 -0
model/transformer_explained.py CHANGED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model/transformer_explained.py
2
+ """
3
+ Tiny Transformer language model (educational).
4
+ Components:
5
+ - PositionalEncoding: sinusoidal positional encodings (buffered)
6
+ - MultiHeadSelfAttention: returns attn weights optionally
7
+ - FeedForward: MLP with GELU
8
+ - TransformerBlock: attention + add&norm + FFN + add&norm
9
+ - TinyTransformerLM: token embeddings, pos enc, stacked blocks, LM head
10
+ """
11
+
12
+ import math
13
+ from typing import Optional, Tuple
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+
20
+ class PositionalEncoding(nn.Module):
21
+ """Sinusoidal positional encoding as in "Attention is All You Need".
22
+ Stored as a buffer (not learned). Adds positional encodings to token embeddings.
23
+ """
24
+
25
+ def __init__(self, d_model: int, max_len: int = 2048):
26
+ super().__init__()
27
+ pe = torch.zeros(max_len, d_model) # (max_len, d_model)
28
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # (max_len, 1)
29
+ div_term = torch.exp(
30
+ torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
31
+ ) # (d_model/2,)
32
+ pe[:, 0::2] = torch.sin(position * div_term)
33
+ pe[:, 1::2] = torch.cos(position * div_term)
34
+ pe = pe.unsqueeze(0) # (1, max_len, d_model)
35
+ self.register_buffer("pe", pe) # not a parameter
36
+
37
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
38
+ """
39
+ x: (batch, seq_len, d_model)
40
+ returns: x + pe[:, :seq_len, :]
41
+ """
42
+ seq_len = x.size(1)
43
+ return x + self.pe[:, :seq_len, :].to(x.device)
44
+
45
+
46
+ class MultiHeadSelfAttention(nn.Module):
47
+ """
48
+ Multi-head self-attention.
49
+ Optionally returns attention weights for visualization.
50
+
51
+ Input shapes:
52
+ x: (batch, seq_len, d_model)
53
+ Output:
54
+ out: (batch, seq_len, d_model)
55
+ Optional:
56
+ attn: (batch, num_heads, seq_len, seq_len)
57
+ """
58
+
59
+ def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0):
60
+ super().__init__()
61
+ assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
62
+ self.d_model = d_model
63
+ self.num_heads = num_heads
64
+ self.d_k = d_model // num_heads
65
+ # single linear for qkv then split
66
+ self.qkv_proj = nn.Linear(d_model, d_model * 3, bias=False)
67
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
68
+ self.attn_dropout = nn.Dropout(dropout)
69
+ self.softmax = nn.Softmax(dim=-1)
70
+
71
+ def forward(
72
+ self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, return_attn: bool = False
73
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
74
+ """
75
+ x: (batch, seq_len, d_model)
76
+ mask: (batch, 1, seq_len, seq_len) or (batch, seq_len) causal mask etc.
77
+ return_attn: if True, also return attention weights
78
+ """
79
+ B, S, D = x.shape
80
+ # project and split into q,k,v
81
+ qkv = self.qkv_proj(x) # (B, S, 3*D)
82
+ qkv = qkv.view(B, S, 3, self.num_heads, self.d_k)
83
+ q, k, v = qkv.unbind(dim=2) # each: (B, S, num_heads, d_k)
84
+
85
+ # transpose to (B, num_heads, S, d_k)
86
+ q = q.transpose(1, 2)
87
+ k = k.transpose(1, 2)
88
+ v = v.transpose(1, 2)
89
+
90
+ # scaled dot-product attention
91
+ # attn_scores: (B, num_heads, S, S)
92
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
93
+
94
+ if mask is not None:
95
+ # mask should be broadcastable to (B, num_heads, S, S)
96
+ attn_scores = attn_scores.masked_fill(mask == 0, float("-inf"))
97
+
98
+ attn = self.softmax(attn_scores) # (B, num_heads, S, S)
99
+ attn = self.attn_dropout(attn)
100
+ # attn @ v -> (B, num_heads, S, d_k)
101
+ out = torch.matmul(attn, v)
102
+ # transpose & combine heads -> (B, S, D)
103
+ out = out.transpose(1, 2).contiguous().view(B, S, D)
104
+ out = self.out_proj(out) # final linear
105
+
106
+ if return_attn:
107
+ return out, attn
108
+ return out, None
109
+
110
+
111
+ class FeedForward(nn.Module):
112
+ """Position-wise feed-forward network."""
113
+
114
+ def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
115
+ super().__init__()
116
+ self.net = nn.Sequential(
117
+ nn.Linear(d_model, d_ff),
118
+ nn.GELU(),
119
+ nn.Dropout(dropout),
120
+ nn.Linear(d_ff, d_model),
121
+ nn.Dropout(dropout),
122
+ )
123
+
124
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
125
+ return self.net(x)
126
+
127
+
128
+ class TransformerBlock(nn.Module):
129
+ """A single Transformer block: MHSA -> Add&Norm -> FFN -> Add&Norm"""
130
+
131
+ def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
132
+ super().__init__()
133
+ self.ln1 = nn.LayerNorm(d_model)
134
+ self.attn = MultiHeadSelfAttention(d_model, num_heads, dropout)
135
+ self.ln2 = nn.LayerNorm(d_model)
136
+ self.ff = FeedForward(d_model, d_ff, dropout)
137
+
138
+ def forward(
139
+ self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, return_attn: bool = False
140
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
141
+ # Pre-norm style: ln -> attn -> add
142
+ z = self.ln1(x)
143
+ attn_out, attn_weights = self.attn(z, mask=mask, return_attn=return_attn)
144
+ x = x + attn_out
145
+ # FFN
146
+ z2 = self.ln2(x)
147
+ ff_out = self.ff(z2)
148
+ x = x + ff_out
149
+ if return_attn:
150
+ return x, attn_weights
151
+ return x, None
152
+
153
+
154
+ class TinyTransformerLM(nn.Module):
155
+ """
156
+ Tiny Transformer language model for educational training/experiments.
157
+ Not tokenizer-dependent; expects token ids.
158
+ """
159
+
160
+ def __init__(
161
+ self,
162
+ vocab_size: int,
163
+ d_model: int = 256,
164
+ n_layers: int = 4,
165
+ num_heads: int = 4,
166
+ d_ff: int = 1024,
167
+ max_len: int = 512,
168
+ dropout: float = 0.1,
169
+ ):
170
+ super().__init__()
171
+ self.vocab_size = vocab_size
172
+ self.tok_emb = nn.Embedding(vocab_size, d_model)
173
+ self.pos_enc = PositionalEncoding(d_model, max_len=max_len)
174
+ self.layers = nn.ModuleList(
175
+ [TransformerBlock(d_model, num_heads, d_ff, dropout) for _ in range(n_layers)]
176
+ )
177
+ self.ln_f = nn.LayerNorm(d_model)
178
+ self.head = nn.Linear(d_model, vocab_size, bias=False) # logits head
179
+
180
+ def forward(
181
+ self, input_ids: torch.LongTensor, mask: Optional[torch.Tensor] = None, return_attn_layer: Optional[int] = None
182
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
183
+ """
184
+ input_ids: (B, S)
185
+ returns: logits (B, S, vocab_size)
186
+ if return_attn_layer is an int, it will return attention weights from that layer (heads)
187
+ """
188
+ B, S = input_ids.shape
189
+ x = self.tok_emb(input_ids) # (B, S, d_model)
190
+ x = self.pos_enc(x)
191
+ attn_weights = None
192
+ for idx, layer in enumerate(self.layers):
193
+ if return_attn_layer is not None and idx == return_attn_layer:
194
+ x, attn_weights = layer(x, mask=mask, return_attn=True)
195
+ else:
196
+ x, _ = layer(x, mask=mask, return_attn=False)
197
+ x = self.ln_f(x)
198
+ logits = self.head(x) # (B, S, vocab_size)
199
+ return logits, attn_weights