File size: 8,966 Bytes
3229f14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
"""
Step-by-step training script for nano GPT.

What this script does:
  1. Load the preprocessed data (train / val tokens)
  2. Build the GPT model with our config
  3. Define a batching function that grabs random chunks of text
  4. Set up an AdamW optimizer with cosine learning-rate schedule
  5. Train loop: sample batch -> forward -> loss -> backward -> step
  6. Periodically evaluate on validation set and print metrics
  7. Save the best model checkpoint
  8. Generate a sample from the model after training
"""

import os
import math
import time
import torch

# Import our model
from model import GPT, GPTConfig

# ---------------------------------------------------------------------------
# 1. Hyperparameters & Config
# ---------------------------------------------------------------------------
# Feel free to tweak these! For a tutorial we keep things small and fast.

BATCH_SIZE = 64          # how many sequences to process in parallel
BLOCK_SIZE = 256         # max context length for each sequence (must match model!)
MAX_ITERS = 5000         # total training steps
LEARNING_RATE = 1e-3     # starting learning rate
WARMUP_ITERS = 200       # linear warmup steps (gradually increase LR)
LR_DECAY_ITERS = 5000    # when to reach min LR (usually = MAX_ITERS)
MIN_LR = 1e-4            # minimum learning rate at end of cosine schedule
EVAL_INTERVAL = 500      # how often to run validation
EVAL_ITERS = 200         # how many val batches to average for a stable loss estimate
GRAD_CLIP = 1.0          # max gradient norm (prevents exploding gradients)

# Device selection
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# ---------------------------------------------------------------------------
# 2. Load Data
# ---------------------------------------------------------------------------
# We load the dictionary saved by prepare.py
data_path = os.path.join(os.path.dirname(__file__), "data.pt")
data = torch.load(data_path, weights_only=False)

train_data = data["train"]
val_data   = data["val"]
vocab_size = data["vocab_size"]
chars      = data["chars"]
stoi       = data["stoi"]
itos       = data["itos"]

print(f"Vocab size : {vocab_size}")
print(f"Train tokens: {len(train_data):,}")
print(f"Val tokens  : {len(val_data):,}")

# ---------------------------------------------------------------------------
# 3. Batch sampling
# ---------------------------------------------------------------------------
# For language modeling, each training example is a random contiguous chunk
# of text. The input is tokens[0:T-1], the target is tokens[1:T].

