yat343 commited on
Commit
4d9ca5d
·
verified ·
1 Parent(s): 06697e8

Upload train_standalone.py

Browse files
Files changed (1) hide show
  1. train_standalone.py +317 -0
train_standalone.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Step-by-step training script for nano GPT — SELF-CONTAINED.
3
+
4
+ Contains both the model architecture and training code so it can run
5
+ as a single file in an HF Job.
6
+ """
7
+
8
+ import os
9
+ import math
10
+ import time
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import functional as F
14
+ from dataclasses import dataclass
15
+
16
+ # =============================================================================
17
+ # PART 1: MODEL
18
+ # =============================================================================
19
+
20
+ @dataclass
21
+ class GPTConfig:
22
+ block_size: int = 256
23
+ vocab_size: int = 65
24
+ n_layer: int = 4
25
+ n_head: int = 4
26
+ n_embd: int = 256
27
+ dropout: float = 0.0
28
+
29
+ class CausalSelfAttention(nn.Module):
30
+ def __init__(self, config: GPTConfig):
31
+ super().__init__()
32
+ assert config.n_embd % config.n_head == 0
33
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
34
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
35
+ self.n_head = config.n_head
36
+ self.n_embd = config.n_embd
37
+ self.register_buffer(
38
+ "bias",
39
+ torch.tril(torch.ones(config.block_size, config.block_size))
40
+ .view(1, 1, config.block_size, config.block_size)
41
+ )
42
+
43
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
44
+ B, T, C = x.size()
45
+ qkv = self.c_attn(x)
46
+ q, k, v = qkv.split(self.n_embd, dim=2)
47
+ head_size = C // self.n_head
48
+ q = q.view(B, T, self.n_head, head_size).transpose(1, 2)
49
+ k = k.view(B, T, self.n_head, head_size).transpose(1, 2)
50
+ v = v.view(B, T, self.n_head, head_size).transpose(1, 2)
51
+ att = (q @ k.transpose(-2, -1)) * (1.0 / (head_size ** 0.5))
52
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
53
+ att = F.softmax(att, dim=-1)
54
+ y = att @ v
55
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
56
+ y = self.c_proj(y)
57
+ return y
58
+
59
+ class MLP(nn.Module):
60
+ def __init__(self, config: GPTConfig):
61
+ super().__init__()
62
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
63
+ self.gelu = nn.GELU()
64
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
65
+ self.dropout = nn.Dropout(config.dropout)
66
+
67
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
68
+ x = self.c_fc(x)
69
+ x = self.gelu(x)
70
+ x = self.c_proj(x)
71
+ x = self.dropout(x)
72
+ return x
73
+
74
+ class Block(nn.Module):
75
+ def __init__(self, config: GPTConfig):
76
+ super().__init__()
77
+ self.ln_1 = nn.LayerNorm(config.n_embd)
78
+ self.attn = CausalSelfAttention(config)
79
+ self.ln_2 = nn.LayerNorm(config.n_embd)
80
+ self.mlp = MLP(config)
81
+
82
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
83
+ x = x + self.attn(self.ln_1(x))
84
+ x = x + self.mlp(self.ln_2(x))
85
+ return x
86
+
87
+ class GPT(nn.Module):
88
+ def __init__(self, config: GPTConfig):
89
+ super().__init__()
90
+ self.config = config
91
+ self.transformer = nn.ModuleDict({
92
+ "wte": nn.Embedding(config.vocab_size, config.n_embd),
93
+ "wpe": nn.Embedding(config.block_size, config.n_embd),
94
+ "h": nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
95
+ "ln_f": nn.LayerNorm(config.n_embd),
96
+ })
97
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
98
+ self.transformer.wte.weight = self.lm_head.weight
99
+ self.apply(self._init_weights)
100
+
101
+ def _init_weights(self, module):
102
+ if isinstance(module, nn.Linear):
103
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
104
+ if module.bias is not None:
105
+ torch.nn.init.zeros_(module.bias)
106
+ elif isinstance(module, nn.Embedding):
107
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
108
+
109
+ def forward(self, idx, targets=None):
110
+ B, T = idx.size()
111
+ assert T <= self.config.block_size
112
+ pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
113
+ x = self.transformer.wte(idx) + self.transformer.wpe(pos)
114
+ for block in self.transformer.h:
115
+ x = block(x)
116
+ x = self.transformer.ln_f(x)
117
+ logits = self.lm_head(x)
118
+ loss = None
119
+ if targets is not None:
120
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
121
+ return logits, loss
122
+
123
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
124
+ for _ in range(max_new_tokens):
125
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
126
+ logits, _ = self(idx_cond)
127
+ logits = logits[:, -1, :]
128
+ if top_k is not None:
129
+ v, _ = torch.topk(logits, top_k, dim=-1)
130
+ logits[logits < v[:, [-1]]] = float("-inf")
131
+ probs = F.softmax(logits / temperature, dim=-1)
132
+ idx_next = torch.multinomial(probs, num_samples=1)
133
+ idx = torch.cat((idx, idx_next), dim=1)
134
+ return idx
135
+
136
+ # =============================================================================
137
+ # PART 2: TRAINING
138
+ # =============================================================================
139
+
140
+ BATCH_SIZE = 64
141
+ BLOCK_SIZE = 256
142
+ MAX_ITERS = 5000
143
+ LEARNING_RATE = 1e-3
144
+ WARMUP_ITERS = 200
145
+ LR_DECAY_ITERS = 5000
146
+ MIN_LR = 1e-4
147
+ EVAL_INTERVAL = 500
148
+ EVAL_ITERS = 200
149
+ GRAD_CLIP = 1.0
150
+
151
+ device = "cuda" if torch.cuda.is_available() else "cpu"
152
+ print(f"Using device: {device}")
153
+
154
+ # Download data if needed
155
+ data_path = "data.pt"
156
+ if not os.path.exists(data_path):
157
+ import urllib.request
158
+ print("Downloading tiny Shakespeare...")
159
+ url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
160
+ urllib.request.urlretrieve(url, "input.txt")
161
+
162
+ with open("input.txt", "r", encoding="utf-8") as f:
163
+ text = f.read()
164
+
165
+ chars = sorted(list(set(text)))
166
+ vocab_size = len(chars)
167
+ stoi = {ch: i for i, ch in enumerate(chars)}
168
+ itos = {i: ch for i, ch in enumerate(chars)}
169
+ encode = lambda s: [stoi[c] for c in s]
170
+ data = torch.tensor(encode(text), dtype=torch.long)
171
+ n = int(0.9 * len(data))
172
+ train_data = data[:n]
173
+ val_data = data[n:]
174
+ torch.save({
175
+ "train": train_data,
176
+ "val": val_data,
177
+ "vocab_size": vocab_size,
178
+ "chars": chars,
179
+ "stoi": stoi,
180
+ "itos": itos,
181
+ }, data_path)
182
+ print("Data saved.")
183
+
184
+ data = torch.load(data_path, weights_only=False)
185
+ train_data = data["train"]
186
+ val_data = data["val"]
187
+ vocab_size = data["vocab_size"]
188
+ chars = data["chars"]
189
+ stoi = data["stoi"]
190
+ itos = data["itos"]
191
+
192
+ print(f"Vocab size : {vocab_size}")
193
+ print(f"Train tokens: {len(train_data):,}")
194
+ print(f"Val tokens : {len(val_data):,}")
195
+
196
+ def get_batch(split: str):
197
+ data_split = train_data if split == "train" else val_data
198
+ ix = torch.randint(len(data_split) - BLOCK_SIZE, (BATCH_SIZE,))
199
+ x = torch.stack([data_split[i : i + BLOCK_SIZE] for i in ix])
200
+ y = torch.stack([data_split[i + 1 : i + BLOCK_SIZE + 1] for i in ix])
201
+ x, y = x.to(device), y.to(device)
202
+ return x, y
203
+
204
+ def get_lr(iteration: int) -> float:
205
+ if iteration < WARMUP_ITERS:
206
+ return LEARNING_RATE * (iteration + 1) / WARMUP_ITERS
207
+ if iteration > LR_DECAY_ITERS:
208
+ return MIN_LR
209
+ decay_ratio = (iteration - WARMUP_ITERS) / (LR_DECAY_ITERS - WARMUP_ITERS)
210
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
211
+ return MIN_LR + coeff * (LEARNING_RATE - MIN_LR)
212
+
213
+ config = GPTConfig(
214
+ block_size=BLOCK_SIZE,
215
+ vocab_size=vocab_size,
216
+ n_layer=6,
217
+ n_head=6,
218
+ n_embd=384,
219
+ dropout=0.0,
220
+ )
221
+
222
+ model = GPT(config)
223
+ model.to(device)
224
+
225
+ param_count = sum(p.numel() for p in model.parameters())
226
+ print(f"\nModel config: {config}")
227
+ print(f"Total parameters: {param_count / 1e6:.2f} M")
228
+
229
+ decay_params = []
230
+ no_decay_params = []
231
+ for name, param in model.named_parameters():
232
+ if param.dim() >= 2:
233
+ decay_params.append(param)
234
+ else:
235
+ no_decay_params.append(param)
236
+
237
+ optim_groups = [
238
+ {"params": decay_params, "weight_decay": 0.1},
239
+ {"params": no_decay_params, "weight_decay": 0.0},
240
+ ]
241
+
242
+ optimizer = torch.optim.AdamW(optim_groups, lr=LEARNING_RATE, betas=(0.9, 0.95), eps=1e-8)
243
+
244
+ @torch.no_grad()
245
+ def estimate_loss():
246
+ out = {}
247
+ model.eval()
248
+ for split in ["train", "val"]:
249
+ losses = torch.zeros(EVAL_ITERS)
250
+ for k in range(EVAL_ITERS):
251
+ xb, yb = get_batch(split)
252
+ _, loss = model(xb, yb)
253
+ losses[k] = loss.item()
254
+ out[split] = losses.mean()
255
+ model.train()
256
+ return out
257
+
258
+ print("\n" + "=" * 60)
259
+ print("Starting training...")
260
+ print("=" * 60)
261
+
262
+ best_val_loss = float("inf")
263
+ start_time = time.time()
264
+
265
+ for iter_num in range(MAX_ITERS):
266
+ lr = get_lr(iter_num)
267
+ for param_group in optimizer.param_groups:
268
+ param_group["lr"] = lr
269
+
270
+ if iter_num % EVAL_INTERVAL == 0 or iter_num == MAX_ITERS - 1:
271
+ losses = estimate_loss()
272
+ elapsed = time.time() - start_time
273
+ print(
274
+ f"step {iter_num:5d} | "
275
+ f"train loss {losses['train']:.4f} | "
276
+ f"val loss {losses['val']:.4f} | "
277
+ f"lr {lr:.2e} | "
278
+ f"time {elapsed:.1f}s"
279
+ )
280
+
281
+ if losses["val"] < best_val_loss:
282
+ best_val_loss = losses["val"]
283
+ torch.save({
284
+ "model_state_dict": model.state_dict(),
285
+ "config": config,
286
+ "vocab_size": vocab_size,
287
+ "chars": chars,
288
+ "stoi": stoi,
289
+ "itos": itos,
290
+ }, "best.pt")
291
+ print(f" -> Saved new best model (val_loss={best_val_loss:.4f})")
292
+
293
+ xb, yb = get_batch("train")
294
+ logits, loss = model(xb, yb)
295
+ optimizer.zero_grad(set_to_none=True)
296
+ loss.backward()
297
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
298
+ optimizer.step()
299
+
300
+ losses = estimate_loss()
301
+ print(f"\nFinal -> train loss {losses['train']:.4f} | val loss {losses['val']:.4f}")
302
+
303
+ model.eval()
304
+ start_token = stoi["\n"]
305
+ context = torch.zeros((1, 1), dtype=torch.long, device=device)
306
+ context[0, 0] = start_token
307
+
308
+ with torch.no_grad():
309
+ generated = model.generate(context, max_new_tokens=500, temperature=1.0, top_k=40)
310
+
311
+ decode = lambda l: "".join([itos[i] for i in l])
312
+
313
+ print("\n--- Generated text ---\n")
314
+ print(decode(generated[0].tolist()))
315
+ print("\n--- End ---")
316
+
317
+ print("\nTraining complete! Best checkpoint saved to: best.pt")