pierjoe commited on
Commit
37406bb
·
verified ·
1 Parent(s): bb83607

Upload minitransformer-stories.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. minitransformer-stories.py +304 -0
minitransformer-stories.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from torch.utils.data import Dataset, DataLoader, random_split
5
+ import urllib.request
6
+ import os
7
+ from transformers import AutoTokenizer, logging
8
+ import pandas as pd
9
+ from tqdm import tqdm
10
+ from safetensors.torch import save_file
11
+ from datasets import load_dataset
12
+
13
+ logging.set_verbosity_error()
14
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
15
+
16
+ # ----------------- CONFIG -----------------
17
+ NUM_STORIES = 50_000
18
+ SAVE_EVERY = 5
19
+ MODEL_NAME = "mini_transformer_v4"
20
+ N_DATA_WORKERS = 8
21
+ PIN_MEMORY = True if N_DATA_WORKERS > 0 and torch.cuda.is_available() else False
22
+ BATCH_SIZE = 128
23
+ EVAL_EVERY = 5
24
+ LEARNING_RATE = 3e-4
25
+ NUM_EPOCHS = 50
26
+ USE_AMP = True
27
+ STRIDE = 64
28
+ CHECKPOINT_DIR = f"MODELS/checkpoints/{MODEL_NAME}"
29
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
30
+ DATASET = "DATA/generated_dataset_very_big.csv"
31
+
32
+ CONTEXT_LENGTH = 256
33
+ EMBEDDING_DIMENSION = 512
34
+ HEAD_NUMBER = 8
35
+ N_LAYER = 6
36
+ # ----------------- MODEL -----------------
37
+
38
+
39
+ # TransformerBlock (from your previous code)
40
+ class TransformerBlock(nn.Module):
41
+ def __init__(self, emb_dim, num_heads, context_length, dropout=0.1):
42
+ super().__init__()
43
+ self.ln1 = nn.LayerNorm(emb_dim)
44
+ self.ln2 = nn.LayerNorm(emb_dim)
45
+ self.attn = nn.MultiheadAttention(
46
+ emb_dim, num_heads, dropout=dropout, batch_first=True
47
+ )
48
+ self.mlp = nn.Sequential(
49
+ nn.Linear(emb_dim, 4 * emb_dim),
50
+ nn.GELU(),
51
+ nn.Linear(4 * emb_dim, emb_dim),
52
+ nn.Dropout(dropout),
53
+ )
54
+
55
+ def forward(self, x):
56
+ attn_out, _ = self.attn(
57
+ self.ln1(x), self.ln1(x), self.ln1(x), need_weights=False
58
+ )
59
+ x = x + attn_out
60
+ x = x + self.mlp(self.ln2(x))
61
+ return x
62
+
63
+
64
+ class MiniTransformer(nn.Module):
65
+ def __init__(
66
+ self,
67
+ vocab_size,
68
+ emb_dim,
69
+ context_length,
70
+ num_heads,
71
+ num_layers,
72
+ dropout=0.1,
73
+ ):
74
+ super().__init__()
75
+ self.emb = nn.Embedding(vocab_size, emb_dim)
76
+ self.pos_emb = nn.Embedding(context_length, emb_dim)
77
+ self.blocks = nn.Sequential(
78
+ *[
79
+ TransformerBlock(emb_dim, num_heads, context_length, dropout)
80
+ for _ in range(num_layers)
81
+ ]
82
+ )
83
+ self.ln_f = nn.LayerNorm(emb_dim)
84
+ self.head = nn.Linear(emb_dim, vocab_size, bias=False)
85
+ self.context_length = context_length
86
+
87
+ def forward(self, x):
88
+ B, T = x.shape
89
+ pos = torch.arange(T, device=x.device)
90
+ x = self.emb(x) + self.pos_emb(pos)
91
+ x = self.blocks(x)
92
+ x = self.ln_f(x)
93
+ logits = self.head(x)
94
+ return logits
95
+
96
+
97
+ # ----------------- DATASET -----------------
98
+ class SlidingWindowDataset(Dataset):
99
+ def __init__(self, texts, tokenizer, context_length=128, stride=64):
100
+ self.tokenizer = tokenizer
101
+ self.context_length = context_length
102
+ self.stride = stride
103
+
104
+ print("Tokenizing texts...")
105
+ # Batch tokenize - MUCH faster
106
+ batch_size = 10000
107
+ all_token_ids = []
108
+
109
+ for i in tqdm(range(0, len(texts), batch_size), desc="Tokenizing batches"):
110
+ batch = texts[i : i + batch_size]
111
+ encoded = tokenizer(
112
+ batch,
113
+ add_special_tokens=False,
114
+ truncation=False,
115
+ padding=False,
116
+ return_attention_mask=False,
117
+ )
118
+ all_token_ids.extend(encoded["input_ids"])
119
+
120
+ print(f"Tokenized {len(all_token_ids)} texts")
121
+
122
+ # Flatten all token IDs into single stream
123
+ print("Flattening tokens...")
124
+ self.tokens = []
125
+ for ids in tqdm(all_token_ids, desc="Flattening"):
126
+ self.tokens.extend(ids)
127
+
128
+ self.tokens = torch.tensor(self.tokens, dtype=torch.long)
129
+ self.n_samples = (len(self.tokens) - context_length) // stride
130
+
131
+ print(f"Total tokens: {len(self.tokens):,}")
132
+ print(f"Total samples: {self.n_samples:,}")
133
+ print(f"Avg tokens per text: {len(self.tokens) / len(texts):.1f}")
134
+
135
+ def __len__(self):
136
+ return self.n_samples
137
+
138
+ def __getitem__(self, idx):
139
+ start = idx * self.stride
140
+ end = start + self.context_length + 1
141
+ chunk = self.tokens[start:end]
142
+ x = chunk[:-1]
143
+ y = chunk[1:]
144
+ return x, y
145
+
146
+
147
+ # as long as we flatten the list of strings into one single piece of text
148
+ # and then we divide it into pieces of the same length, by definition we don't need padding.
149
+ # we need padding in the case when we have multiple separated sentences in a list,
150
+ # and we want to create a batch with them --> than we surely need to padd all the sequences
151
+ # to the same length --> max length or context length (with duely truncation if needed)
152
+
153
+ # example
154
+ # we have a batch like this:
155
+ # ["ciao", "ciao io sono", "ciao io sono pippo"]
156
+ # becomes:
157
+ # [101, 2003, 102]
158
+ # [101, 2003, 2026, 2070, 102]
159
+ # [101, 2003, 2026, 2070, 5274, 102]
160
+ # we have to pad to max length
161
+ # [101, 2003, 102, 0, 0, 0]
162
+ # [101, 2003, 2026, 2070, 102, 0]
163
+ # [101, 2003, 2026, 2070, 5274, 102]
164
+
165
+
166
+ # ----------------- DEVICE -----------------
167
+ device = torch.device("cuda" if torch.cuda.is_available() else "mps")
168
+ print(f"Using device: {device}")
169
+ if device.type == "cuda":
170
+ print(torch.cuda.get_device_name(0))
171
+ print(torch.cuda.memory_allocated() / 1024**2, "MB allocated")
172
+ print(torch.cuda.memory_reserved() / 1024**2, "MB reserved")
173
+
174
+
175
+ # ----------------- LOAD DATA -----------------
176
+
177
+ print("Loading TinyStories dataset...")
178
+ ds = load_dataset("roneneldan/TinyStories")
179
+
180
+ # Use subset - adjust this number based on training time
181
+ texts = ds["train"]["text"][:NUM_STORIES]
182
+ print(f"Using {len(texts)} stories out of {len(ds['train'])} total")
183
+
184
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
185
+ vocab_size = tokenizer.vocab_size
186
+
187
+ dataset = SlidingWindowDataset(texts, tokenizer, CONTEXT_LENGTH, STRIDE)
188
+ train_size = int(0.9 * len(dataset))
189
+ test_size = len(dataset) - train_size
190
+ train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
191
+ print(f"dataset train lenght: {len(train_dataset)}")
192
+ loader_train = DataLoader(
193
+ train_dataset,
194
+ batch_size=BATCH_SIZE,
195
+ shuffle=True,
196
+ num_workers=N_DATA_WORKERS,
197
+ pin_memory=PIN_MEMORY,
198
+ )
199
+ loader_test = DataLoader(
200
+ test_dataset,
201
+ batch_size=BATCH_SIZE,
202
+ shuffle=False,
203
+ num_workers=N_DATA_WORKERS,
204
+ pin_memory=PIN_MEMORY,
205
+ )
206
+
207
+
208
+ # ----------------- TRAINING SETUP -----------------
209
+
210
+ model = MiniTransformer(
211
+ vocab_size=vocab_size,
212
+ emb_dim=EMBEDDING_DIMENSION,
213
+ context_length=CONTEXT_LENGTH,
214
+ num_heads=HEAD_NUMBER,
215
+ num_layers=N_LAYER,
216
+ ).to(device)
217
+
218
+ n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
219
+ print(f"number of parameters: {n_params}")
220
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
221
+ scaler = torch.amp.GradScaler(enabled=USE_AMP and device.type == "cuda")
222
+ criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
223
+
224
+
225
+ # ----------------- CHECKPOINT RESUME -----------------
226
+ checkpoint_files = sorted([f for f in os.listdir(CHECKPOINT_DIR) if f.endswith(".pt")])
227
+ if checkpoint_files:
228
+ latest_ckpt = os.path.join(CHECKPOINT_DIR, checkpoint_files[-1])
229
+ ckpt = torch.load(latest_ckpt, map_location=device)
230
+ state_dict = ckpt["model_state"]
231
+
232
+ if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
233
+ state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
234
+
235
+ model.load_state_dict(state_dict)
236
+ optimizer.load_state_dict(ckpt["optimizer_state"])
237
+ start_epoch = ckpt["epoch"] + 1
238
+ print(f"Resumed from {latest_ckpt}")
239
+ else:
240
+ start_epoch = 0
241
+
242
+ model = torch.compile(model)
243
+
244
+ # ----------------- TRAINING LOOP -----------------
245
+ for epoch in range(start_epoch, NUM_EPOCHS):
246
+ model.train()
247
+ total_loss = 0
248
+
249
+ for x, y in tqdm(loader_train, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}"):
250
+ x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
251
+ optimizer.zero_grad()
252
+
253
+ with torch.amp.autocast(
254
+ "cuda", dtype=torch.float16, enabled=USE_AMP and device.type == "cuda"
255
+ ):
256
+ logits = model(x)
257
+ loss = criterion(logits.view(-1, vocab_size), y.view(-1))
258
+
259
+ scaler.scale(loss).backward()
260
+ scaler.step(optimizer)
261
+ scaler.update()
262
+
263
+ total_loss += loss.item() * x.size(0)
264
+
265
+ avg_train_loss = total_loss / len(train_dataset)
266
+ print(f"Train Loss: {avg_train_loss:.4f}")
267
+
268
+ # --- Evaluation ---
269
+ if (epoch + 1) % EVAL_EVERY == 0:
270
+ model.eval()
271
+ total_loss = 0
272
+ with torch.no_grad():
273
+ for x, y in loader_test:
274
+ x, y = x.to(device), y.to(device)
275
+ with torch.amp.autocast(
276
+ "cuda",
277
+ dtype=torch.bfloat16,
278
+ enabled=USE_AMP and device.type == "cuda",
279
+ ):
280
+ logits = model(x)
281
+ loss = criterion(logits.view(-1, vocab_size), y.view(-1))
282
+ total_loss += loss.item() * x.size(0)
283
+ avg_test_loss = total_loss / len(test_dataset)
284
+ print(f"Test Loss: {avg_test_loss:.4f}")
285
+
286
+ # --- Save checkpoint ---
287
+ if SAVE_EVERY > 0 and (epoch + 1) % SAVE_EVERY == 0:
288
+ torch.save(
289
+ {
290
+ "epoch": epoch,
291
+ "model_state": model.state_dict(),
292
+ "optimizer_state": optimizer.state_dict(),
293
+ "scaler_state": scaler.state_dict(),
294
+ },
295
+ os.path.join(CHECKPOINT_DIR, f"checkpoint_{MODEL_NAME}_epoch_{epoch+1}.pt"),
296
+ )
297
+ save_file(
298
+ model.state_dict(),
299
+ os.path.join(CHECKPOINT_DIR, f"model_{epoch+1}.safetensors"),
300
+ )
301
+
302
+
303
+ # check GPU utilization metrics here:
304
+ # nvidia-smi dmon -s u