def get_batch(split: str):
    """Sample a single batch from train or val data."""
    data_split = train_data if split == "train" else val_data
    ix = torch.randint(len(data_split) - BLOCK_SIZE, (BATCH_SIZE,))
    x = torch.stack([data_split[i : i + BLOCK_SIZE] for i in ix])
    y = torch.stack([data_split[i + 1 : i + BLOCK_SIZE + 1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

# ---------------------------------------------------------------------------
# 4. Helper: Learning-rate schedule (cosine with linear warmup)
# ---------------------------------------------------------------------------
# Warmup is crucial for transformers — it prevents early spikes in loss
# caused by large gradients when the model is still random.

def get_lr(iteration: int) -> float:
    if iteration < WARMUP_ITERS:
        # Linear warmup
        return LEARNING_RATE * (iteration + 1) / WARMUP_ITERS
    if iteration > LR_DECAY_ITERS:
        return MIN_LR
    # Cosine decay after warmup
    decay_ratio = (iteration - WARMUP_ITERS) / (LR_DECAY_ITERS - WARMUP_ITERS)
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return MIN_LR + coeff * (LEARNING_RATE - MIN_LR)

# ---------------------------------------------------------------------------
# 5. Model Setup
# ---------------------------------------------------------------------------
# We match block_size to our training hyperparameter above.
# For tiny Shakespeare, even a 4-layer model can learn structure.

config = GPTConfig(
    block_size=BLOCK_SIZE,
    vocab_size=vocab_size,
    n_layer=6,       # deeper = more capacity to learn patterns
    n_head=6,
    n_embd=384,
    dropout=0.0,
)

model = GPT(config)
model.to(device)

# Count parameters
param_count = sum(p.numel() for p in model.parameters())
print(f"\nModel config: {config}")
print(f"Total parameters: {param_count / 1e6:.2f} M")

# ---------------------------------------------------------------------------
# 6. Optimizer
# ---------------------------------------------------------------------------
# We separate parameters that should get weight decay (2D weights)
# from those that should not (1D biases, LayerNorm scales).
# This is standard practice and slightly improves training.

decay_params = []
no_decay_params = []
for name, param in model.named_parameters():
    if param.dim() >= 2:
        decay_params.append(param)
    else:
        no_decay_params.append(param)

optim_groups = [
    {"params": decay_params, "weight_decay": 0.1},
    {"params": no_decay_params, "weight_decay": 0.0},
]

optimizer = torch.optim.AdamW(optim_groups, lr=LEARNING_RATE, betas=(0.9, 0.95), eps=1e-8)

# ---------------------------------------------------------------------------
# 7. Evaluation helper
# ---------------------------------------------------------------------------
# We average the loss over multiple validation batches for a stable estimate.
# torch.no_grad() disables gradient computation -> faster and less memory.

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()  # set model to evaluation mode
    for split in ["train", "val"]:
        losses = torch.zeros(EVAL_ITERS)
        for k in range(EVAL_ITERS):
            xb, yb = get_batch(split)
            _, loss = model(xb, yb)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()  # set model back to training mode
    return out

# ---------------------------------------------------------------------------
# 8. Training Loop
# ---------------------------------------------------------------------------
print("\n" + "=" * 60)
print("Starting training...")
print("=" * 60)

best_val_loss = float("inf")
start_time = time.time()

for iter_num in range(MAX_ITERS):
    # --- Learning rate scheduling ---
    lr = get_lr(iter_num)
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

    # --- Periodic evaluation ---
    if iter_num % EVAL_INTERVAL == 0 or iter_num == MAX_ITERS - 1:
        losses = estimate_loss()
        elapsed = time.time() - start_time
        print(
            f"step {iter_num:5d} | "
            f"train loss {losses['train']:.4f} | "
            f"val loss {losses['val']:.4f} | "
            f"lr {lr:.2e} | "
            f"time {elapsed:.1f}s"
        )

        # Save the best checkpoint
        if losses["val"] < best_val_loss:
            best_val_loss = losses["val"]
            checkpoint_path = os.path.join(os.path.dirname(__file__), "best.pt")
            torch.save({
                "model_state_dict": model.state_dict(),
                "config": config,
                "vocab_size": vocab_size,
                "chars": chars,
                "stoi": stoi,
                "itos": itos,
            }, checkpoint_path)
            print(f"  -> Saved new best model (val_loss={best_val_loss:.4f})")

    # --- Training step ---
    xb, yb = get_batch("train")

    # Forward
    logits, loss = model(xb, yb)

    # Backward
    optimizer.zero_grad(set_to_none=True)
    loss.backward()

    # Gradient clipping (prevents exploding gradients)
    torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)

    # Optimizer step
    optimizer.step()

# ---------------------------------------------------------------------------
# 9. Final evaluation
# ---------------------------------------------------------------------------
losses = estimate_loss()
print(f"\nFinal -> train loss {losses['train']:.4f} | val loss {losses['val']:.4f}")

# ---------------------------------------------------------------------------
# 10. Generate text from the trained model
# ---------------------------------------------------------------------------
print("\n" + "=" * 60)
print("Generating sample text...")
print("=" * 60)

model.eval()

# Start from a newline character (index of '\n' in our vocab)
start_token = stoi["\n"]
context = torch.zeros((1, 1), dtype=torch.long, device=device)
context[0, 0] = start_token

with torch.no_grad():
    generated = model.generate(context, max_new_tokens=500, temperature=1.0, top_k=40)

# Rebuild decode function from saved mappings
decode = lambda l: "".join([itos[i] for i in l])

# Decode to text
print("\n--- Generated text ---\n")
print(decode(generated[0].tolist()))
print("\n--- End ---")

print("\nTraining complete! Best checkpoint saved to: best.pt")