File size: 10,219 Bytes
6a1cd42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
"""
nano GPT: A tiny GPT model built from scratch in pure PyTorch.

This is a step-by-step tutorial implementation following Andrej Karpathy's
build-nanogpt approach. Every piece is explicit and commented.
"""

import torch
import torch.nn as nn
from torch.nn import functional as F
from dataclasses import dataclass


# ---------------------------------------------------------------------------
# Step 1: Configuration
# ---------------------------------------------------------------------------
# We define all hyperparameters in a single dataclass so they are easy to
# tweak without hunting through the code.

@dataclass
class GPTConfig:
    block_size: int = 256      # maximum sequence length (context length)
    vocab_size: int = 65       # number of unique characters in our dataset
    n_layer: int = 4           # number of transformer blocks
    n_head: int = 4            # number of attention heads per block
    n_embd: int = 256          # embedding dimension (hidden size)
    dropout: float = 0.0       # dropout probability (0 for small overfit-prone runs)


# ---------------------------------------------------------------------------
# Step 2: Causal Self-Attention
# ---------------------------------------------------------------------------
# This is the heart of the transformer. For each token we compute three
# vectors: Query, Key, and Value.
#
#   Query: "What am I looking for?"
#   Key:   "What do I contain?"
#   Value: "What information do I have?"
#
# We then compute attention scores = Q @ K.T, mask future tokens so the
# model cannot "cheat" by looking ahead, and take a weighted sum of Values.

class CausalSelfAttention(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        assert config.n_embd % config.n_head == 0, "n_embd must be divisible by n_head"

        # One linear layer projects input into Q, K, V concatenated together.
        # Output shape: (B, T, 3 * n_embd)
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)

        # Output projection back to n_embd
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)

        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout

        # Register a causal mask (lower-triangular) so we never attend to future tokens.
        # We do this once at init instead of recomputing every forward pass.
        self.register_buffer(
            "bias",
            torch.tril(torch.ones(config.block_size, config.block_size))
                 .view(1, 1, config.block_size, config.block_size)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.size()  # batch, sequence length, embedding dim

        # 1. Compute Q, K, V
        qkv = self.c_attn(x)                     # (B, T, 3*C)
        q, k, v = qkv.split(self.n_embd, dim=2)  # each (B, T, C)

        # 2. Reshape into (B, n_head, T, head_size) for multi-head attention
        head_size = C // self.n_head
        q = q.view(B, T, self.n_head, head_size).transpose(1, 2)   # (B, nh, T, hs)
        k = k.view(B, T, self.n_head, head_size).transpose(1, 2)   # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, head_size).transpose(1, 2)   # (B, nh, T, hs)

        # 3. Compute attention scores: (B, nh, T, hs) @ (B, nh, hs, T) -> (B, nh, T, T)
        # We scale by 1/sqrt(head_size) to keep gradients stable.
        att = (q @ k.transpose(-2, -1)) * (1.0 / (head_size ** 0.5))

        # 4. Apply causal mask: set future positions to -inf so softmax gives 0
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))

        # 5. Softmax to get probability distribution over past tokens
        att = F.softmax(att, dim=-1)

        # 6. Weighted sum of values: (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)
        y = att @ v

        # 7. Concatenate heads back together: (B, nh, T, hs) -> (B, T, nh*hs) = (B, T, C)
        y = y.transpose(1, 2).contiguous().view(B, T, C)

        # 8. Final output projection
        y = self.c_proj(y)
        return y


# ---------------------------------------------------------------------------
# Step 3: Feed-Forward Network (MLP)
# ---------------------------------------------------------------------------
# After attention, each token gets its own private "thinking" step through
# a simple two-layer MLP with a GELU non-linearity.

class MLP(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        # Expand by 4x (common in transformers) then project back down
        self.c_fc   = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu   = nn.GELU()
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x


# ---------------------------------------------------------------------------
# Step 4: Transformer Block
# ---------------------------------------------------------------------------
# A block = Attention -> Add & Norm -> MLP -> Add & Norm
# We use **pre-norm**: normalize BEFORE applying attention/MLP.
# This is what modern models (GPT-2, GPT-3, Llama, etc.) use.

class Block(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp  = MLP(config)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Pre-norm residual connections
        x = x + self.attn(self.ln_1(x))   # attention branch
        x = x + self.mlp(self.ln_2(x))    # MLP branch
        return x


# ---------------------------------------------------------------------------
# Step 5: Full GPT Model
# ---------------------------------------------------------------------------
# Putting it all together:
#   1. Token embedding table (wte): maps character index -> vector
#   2. Position embedding table (wpe): maps position index -> vector
#   3. Stack of N transformer blocks
#   4. Final layer norm
#   5. Language model head: projects back to vocab_size logits

class GPT(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict({
            "wte": nn.Embedding(config.vocab_size, config.n_embd),      # token embeddings
            "wpe": nn.Embedding(config.block_size, config.n_embd),      # position embeddings
            "h":   nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            "ln_f": nn.LayerNorm(config.n_embd),
        })
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # Weight tying: share the token embedding weights with the output projection.
        # This saves parameters and often improves training.
        self.transformer.wte.weight = self.lm_head.weight

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(
        self,
        idx: torch.Tensor,
        targets: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        """
        idx:    (B, T)  integer token indices
        targets:(B, T)  integer targets for next-token prediction (optional)
        returns: logits (B, T, vocab_size), loss (scalar or None)
        """
        B, T = idx.size()
        assert T <= self.config.block_size, f"Sequence length {T} exceeds block_size {self.config.block_size}"

        # 1. Token + position embeddings
        pos = torch.arange(0, T, dtype=torch.long, device=idx.device)          # (T,)
        tok_emb = self.transformer.wte(idx)                                     # (B, T, C)
        pos_emb = self.transformer.wpe(pos)                                     # (T, C)
        x = tok_emb + pos_emb                                                   # (B, T, C)

        # 2. Pass through transformer blocks
        for block in self.transformer.h:
            x = block(x)

        # 3. Final layer norm
        x = self.transformer.ln_f(x)

        # 4. Project to vocabulary logits
        logits = self.lm_head(x)                                                # (B, T, vocab_size)

        # 5. Compute cross-entropy loss if targets are provided
        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index=-1,
            )

        return logits, loss

    def generate(
        self,
        idx: torch.Tensor,
        max_new_tokens: int,
        temperature: float = 1.0,
        top_k: int | None = None,
    ) -> torch.Tensor:
        """
        Generate new tokens autoregressively.
        idx: (B, T) starting token indices
        """
        for _ in range(max_new_tokens):
            # Crop to block_size so we never exceed context length
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]

            # Forward pass
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :]  # take logits for the last token only: (B, vocab_size)

            # Optional top-k sampling
            if top_k is not None:
                v, _ = torch.topk(logits, top_k, dim=-1)
                logits[logits < v[:, [-1]]] = float("-inf")

            # Apply temperature and softmax
            probs = F.softmax(logits / temperature, dim=-1)

            # Sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            idx = torch.cat((idx, idx_next), dim=1)              # (B, T+1)

        return idx