pierjoe commited on
Commit
7d1d824
·
verified ·
1 Parent(s): 39e1d67

Upload minitransformer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. minitransformer.py +268 -0
minitransformer.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from torch.utils.data import Dataset, DataLoader, random_split
5
+ import urllib.request
6
+ import os
7
+ from transformers import AutoTokenizer, logging
8
+ import pandas as pd
9
+ from tqdm import tqdm
10
+
11
+
12
+ logging.set_verbosity_error()
13
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
14
+
15
+ # ----------------- CONFIG -----------------
16
+ SAVE_EVERY = 5
17
+ MODEL_NAME = "mini_transformer_v3"
18
+ N_DATA_WORKERS = 6
19
+ PIN_MEMORY = True if N_DATA_WORKERS > 0 and torch.cuda.is_available() else False
20
+ BATCH_SIZE = 128
21
+ EVAL_EVERY = 5
22
+ LEARNING_RATE = 3e-4
23
+ NUM_EPOCHS = 50
24
+ USE_AMP = True
25
+ STRIDE = 32
26
+ CHECKPOINT_DIR = f"MODELS/checkpoints/{MODEL_NAME}"
27
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
28
+ DATASET = "DATA/generated_dataset_very_big.csv"
29
+
30
+ CONTEXT_LENGTH = 128
31
+ EMBEDDING_DIMENSION = 512
32
+ HEAD_NUMBER = 4
33
+ N_LAYER = 4
34
+ # ----------------- MODEL -----------------
35
+
36
+
37
+ # TransformerBlock (from your previous code)
38
+ class TransformerBlock(nn.Module):
39
+ def __init__(self, emb_dim, num_heads, context_length, dropout=0.1):
40
+ super().__init__()
41
+ self.ln1 = nn.LayerNorm(emb_dim)
42
+ self.ln2 = nn.LayerNorm(emb_dim)
43
+ self.attn = nn.MultiheadAttention(
44
+ emb_dim, num_heads, dropout=dropout, batch_first=True
45
+ )
46
+ self.mlp = nn.Sequential(
47
+ nn.Linear(emb_dim, 4 * emb_dim),
48
+ nn.GELU(),
49
+ nn.Linear(4 * emb_dim, emb_dim),
50
+ nn.Dropout(dropout),
51
+ )
52
+
53
+ def forward(self, x):
54
+ attn_out, _ = self.attn(
55
+ self.ln1(x), self.ln1(x), self.ln1(x), need_weights=False
56
+ )
57
+ x = x + attn_out
58
+ x = x + self.mlp(self.ln2(x))
59
+ return x
60
+
61
+
62
+ class MiniTransformer(nn.Module):
63
+ def __init__(
64
+ self,
65
+ vocab_size,
66
+ emb_dim,
67
+ context_length,
68
+ num_heads,
69
+ num_layers,
70
+ dropout=0.1,
71
+ ):
72
+ super().__init__()
73
+ self.emb = nn.Embedding(vocab_size, emb_dim)
74
+ self.pos_emb = nn.Embedding(context_length, emb_dim)
75
+ self.blocks = nn.Sequential(
76
+ *[
77
+ TransformerBlock(emb_dim, num_heads, context_length, dropout)
78
+ for _ in range(num_layers)
79
+ ]
80
+ )
81
+ self.ln_f = nn.LayerNorm(emb_dim)
82
+ self.head = nn.Linear(emb_dim, vocab_size, bias=False)
83
+ self.context_length = context_length
84
+
85
+ def forward(self, x):
86
+ B, T = x.shape
87
+ pos = torch.arange(T, device=x.device)
88
+ x = self.emb(x) + self.pos_emb(pos)
89
+ x = self.blocks(x)
90
+ x = self.ln_f(x)
91
+ logits = self.head(x)
92
+ return logits
93
+
94
+
95
+ # ----------------- DATASET -----------------
96
+ class SlidingWindowDataset(Dataset):
97
+ def __init__(self, texts, tokenizer, context_length=128, stride=64):
98
+ self.tokenizer = tokenizer
99
+ self.context_length = context_length
100
+ self.stride = stride
101
+
102
+ # Flatten all text into a single long stream of token IDs
103
+ self.tokens = []
104
+ for text in texts:
105
+ ids = tokenizer.encode(text, add_special_tokens=False)
106
+ self.tokens.extend(ids)
107
+ self.tokens = torch.tensor(self.tokens, dtype=torch.long)
108
+
109
+ self.n_samples = (len(self.tokens) - context_length) // stride
110
+
111
+ def __len__(self):
112
+ return self.n_samples
113
+
114
+ def __getitem__(self, idx):
115
+ start = idx * self.stride
116
+ end = start + self.context_length + 1
117
+ chunk = self.tokens[start:end]
118
+ x = chunk[:-1]
119
+ y = chunk[1:]
120
+ return x, y
121
+
122
+
123
+ # as long as we flatten the list of strings into one single piece of text
124
+ # and then we divide it into pieces of the same length, by definition we don't need padding.
125
+ # we need padding in the case when we have multiple separated sentences in a list,
126
+ # and we want to create a batch with them --> than we surely need to padd all the sequences
127
+ # to the same length --> max length or context length (with duely truncation if needed)
128
+
129
+ # example
130
+ # we have a batch like this:
131
+ # ["ciao", "ciao io sono", "ciao io sono pippo"]
132
+ # becomes:
133
+ # [101, 2003, 102]
134
+ # [101, 2003, 2026, 2070, 102]
135
+ # [101, 2003, 2026, 2070, 5274, 102]
136
+ # we have to pad to max length
137
+ # [101, 2003, 102, 0, 0, 0]
138
+ # [101, 2003, 2026, 2070, 102, 0]
139
+ # [101, 2003, 2026, 2070, 5274, 102]
140
+
141
+
142
+ # ----------------- DEVICE -----------------
143
+ device = torch.device("cuda" if torch.cuda.is_available() else "mps")
144
+ print(f"Using device: {device}")
145
+ if device.type == "cuda":
146
+ print(torch.cuda.get_device_name(0))
147
+ print(torch.cuda.memory_allocated() / 1024**2, "MB allocated")
148
+ print(torch.cuda.memory_reserved() / 1024**2, "MB reserved")
149
+
150
+
151
+ # ----------------- LOAD DATA -----------------
152
+ df = pd.read_csv(DATASET)
153
+ texts = [
154
+ f"{row['system_prompt']} {row['question']} {row['answer']}"
155
+ for _, row in df.iterrows()
156
+ ]
157
+
158
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
159
+ vocab_size = tokenizer.vocab_size
160
+
161
+ dataset = SlidingWindowDataset(texts, tokenizer, CONTEXT_LENGTH, STRIDE)
162
+ train_size = int(0.9 * len(dataset))
163
+ test_size = len(dataset) - train_size
164
+ train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
165
+ print(f"dataset train lenght: {len(train_dataset)}")
166
+ loader_train = DataLoader(
167
+ train_dataset,
168
+ batch_size=BATCH_SIZE,
169
+ shuffle=True,
170
+ num_workers=N_DATA_WORKERS,
171
+ pin_memory=PIN_MEMORY,
172
+ )
173
+ loader_test = DataLoader(
174
+ test_dataset,
175
+ batch_size=BATCH_SIZE,
176
+ shuffle=False,
177
+ num_workers=N_DATA_WORKERS,
178
+ pin_memory=PIN_MEMORY,
179
+ )
180
+
181
+
182
+ # ----------------- TRAINING SETUP -----------------
183
+
184
+ model = MiniTransformer(
185
+ vocab_size=vocab_size,
186
+ emb_dim=EMBEDDING_DIMENSION,
187
+ context_length=CONTEXT_LENGTH,
188
+ num_heads=HEAD_NUMBER,
189
+ num_layers=N_LAYER,
190
+ ).to(device)
191
+
192
+ n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
193
+ print(f"number of parameters: {n_params}")
194
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
195
+ scaler = torch.amp.GradScaler(enabled=USE_AMP and device.type == "cuda")
196
+ criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
197
+
198
+
199
+ # ----------------- CHECKPOINT RESUME -----------------
200
+ checkpoint_files = sorted([f for f in os.listdir(CHECKPOINT_DIR) if f.endswith(".pt")])
201
+ if checkpoint_files:
202
+ latest_ckpt = os.path.join(CHECKPOINT_DIR, checkpoint_files[-1])
203
+ ckpt = torch.load(latest_ckpt, map_location=device)
204
+ model.load_state_dict(ckpt["model_state"])
205
+ optimizer.load_state_dict(ckpt["optimizer_state"])
206
+ start_epoch = ckpt["epoch"] + 1
207
+ print(f"Resumed from {latest_ckpt}")
208
+ else:
209
+ start_epoch = 0
210
+
211
+ model = torch.compile(model)
212
+
213
+ # ----------------- TRAINING LOOP -----------------
214
+ for epoch in range(start_epoch, NUM_EPOCHS):
215
+ model.train()
216
+ total_loss = 0
217
+
218
+ for x, y in tqdm(loader_train, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}"):
219
+ x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
220
+ optimizer.zero_grad()
221
+
222
+ with torch.amp.autocast(
223
+ "cuda", dtype=torch.float16, enabled=USE_AMP and device.type == "cuda"
224
+ ):
225
+ logits = model(x)
226
+ loss = criterion(logits.view(-1, vocab_size), y.view(-1))
227
+
228
+ scaler.scale(loss).backward()
229
+ scaler.step(optimizer)
230
+ scaler.update()
231
+
232
+ total_loss += loss.item() * x.size(0)
233
+
234
+ avg_train_loss = total_loss / len(train_dataset)
235
+ print(f"Train Loss: {avg_train_loss:.4f}")
236
+
237
+ # --- Evaluation ---
238
+ if (epoch + 1) % EVAL_EVERY == 0:
239
+ model.eval()
240
+ total_loss = 0
241
+ with torch.no_grad():
242
+ for x, y in loader_test:
243
+ x, y = x.to(device), y.to(device)
244
+ with torch.amp.autocast(
245
+ "cuda",
246
+ dtype=torch.float16,
247
+ enabled=USE_AMP and device.type == "cuda",
248
+ ):
249
+ logits = model(x)
250
+ loss = criterion(logits.view(-1, vocab_size), y.view(-1))
251
+ total_loss += loss.item() * x.size(0)
252
+ avg_test_loss = total_loss / len(test_dataset)
253
+ print(f"Test Loss: {avg_test_loss:.4f}")
254
+
255
+ # --- Save checkpoint ---
256
+ if SAVE_EVERY > 0 and (epoch + 1) % SAVE_EVERY == 0:
257
+ torch.save(
258
+ {
259
+ "epoch": epoch,
260
+ "model_state": model.state_dict(),
261
+ "optimizer_state": optimizer.state_dict(),
262
+ "scaler_state": scaler.state_dict(),
263
+ },
264
+ os.path.join(CHECKPOINT_DIR, f"checkpoint_{MODEL_NAME}_epoch_{epoch+1}.pt"),
265
+ )
266
+
267
+ # check GPU utilization metrics here:
268
+ # nvidia-smi dmon -s u