dev-das commited on
Commit
cac4140
·
verified ·
1 Parent(s): f938628

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MyGPTForCausalLM"
4
+ ],
5
+ "context_length": 256,
6
+ "drop_rate": 0.1,
7
+ "dtype": "float32",
8
+ "emb_dim": 256,
9
+ "model_type": "my_gpt",
10
+ "n_heads": 4,
11
+ "n_layers": 12,
12
+ "qkv_bias": false,
13
+ "transformers_version": "5.1.0",
14
+ "vocab_size": 50257,
15
+
16
+ "bos_token_id": 50256,
17
+ "eos_token_id": 50256,
18
+ "pad_token_id": 50256
19
+ }
configuration_my_gpt.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class MyGPTConfig(PretrainedConfig):
4
+ model_type = "my_gpt"
5
+
6
+ def __init__(
7
+ self,
8
+ vocab_size=50257,
9
+ context_length=256,
10
+ emb_dim=256,
11
+ n_heads=4,
12
+ n_layers=12,
13
+ drop_rate=0.1,
14
+ qkv_bias=False,
15
+ **kwargs
16
+ ):
17
+ super().__init__(**kwargs)
18
+
19
+ self.vocab_size = vocab_size
20
+ self.context_length = context_length
21
+ self.emb_dim = emb_dim
22
+ self.n_heads = n_heads
23
+ self.n_layers = n_layers
24
+ self.drop_rate = drop_rate
25
+ self.qkv_bias = qkv_bias
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2315b9bb1ea7a701548cf7b78919b30120aa8f3be7eb3f84e917be93bc35d67
3
+ size 144226200
modeling_my_gpt.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel
4
+ from configuration_my_gpt import MyGPTConfig
5
+ from untrained_model import GPTModel
6
+
7
+ import os
8
+ import sys
9
+
10
+ curr_dir = os.getcwd()
11
+ parent_dir = os.path.dirname(curr_dir)
12
+
13
+ sys.path.insert(0, parent_dir)
14
+
15
+
16
+ class MyGPTForCausalLM(PreTrainedModel):
17
+ config_class = MyGPTConfig
18
+
19
+ def __init__(self, config):
20
+ super().__init__(config)
21
+
22
+ # Import your original GPTModel
23
+ self.model = GPTModel({
24
+ "vocab_size": config.vocab_size,
25
+ "context_length": config.context_length,
26
+ "emb_dim": config.emb_dim,
27
+ "n_heads": config.n_heads,
28
+ "n_layers": config.n_layers,
29
+ "drop_rate": config.drop_rate,
30
+ "qkv_bias": config.qkv_bias
31
+ })
32
+
33
+ self.post_init()
34
+
35
+ def forward(self, input_ids, labels=None):
36
+ logits = self.model(input_ids)
37
+
38
+ loss = None
39
+ if labels is not None:
40
+ shift_logits = logits[..., :-1, :].contiguous()
41
+ shift_labels = labels[..., 1:].contiguous()
42
+ loss_fct = nn.CrossEntropyLoss()
43
+ loss = loss_fct(
44
+ shift_logits.view(-1, shift_logits.size(-1)),
45
+ shift_labels.view(-1)
46
+ )
47
+
48
+ return {
49
+ "loss": loss,
50
+ "logits": logits,
51
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "unk_token": "<|endoftext|>",
5
+ "pad_token": "<|endoftext|>"
6
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<|endoftext|>",
5
+ "eos_token": "<|endoftext|>",
6
+ "errors": "replace",
7
+ "is_local": false,
8
+ "model_max_length": 1024,
9
+ "pad_token": "<|endoftext|>",
10
+ "tokenizer_class": "GPT2Tokenizer",
11
+ "unk_token": "<|endoftext|>"
12
+ }
untrained_model.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tiktoken
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.utils.data import Dataset, DataLoader
5
+
6
+ import matplotlib.pyplot as plt
7
+ from matplotlib.ticker import MaxNLocator
8
+ import numpy as np
9
+
10
+
11
+
12
+
13
+ class GPTDatasetV1(Dataset):
14
+ def __init__(self, txt, tokenizer, max_length, stride):
15
+ self.input_ids = []
16
+ self.target_ids = []
17
+
18
+ # Tokenize the entire text
19
+ token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})
20
+
21
+ # Use a sliding window to chunk the book into overlapping sequences of max_length
22
+ for i in range(0, len(token_ids) - max_length, stride):
23
+ input_chunk = token_ids[i:i + max_length]
24
+ target_chunk = token_ids[i + 1: i + max_length + 1]
25
+ self.input_ids.append(torch.tensor(input_chunk))
26
+ self.target_ids.append(torch.tensor(target_chunk))
27
+
28
+ def __len__(self):
29
+ return len(self.input_ids)
30
+
31
+ def __getitem__(self, idx):
32
+ return self.input_ids[idx], self.target_ids[idx]
33
+
34
+
35
+ def create_dataloader_v1(txt, batch_size=4, max_length=256,
36
+ stride=128, shuffle=True, drop_last=True, num_workers=0):
37
+ # Initialize the tokenizer
38
+ tokenizer = tiktoken.get_encoding("gpt2")
39
+
40
+ # Create dataset
41
+ dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
42
+
43
+ # Create dataloader
44
+ dataloader = DataLoader(
45
+ dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
46
+
47
+ return dataloader
48
+
49
+
50
+
51
+ class MultiHeadAttention(nn.Module):
52
+ def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
53
+ super().__init__()
54
+ assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
55
+
56
+ self.d_out = d_out
57
+ self.num_heads = num_heads
58
+ self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
59
+
60
+ self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
61
+ self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
62
+ self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
63
+ self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
64
+ self.dropout = nn.Dropout(dropout)
65
+ self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
66
+
67
+ def forward(self, x):
68
+ b, num_tokens, d_in = x.shape
69
+
70
+ keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
71
+ queries = self.W_query(x)
72
+ values = self.W_value(x)
73
+
74
+ # We implicitly split the matrix by adding a `num_heads` dimension
75
+ # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
76
+ keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
77
+ values = values.view(b, num_tokens, self.num_heads, self.head_dim)
78
+ queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
79
+
80
+ # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
81
+ keys = keys.transpose(1, 2)
82
+ queries = queries.transpose(1, 2)
83
+ values = values.transpose(1, 2)
84
+
85
+ # Compute scaled dot-product attention (aka self-attention) with a causal mask
86
+ attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
87
+
88
+ # Original mask truncated to the number of tokens and converted to boolean
89
+ mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
90
+
91
+ # Use the mask to fill attention scores
92
+ attn_scores.masked_fill_(mask_bool, -torch.inf)
93
+
94
+ attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
95
+ attn_weights = self.dropout(attn_weights)
96
+
97
+ # Shape: (b, num_tokens, num_heads, head_dim)
98
+ context_vec = (attn_weights @ values).transpose(1, 2)
99
+
100
+ # Combine heads, where self.d_out = self.num_heads * self.head_dim
101
+ context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
102
+ context_vec = self.out_proj(context_vec) # optional projection
103
+
104
+ return context_vec
105
+
106
+
107
+ class LayerNorm(nn.Module):
108
+ def __init__(self, emb_dim):
109
+ super().__init__()
110
+ self.eps = 1e-5
111
+ self.scale = nn.Parameter(torch.ones(emb_dim))
112
+ self.shift = nn.Parameter(torch.zeros(emb_dim))
113
+
114
+ def forward(self, x):
115
+ mean = x.mean(dim=-1, keepdim=True)
116
+ var = x.var(dim=-1, keepdim=True, unbiased=False)
117
+ norm_x = (x - mean) / torch.sqrt(var + self.eps)
118
+ return self.scale * norm_x + self.shift
119
+
120
+
121
+ class GELU(nn.Module):
122
+ def __init__(self):
123
+ super().__init__()
124
+
125
+ def forward(self, x):
126
+ return 0.5 * x * (1 + torch.tanh(
127
+ torch.sqrt(torch.tensor(2.0 / torch.pi)) *
128
+ (x + 0.044715 * torch.pow(x, 3))
129
+ ))
130
+
131
+
132
+ class FeedForward(nn.Module):
133
+ def __init__(self, cfg):
134
+ super().__init__()
135
+ self.layers = nn.Sequential(
136
+ nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
137
+ GELU(),
138
+ nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
139
+ )
140
+
141
+ def forward(self, x):
142
+ return self.layers(x)
143
+
144
+
145
+ class TransformerBlock(nn.Module):
146
+ def __init__(self, cfg):
147
+ super().__init__()
148
+ self.att = MultiHeadAttention(
149
+ d_in=cfg["emb_dim"],
150
+ d_out=cfg["emb_dim"],
151
+ context_length=cfg["context_length"],
152
+ num_heads=cfg["n_heads"],
153
+ dropout=cfg["drop_rate"],
154
+ qkv_bias=cfg["qkv_bias"])
155
+ self.ff = FeedForward(cfg)
156
+ self.norm1 = LayerNorm(cfg["emb_dim"])
157
+ self.norm2 = LayerNorm(cfg["emb_dim"])
158
+ self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
159
+
160
+ def forward(self, x):
161
+ # Shortcut connection for attention block
162
+ shortcut = x
163
+ x = self.norm1(x)
164
+ x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
165
+ x = self.drop_shortcut(x)
166
+ x = x + shortcut # Add the original input back
167
+
168
+ # Shortcut connection for feed-forward block
169
+ shortcut = x
170
+ x = self.norm2(x)
171
+ x = self.ff(x)
172
+ x = self.drop_shortcut(x)
173
+ x = x + shortcut # Add the original input back
174
+
175
+ return x
176
+
177
+
178
+ class GPTModel(nn.Module):
179
+ def __init__(self, cfg):
180
+ super().__init__()
181
+ self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
182
+ self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
183
+ self.drop_emb = nn.Dropout(cfg["drop_rate"])
184
+
185
+ self.trf_blocks = nn.Sequential(
186
+ *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
187
+
188
+ self.final_norm = LayerNorm(cfg["emb_dim"])
189
+ self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
190
+
191
+ def forward(self, in_idx):
192
+ batch_size, seq_len = in_idx.shape
193
+ tok_embeds = self.tok_emb(in_idx)
194
+ pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
195
+ x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
196
+ x = self.drop_emb(x)
197
+ x = self.trf_blocks(x)
198
+ x = self.final_norm(x)
199
+ logits = self.out_head(x)
200
+ return logits
201
+
202
+
203
+ def generate_text_simple(model, idx, max_new_tokens, context_size):
204
+ # idx is (B, T) array of indices in the current context
205
+ for _ in range(max_new_tokens):
206
+
207
+ # Crop current context if it exceeds the supported context size
208
+ # E.g., if LLM supports only 5 tokens, and the context size is 10
209
+ # then only the last 5 tokens are used as context
210
+ idx_cond = idx[:, -context_size:]
211
+
212
+ # Get the predictions
213
+ with torch.no_grad():
214
+ logits = model(idx_cond)
215
+
216
+ # Focus only on the last time step
217
+ # (batch, n_token, vocab_size) becomes (batch, vocab_size)
218
+ logits = logits[:, -1, :]
219
+
220
+ # Get the idx of the vocab entry with the highest logits value
221
+ idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
222
+
223
+ # Append sampled index to the running sequence
224
+ idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
225
+
226
+ return idx
227
+
228
+
229
+ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):
230
+
231
+ # For-loop is the same as before: Get logits, and only focus on last time step
232
+ for _ in range(max_new_tokens):
233
+ idx_cond = idx[:, -context_size:]
234
+ with torch.no_grad():
235
+ logits = model(idx_cond)
236
+ logits = logits[:, -1, :]
237
+
238
+ # New: Filter logits with top_k sampling
239
+ if top_k is not None:
240
+ # Keep only top_k values
241
+ top_logits, _ = torch.topk(logits, top_k)
242
+ min_val = top_logits[:, -1]
243
+ logits = torch.where(logits < min_val, torch.tensor(float("-inf")).to(logits.device), logits)
244
+
245
+ # New: Apply temperature scaling
246
+ if temperature > 0.0:
247
+ logits = logits / temperature
248
+
249
+ # New (not in book): numerical stability tip to get equivalent results on mps device
250
+ # subtract rowwise max before softmax
251
+ logits = logits - logits.max(dim=-1, keepdim=True).values
252
+
253
+ # Apply softmax to get probabilities
254
+ probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
255
+
256
+ # Sample from the distribution
257
+ idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
258
+
259
+ # Otherwise same as before: get idx of the vocab entry with the highest logits value
260
+ else:
261
+ idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)
262
+
263
+ if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified
264
+ break
265
+
266
+ # Same as before: append sampled index to the running sequence
267
+ idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)
268
+
269
+ return idx
270
+
271
+
272
+ def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
273
+ eval_freq, eval_iter, start_context, tokenizer):
274
+ # Initialize lists to track losses and tokens seen
275
+ train_losses, val_losses, track_tokens_seen = [], [], []
276
+ tokens_seen, global_step = 0, -1
277
+
278
+ # Main training loop
279
+ for epoch in range(num_epochs):
280
+ model.train() # Set model to training mode
281
+
282
+ for input_batch, target_batch in train_loader:
283
+ optimizer.zero_grad() # Reset loss gradients from previous batch iteration
284
+ loss = calc_loss_batch(input_batch, target_batch, model, device)
285
+ loss.backward() # Calculate loss gradients
286
+ optimizer.step() # Update model weights using loss gradients
287
+ tokens_seen += input_batch.numel()
288
+ global_step += 1
289
+
290
+ # Optional evaluation step
291
+ if global_step % eval_freq == 0:
292
+ train_loss, val_loss = evaluate_model(
293
+ model, train_loader, val_loader, device, eval_iter)
294
+ train_losses.append(train_loss)
295
+ val_losses.append(val_loss)
296
+ track_tokens_seen.append(tokens_seen)
297
+ print(f"Ep {epoch+1} (Step {global_step:06d}): "
298
+ f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
299
+
300
+ # Print a sample text after each epoch
301
+ generate_and_print_sample(
302
+ model, tokenizer, device, start_context
303
+ )
304
+
305
+ return train_losses, val_losses, track_tokens_seen
306
+
307
+
308
+ def evaluate_model(model, train_loader, val_loader, device, eval_iter):
309
+ model.eval()
310
+ with torch.no_grad():
311
+ train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
312
+ val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
313
+ model.train()
314
+ return train_loss, val_loss
315
+
316
+
317
+ def generate_and_print_sample(model, tokenizer, device, start_context):
318
+ model.eval()
319
+ context_size = model.pos_emb.weight.shape[0]
320
+ encoded = text_to_token_ids(start_context, tokenizer).to(device)
321
+ with torch.no_grad():
322
+ token_ids = generate_text_simple(
323
+ model=model, idx=encoded,
324
+ max_new_tokens=50, context_size=context_size
325
+ )
326
+ decoded_text = token_ids_to_text(token_ids, tokenizer)
327
+ print(decoded_text.replace("\n", " ")) # Compact print format
328
+ model.train()
329
+
330
+
331
+ def assign(left, right):
332
+ if left.shape != right.shape:
333
+ raise ValueError(f"Shape mismatch. Left: {left.shape}, Right: {right.shape}")
334
+ return torch.nn.Parameter(torch.tensor(right))
335
+
336
+
337
+ def text_to_token_ids(text, tokenizer):
338
+ encoded = tokenizer.encode(text, allowed_special={"<|endoftext|>"})
339
+ encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
340
+ return encoded_tensor
341
+
342
+
343
+ def token_ids_to_text(token_ids, tokenizer):
344
+ flat = token_ids.squeeze(0) # remove batch dimension
345
+ return tokenizer.decode(flat.tolist())
346
+
347
+
348
+ def calc_loss_batch(input_batch, target_batch, model, device):
349
+ input_batch, target_batch = input_batch.to(device), target_batch.to(device)
350
+ logits = model(input_batch)
351
+ loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
352
+ return loss
353
+
354
+
355
+ def calc_loss_loader(data_loader, model, device, num_batches=None):
356
+ total_loss = 0.
357
+ if len(data_loader) == 0:
358
+ return float("nan")
359
+ elif num_batches is None:
360
+ num_batches = len(data_loader)
361
+ else:
362
+ # Reduce the number of batches to match the total number of batches in the data loader
363
+ # if num_batches exceeds the number of batches in the data loader
364
+ num_batches = min(num_batches, len(data_loader))
365
+ for i, (input_batch, target_batch) in enumerate(data_loader):
366
+ if i < num_batches:
367
+ loss = calc_loss_batch(input_batch, target_batch, model, device)
368
+ total_loss += loss.item()
369
+ else:
370
+ break
371
+ return total_loss / num_batches
372
+
373
+
374
+ def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):
375
+ fig, ax1 = plt.subplots(figsize=(5, 3))
376
+
377
+ # Plot training and validation loss against epochs
378
+ ax1.plot(epochs_seen, train_losses, label="Training loss")
379
+ ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss")
380
+ ax1.set_xlabel("Epochs")
381
+ ax1.set_ylabel("Loss")
382
+ ax1.legend(loc="upper right")
383
+ ax1.xaxis.set_major_locator(MaxNLocator(integer=True)) # only show integer labels on x-axis
384
+
385
+ # Create a second x-axis for tokens seen
386
+ ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
387
+ ax2.plot(tokens_seen, train_losses, alpha=0) # Invisible plot for aligning ticks
388
+ ax2.set_xlabel("Tokens seen")
389
+
390
+ fig.tight_layout() # Adjust layout to make room
391
+ plt.savefig("loss-plot.pdf")
392
+ plt.show()
393
+
394
+ def main():
395
+ GPT_CONFIG_124M = {
396
+ "vocab_size": 50257, # Vocabulary size
397
+ "context_length": 1024, # Context length
398
+ "emb_dim": 768, # Embedding dimension
399
+ "n_heads": 12, # Number of attention heads
400
+ "n_layers": 12, # Number of layers
401
+ "drop_rate": 0.1, # Dropout rate
402
+ "qkv_bias": False # Query-Key-Value bias
403
+ }
404
+
405
+ torch.manual_seed(123)
406
+ model = GPTModel(GPT_CONFIG_124M)
407
+ model.eval() # disable dropout
408
+
409
+ start_context = "Hi, there"
410
+
411
+ tokenizer = tiktoken.get_encoding("gpt2")
412
+ encoded = tokenizer.encode(start_context)
413
+ encoded_tensor = torch.tensor(encoded).unsqueeze(0)
414
+
415
+ print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
416
+ print("\nInput text:", start_context)
417
+ print("Encoded input text:", encoded)
418
+ print("encoded_tensor.shape:", encoded_tensor.shape)
419
+
420
+ out = generate_text_simple(
421
+ model=model,
422
+ idx=encoded_tensor,
423
+ max_new_tokens=10,
424
+ context_size=GPT_CONFIG_124M["context_length"]
425
+ )
426
+ decoded_text = tokenizer.decode(out.squeeze(0).tolist())
427
+
428
+ print(f"\n\n{50*'='}\n{22*' '}OUT\n{50*'='}")
429
+ print("\nOutput:", out)
430
+ print("Output length:", len(out[0]))
431
+ print("Output text:", decoded_text)
432
+
433
+
434
+ if __name__ == "__main__":
435
+ main()