yat343 commited on
Commit
3229f14
·
verified ·
1 Parent(s): 82cb4ef

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +247 -0
train.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Step-by-step training script for nano GPT.
3
+
4
+ What this script does:
5
+ 1. Load the preprocessed data (train / val tokens)
6
+ 2. Build the GPT model with our config
7
+ 3. Define a batching function that grabs random chunks of text
8
+ 4. Set up an AdamW optimizer with cosine learning-rate schedule
9
+ 5. Train loop: sample batch -> forward -> loss -> backward -> step
10
+ 6. Periodically evaluate on validation set and print metrics
11
+ 7. Save the best model checkpoint
12
+ 8. Generate a sample from the model after training
13
+ """
14
+
15
+ import os
16
+ import math
17
+ import time
18
+ import torch
19
+
20
+ # Import our model
21
+ from model import GPT, GPTConfig
22
+
23
+ # ---------------------------------------------------------------------------
24
+ # 1. Hyperparameters & Config
25
+ # ---------------------------------------------------------------------------
26
+ # Feel free to tweak these! For a tutorial we keep things small and fast.
27
+
28
+ BATCH_SIZE = 64 # how many sequences to process in parallel
29
+ BLOCK_SIZE = 256 # max context length for each sequence (must match model!)
30
+ MAX_ITERS = 5000 # total training steps
31
+ LEARNING_RATE = 1e-3 # starting learning rate
32
+ WARMUP_ITERS = 200 # linear warmup steps (gradually increase LR)
33
+ LR_DECAY_ITERS = 5000 # when to reach min LR (usually = MAX_ITERS)
34
+ MIN_LR = 1e-4 # minimum learning rate at end of cosine schedule
35
+ EVAL_INTERVAL = 500 # how often to run validation
36
+ EVAL_ITERS = 200 # how many val batches to average for a stable loss estimate
37
+ GRAD_CLIP = 1.0 # max gradient norm (prevents exploding gradients)
38
+
39
+ # Device selection
40
+ device = "cuda" if torch.cuda.is_available() else "cpu"
41
+ print(f"Using device: {device}")
42
+
43
+ # ---------------------------------------------------------------------------
44
+ # 2. Load Data
45
+ # ---------------------------------------------------------------------------
46
+ # We load the dictionary saved by prepare.py
47
+ data_path = os.path.join(os.path.dirname(__file__), "data.pt")
48
+ data = torch.load(data_path, weights_only=False)
49
+
50
+ train_data = data["train"]
51
+ val_data = data["val"]
52
+ vocab_size = data["vocab_size"]
53
+ chars = data["chars"]
54
+ stoi = data["stoi"]
55
+ itos = data["itos"]
56
+
57
+ print(f"Vocab size : {vocab_size}")
58
+ print(f"Train tokens: {len(train_data):,}")
59
+ print(f"Val tokens : {len(val_data):,}")
60
+
61
+ # ---------------------------------------------------------------------------
62
+ # 3. Batch sampling
63
+ # ---------------------------------------------------------------------------
64
+ # For language modeling, each training example is a random contiguous chunk
65
+ # of text. The input is tokens[0:T-1], the target is tokens[1:T].
66
+
67
+ def get_batch(split: str):
68
+ """Sample a single batch from train or val data."""
69
+ data_split = train_data if split == "train" else val_data
70
+ ix = torch.randint(len(data_split) - BLOCK_SIZE, (BATCH_SIZE,))
71
+ x = torch.stack([data_split[i : i + BLOCK_SIZE] for i in ix])
72
+ y = torch.stack([data_split[i + 1 : i + BLOCK_SIZE + 1] for i in ix])
73
+ x, y = x.to(device), y.to(device)
74
+ return x, y
75
+
76
+ # ---------------------------------------------------------------------------
77
+ # 4. Helper: Learning-rate schedule (cosine with linear warmup)
78
+ # ---------------------------------------------------------------------------
79
+ # Warmup is crucial for transformers — it prevents early spikes in loss
80
+ # caused by large gradients when the model is still random.
81
+
82
+ def get_lr(iteration: int) -> float:
83
+ if iteration < WARMUP_ITERS:
84
+ # Linear warmup
85
+ return LEARNING_RATE * (iteration + 1) / WARMUP_ITERS
86
+ if iteration > LR_DECAY_ITERS:
87
+ return MIN_LR
88
+ # Cosine decay after warmup
89
+ decay_ratio = (iteration - WARMUP_ITERS) / (LR_DECAY_ITERS - WARMUP_ITERS)
90
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
91
+ return MIN_LR + coeff * (LEARNING_RATE - MIN_LR)
92
+
93
+ # ---------------------------------------------------------------------------
94
+ # 5. Model Setup
95
+ # ---------------------------------------------------------------------------
96
+ # We match block_size to our training hyperparameter above.
97
+ # For tiny Shakespeare, even a 4-layer model can learn structure.
98
+
99
+ config = GPTConfig(
100
+ block_size=BLOCK_SIZE,
101
+ vocab_size=vocab_size,
102
+ n_layer=6, # deeper = more capacity to learn patterns
103
+ n_head=6,
104
+ n_embd=384,
105
+ dropout=0.0,
106
+ )
107
+
108
+ model = GPT(config)
109
+ model.to(device)
110
+
111
+ # Count parameters
112
+ param_count = sum(p.numel() for p in model.parameters())
113
+ print(f"\nModel config: {config}")
114
+ print(f"Total parameters: {param_count / 1e6:.2f} M")
115
+
116
+ # ---------------------------------------------------------------------------
117
+ # 6. Optimizer
118
+ # ---------------------------------------------------------------------------
119
+ # We separate parameters that should get weight decay (2D weights)
120
+ # from those that should not (1D biases, LayerNorm scales).
121
+ # This is standard practice and slightly improves training.
122
+
123
+ decay_params = []
124
+ no_decay_params = []
125
+ for name, param in model.named_parameters():
126
+ if param.dim() >= 2:
127
+ decay_params.append(param)
128
+ else:
129
+ no_decay_params.append(param)
130
+
131
+ optim_groups = [
132
+ {"params": decay_params, "weight_decay": 0.1},
133
+ {"params": no_decay_params, "weight_decay": 0.0},
134
+ ]
135
+
136
+ optimizer = torch.optim.AdamW(optim_groups, lr=LEARNING_RATE, betas=(0.9, 0.95), eps=1e-8)
137
+
138
+ # ---------------------------------------------------------------------------
139
+ # 7. Evaluation helper
140
+ # ---------------------------------------------------------------------------
141
+ # We average the loss over multiple validation batches for a stable estimate.
142
+ # torch.no_grad() disables gradient computation -> faster and less memory.
143
+
144
+ @torch.no_grad()
145
+ def estimate_loss():
146
+ out = {}
147
+ model.eval() # set model to evaluation mode
148
+ for split in ["train", "val"]:
149
+ losses = torch.zeros(EVAL_ITERS)
150
+ for k in range(EVAL_ITERS):
151
+ xb, yb = get_batch(split)
152
+ _, loss = model(xb, yb)
153
+ losses[k] = loss.item()
154
+ out[split] = losses.mean()
155
+ model.train() # set model back to training mode
156
+ return out
157
+
158
+ # ---------------------------------------------------------------------------
159
+ # 8. Training Loop
160
+ # ---------------------------------------------------------------------------
161
+ print("\n" + "=" * 60)
162
+ print("Starting training...")
163
+ print("=" * 60)
164
+
165
+ best_val_loss = float("inf")
166
+ start_time = time.time()
167
+
168
+ for iter_num in range(MAX_ITERS):
169
+ # --- Learning rate scheduling ---
170
+ lr = get_lr(iter_num)
171
+ for param_group in optimizer.param_groups:
172
+ param_group["lr"] = lr
173
+
174
+ # --- Periodic evaluation ---
175
+ if iter_num % EVAL_INTERVAL == 0 or iter_num == MAX_ITERS - 1:
176
+ losses = estimate_loss()
177
+ elapsed = time.time() - start_time
178
+ print(
179
+ f"step {iter_num:5d} | "
180
+ f"train loss {losses['train']:.4f} | "
181
+ f"val loss {losses['val']:.4f} | "
182
+ f"lr {lr:.2e} | "
183
+ f"time {elapsed:.1f}s"
184
+ )
185
+
186
+ # Save the best checkpoint
187
+ if losses["val"] < best_val_loss:
188
+ best_val_loss = losses["val"]
189
+ checkpoint_path = os.path.join(os.path.dirname(__file__), "best.pt")
190
+ torch.save({
191
+ "model_state_dict": model.state_dict(),
192
+ "config": config,
193
+ "vocab_size": vocab_size,
194
+ "chars": chars,
195
+ "stoi": stoi,
196
+ "itos": itos,
197
+ }, checkpoint_path)
198
+ print(f" -> Saved new best model (val_loss={best_val_loss:.4f})")
199
+
200
+ # --- Training step ---
201
+ xb, yb = get_batch("train")
202
+
203
+ # Forward
204
+ logits, loss = model(xb, yb)
205
+
206
+ # Backward
207
+ optimizer.zero_grad(set_to_none=True)
208
+ loss.backward()
209
+
210
+ # Gradient clipping (prevents exploding gradients)
211
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
212
+
213
+ # Optimizer step
214
+ optimizer.step()
215
+
216
+ # ---------------------------------------------------------------------------
217
+ # 9. Final evaluation
218
+ # ---------------------------------------------------------------------------
219
+ losses = estimate_loss()
220
+ print(f"\nFinal -> train loss {losses['train']:.4f} | val loss {losses['val']:.4f}")
221
+
222
+ # ---------------------------------------------------------------------------
223
+ # 10. Generate text from the trained model
224
+ # ---------------------------------------------------------------------------
225
+ print("\n" + "=" * 60)
226
+ print("Generating sample text...")
227
+ print("=" * 60)
228
+
229
+ model.eval()
230
+
231
+ # Start from a newline character (index of '\n' in our vocab)
232
+ start_token = stoi["\n"]
233
+ context = torch.zeros((1, 1), dtype=torch.long, device=device)
234
+ context[0, 0] = start_token
235
+
236
+ with torch.no_grad():
237
+ generated = model.generate(context, max_new_tokens=500, temperature=1.0, top_k=40)
238
+
239
+ # Rebuild decode function from saved mappings
240
+ decode = lambda l: "".join([itos[i] for i in l])
241
+
242
+ # Decode to text
243
+ print("\n--- Generated text ---\n")
244
+ print(decode(generated[0].tolist()))
245
+ print("\n--- End ---")
246
+
247
+ print("\nTraining complete! Best checkpoint saved to: best.pt")