AYYasaswini commited on
Commit
1930848
·
verified ·
1 Parent(s): ffe1511

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +524 -48
app.py CHANGED
@@ -1,64 +1,540 @@
1
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
- from transformers import PreTrainedModel, AutoConfig
3
- from pathlib import Path
4
 
 
5
 
6
- def predict_with_custom_weights(model_name, text, weights_path):
7
- """
8
- Loads a model config, creates a model, loads custom weights, performs testing
9
- and prediction.
10
 
11
- Args:
12
- model_name: Name of the pre-trained model architecture (e.g., "bert-base-uncased").
13
- text: The text string for prediction.
14
- weights_path: Path to the directory containing your custom weights.
15
 
16
- Returns:
17
- A dictionary containing predictions and logits (optional).
18
- """
19
- # Load tokenizer and model config
20
- tokenizer = AutoTokenizer.from_pretrained(model_name)
21
- config = AutoConfig.from_pretrained(model_name)
22
 
23
- # Create empty model from config
24
- model = PreTrainedModel.from_config(config)
 
25
 
26
- # Load weights from your directory
27
- model.load_state_dict(torch.load(Path(weights_path) / "pytorch_model.bin"))
 
 
28
 
29
- # Tokenize the text
30
- inputs = tokenizer(text, padding="max_length", truncation=True, return_tensors="pt")
31
 
32
- # Set model to evaluation mode
33
- model.eval()
34
 
35
- # Make predictions
36
- with torch.no_grad():
37
- outputs = model(**inputs)
38
- logits = outputs.logits
39
- predictions = torch.argmax(logits, dim=-1)
40
 
41
- # Get label names (optional, if your model has labels)
42
- label_names = config.label_names if hasattr(config, "label_names") else None
43
 
44
- # Return results
45
- return {
46
- "predictions": predictions.item() if len(predictions) == 1 else predictions.tolist(),
47
- "logits": logits.squeeze().tolist() if label_names else None,
48
- "labels": label_names[predictions.item()] if label_names else None,
49
- }
50
 
 
51
 
52
- # Example usage (replace with your actual model name, text, and weights path)
53
- model_name = "bert-base-uncased"
54
- text = "This movie was absolutely fantastic!"
55
- weights_path = "transformer_weights.pth" # Replace with actual path
56
- results = predict_with_custom_weights(model_name, text, weights_path)
57
 
58
- print(f"Predicted label: {results['predictions']}")
 
 
 
 
59
 
60
- if results.get("labels"):
61
- print(f"Corresponding label name: {results['labels']}")
 
 
 
62
 
63
- if results.get("logits"):
64
- print(f"Logits: {results['logits']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """gpt_dev.ipynb
 
3
 
4
+ Automatically generated by Colab.
5
 
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1zxxLfIi8_EDLqYODY8TyNLpr8RTxV-Ct
 
 
8
 
9
+ ## Building a GPT
 
 
 
10
 
11
+ Companion notebook to the [Zero To Hero](https://karpathy.ai/zero-to-hero.html) video on GPT.
12
+ """
 
 
 
 
13
 
14
+ # We always start with a dataset to train on. Let's download the tiny shakespeare dataset
15
+ #!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
16
+ import subprocess
17
 
18
+ # URL of the file you want to download
19
+ url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
20
+ # Local path where the file will be saved
21
+ local_filename = "input.txt"
22
 
23
+ def download_file(url, local_filename):
24
+ subprocess.run(["wget", url, "-O", local_filename], check=True)
25
 
26
+ # Download the file
27
+ download_file(url, local_filename)
28
 
29
+ #from gpt_dev import BigramLanguageModel # Import your model class
 
 
 
 
30
 
31
+ # Your other code here
 
32
 
33
+ # read it in to inspect it
34
+ with open('input.txt', 'r', encoding='utf-8') as f:
35
+ text = f.read()
 
 
 
36
 
37
+ print("length of dataset in characters: ", len(text))
38
 
39
+ # let's look at the first 1000 characters
40
+ print(text[:1000])
 
 
 
41
 
42
+ # here are all the unique characters that occur in this text
43
+ chars = sorted(list(set(text)))
44
+ vocab_size = len(chars)
45
+ print(''.join(chars))
46
+ print(vocab_size)
47
 
48
+ # create a mapping from characters to integers
49
+ stoi = { ch:i for i,ch in enumerate(chars) }
50
+ itos = { i:ch for i,ch in enumerate(chars) }
51
+ encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
52
+ decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
53
 
54
+ print(encode("hii there"))
55
+ print(decode(encode("hii there")))
56
+
57
+ # let's now encode the entire text dataset and store it into a torch.Tensor
58
+ import torch # we use PyTorch: https://pytorch.org
59
+ data = torch.tensor(encode(text), dtype=torch.long)
60
+ print(data.shape, data.dtype)
61
+ print(data[:1000]) # the 1000 characters we looked at earier will to the GPT look like this
62
+
63
+ # Let's now split up the data into train and validation sets
64
+ n = int(0.9*len(data)) # first 90% will be train, rest val
65
+ train_data = data[:n]
66
+ val_data = data[n:]
67
+
68
+ block_size = 8
69
+ train_data[:block_size+1]
70
+
71
+ x = train_data[:block_size]
72
+ y = train_data[1:block_size+1]
73
+ for t in range(block_size):
74
+ context = x[:t+1]
75
+ target = y[t]
76
+ print(f"when input is {context} the target: {target}")
77
+
78
+ torch.manual_seed(1337)
79
+ batch_size = 4 # how many independent sequences will we process in parallel?
80
+ block_size = 8 # what is the maximum context length for predictions?
81
+
82
+ def get_batch(split):
83
+ # generate a small batch of data of inputs x and targets y
84
+ data = train_data if split == 'train' else val_data
85
+ ix = torch.randint(len(data) - block_size, (batch_size,))
86
+ x = torch.stack([data[i:i+block_size] for i in ix])
87
+ y = torch.stack([data[i+1:i+block_size+1] for i in ix])
88
+ return x, y
89
+
90
+ xb, yb = get_batch('train')
91
+ print('inputs:')
92
+ print(xb.shape)
93
+ print(xb)
94
+ print('targets:')
95
+ print(yb.shape)
96
+ print(yb)
97
+
98
+ print('----')
99
+
100
+ for b in range(batch_size): # batch dimension
101
+ for t in range(block_size): # time dimension
102
+ context = xb[b, :t+1]
103
+ target = yb[b,t]
104
+ print(f"when input is {context.tolist()} the target: {target}")
105
+
106
+ print(xb) # our input to the transformer
107
+
108
+ import torch
109
+ import torch.nn as nn
110
+ from torch.nn import functional as F
111
+ torch.manual_seed(1337)
112
+
113
+ class BigramLanguageModel(nn.Module):
114
+
115
+ def __init__(self, vocab_size):
116
+ super().__init__()
117
+ # each token directly reads off the logits for the next token from a lookup table
118
+ self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
119
+
120
+ def forward(self, idx, targets=None):
121
+
122
+ # idx and targets are both (B,T) tensor of integers
123
+ logits = self.token_embedding_table(idx) # (B,T,C)
124
+
125
+ if targets is None:
126
+ loss = None
127
+ else:
128
+ B, T, C = logits.shape
129
+ logits = logits.view(B*T, C)
130
+ targets = targets.view(B*T)
131
+ loss = F.cross_entropy(logits, targets)
132
+
133
+ return logits, loss
134
+
135
+ def generate(self, idx, max_new_tokens):
136
+ # idx is (B, T) array of indices in the current context
137
+ for _ in range(max_new_tokens):
138
+ # get the predictions
139
+ logits, loss = self(idx)
140
+ # focus only on the last time step
141
+ logits = logits[:, -1, :] # becomes (B, C)
142
+ # apply softmax to get probabilities
143
+ probs = F.softmax(logits, dim=-1) # (B, C)
144
+ # sample from the distribution
145
+ idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
146
+ # append sampled index to the running sequence
147
+ idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
148
+ return idx
149
+
150
+ m = BigramLanguageModel(vocab_size)
151
+
152
+
153
+ logits, loss = m(xb, yb)
154
+ print(logits.shape)
155
+ print(loss)
156
+
157
+ print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))
158
+
159
+
160
+
161
+ # create a PyTorch optimizer
162
+ optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
163
+
164
+ batch_size = 32
165
+ for steps in range(100): # increase number of steps for good results...
166
+
167
+ # sample a batch of data
168
+ xb, yb = get_batch('train')
169
+
170
+ # evaluate the loss
171
+ logits, loss = m(xb, yb)
172
+ optimizer.zero_grad(set_to_none=True)
173
+ loss.backward()
174
+ optimizer.step()
175
+
176
+ print(loss.item())
177
+
178
+ print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))
179
+
180
+ """## The mathematical trick in self-attention"""
181
+
182
+ # toy example illustrating how matrix multiplication can be used for a "weighted aggregation"
183
+ torch.manual_seed(42)
184
+ a = torch.tril(torch.ones(3, 3))
185
+ a = a / torch.sum(a, 1, keepdim=True)
186
+ b = torch.randint(0,10,(3,2)).float()
187
+ c = a @ b
188
+ print('a=')
189
+ print(a)
190
+ print('--')
191
+ print('b=')
192
+ print(b)
193
+ print('--')
194
+ print('c=')
195
+ print(c)
196
+
197
+ # consider the following toy example:
198
+
199
+ torch.manual_seed(1337)
200
+ B,T,C = 4,8,2 # batch, time, channels
201
+ x = torch.randn(B,T,C)
202
+ x.shape
203
+
204
+ # We want x[b,t] = mean_{i<=t} x[b,i]
205
+ xbow = torch.zeros((B,T,C))
206
+ for b in range(B):
207
+ for t in range(T):
208
+ xprev = x[b,:t+1] # (t,C)
209
+ xbow[b,t] = torch.mean(xprev, 0)
210
+
211
+ # version 2: using matrix multiply for a weighted aggregation
212
+ wei = torch.tril(torch.ones(T, T))
213
+ wei = wei / wei.sum(1, keepdim=True)
214
+ xbow2 = wei @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)
215
+ torch.allclose(xbow, xbow2)
216
+
217
+ # version 3: use Softmax
218
+ tril = torch.tril(torch.ones(T, T))
219
+ wei = torch.zeros((T,T))
220
+ wei = wei.masked_fill(tril == 0, float('-inf'))
221
+ wei = F.softmax(wei, dim=-1)
222
+ xbow3 = wei @ x
223
+ torch.allclose(xbow, xbow3)
224
+
225
+ # version 4: self-attention!
226
+ torch.manual_seed(1337)
227
+ B,T,C = 4,8,32 # batch, time, channels
228
+ x = torch.randn(B,T,C)
229
+
230
+ # let's see a single Head perform self-attention
231
+ head_size = 16
232
+ key = nn.Linear(C, head_size, bias=False)
233
+ query = nn.Linear(C, head_size, bias=False)
234
+ value = nn.Linear(C, head_size, bias=False)
235
+ k = key(x) # (B, T, 16)
236
+ q = query(x) # (B, T, 16)
237
+ wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)
238
+
239
+ tril = torch.tril(torch.ones(T, T))
240
+ #wei = torch.zeros((T,T))
241
+ wei = wei.masked_fill(tril == 0, float('-inf'))
242
+ wei = F.softmax(wei, dim=-1)
243
+
244
+ v = value(x)
245
+ out = wei @ v
246
+ #out = wei @ x
247
+
248
+ out.shape
249
+
250
+ wei[0]
251
+
252
+ """Notes:
253
+ - Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
254
+ - There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
255
+ - Each example across batch dimension is of course processed completely independently and never "talk" to each other
256
+ - In an "encoder" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
257
+ - "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
258
+ - "Scaled" attention additional divides `wei` by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below
259
+ """
260
+
261
+ k = torch.randn(B,T,head_size)
262
+ q = torch.randn(B,T,head_size)
263
+ wei = q @ k.transpose(-2, -1) * head_size**-0.5
264
+
265
+ k.var()
266
+
267
+ q.var()
268
+
269
+ wei.var()
270
+
271
+ torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)
272
+
273
+ torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])*8, dim=-1) # gets too peaky, converges to one-hot
274
+
275
+ class LayerNorm1d: # (used to be BatchNorm1d)
276
+
277
+ def __init__(self, dim, eps=1e-5, momentum=0.1):
278
+ self.eps = eps
279
+ self.gamma = torch.ones(dim)
280
+ self.beta = torch.zeros(dim)
281
+
282
+ def __call__(self, x):
283
+ # calculate the forward pass
284
+ xmean = x.mean(1, keepdim=True) # batch mean
285
+ xvar = x.var(1, keepdim=True) # batch variance
286
+ xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance
287
+ self.out = self.gamma * xhat + self.beta
288
+ return self.out
289
+
290
+ def parameters(self):
291
+ return [self.gamma, self.beta]
292
+
293
+ torch.manual_seed(1337)
294
+ module = LayerNorm1d(100)
295
+ x = torch.randn(32, 100) # batch size 32 of 100-dimensional vectors
296
+ x = module(x)
297
+ x.shape
298
+
299
+ x[:,0].mean(), x[:,0].std() # mean,std of one feature across all batch inputs
300
+
301
+ x[0,:].mean(), x[0,:].std() # mean,std of a single input from the batch, of its features
302
+
303
+ # French to English translation example:
304
+
305
+ # <--------- ENCODE ------------------><--------------- DECODE ----------------->
306
+ # les réseaux de neurones sont géniaux! <START> neural networks are awesome!<END>
307
+
308
+ """### Full finished code, for reference
309
+
310
+ You may want to refer directly to the git repo instead though.
311
+ """
312
+
313
+ import torch
314
+ import torch.nn as nn
315
+ from torch.nn import functional as F
316
+
317
+ # hyperparameters
318
+ batch_size = 16 # how many independent sequences will we process in parallel?
319
+ block_size = 32 # what is the maximum context length for predictions?
320
+ max_iters = 5000
321
+ #00
322
+ eval_interval = 100
323
+ learning_rate = 1e-3
324
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
325
+ eval_iters = 200
326
+ n_embd = 64
327
+ n_head = 4
328
+ n_layer = 4
329
+ dropout = 0.0
330
+ # ------------
331
+
332
+ torch.manual_seed(1337)
333
+
334
+ # wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
335
+ with open('input.txt', 'r', encoding='utf-8') as f:
336
+ text = f.read()
337
+
338
+ # here are all the unique characters that occur in this text
339
+ chars = sorted(list(set(text)))
340
+ vocab_size = len(chars)
341
+ # create a mapping from characters to integers
342
+ stoi = { ch:i for i,ch in enumerate(chars) }
343
+ itos = { i:ch for i,ch in enumerate(chars) }
344
+ encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
345
+ decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
346
+
347
+ # Train and test splits
348
+ data = torch.tensor(encode(text), dtype=torch.long)
349
+ n = int(0.9*len(data)) # first 90% will be train, rest val
350
+ train_data = data[:n]
351
+ val_data = data[n:]
352
+
353
+ # data loading
354
+ def get_batch(split):
355
+ # generate a small batch of data of inputs x and targets y
356
+ data = train_data if split == 'train' else val_data
357
+ ix = torch.randint(len(data) - block_size, (batch_size,))
358
+ x = torch.stack([data[i:i+block_size] for i in ix])
359
+ y = torch.stack([data[i+1:i+block_size+1] for i in ix])
360
+ x, y = x.to(device), y.to(device)
361
+ return x, y
362
+
363
+ @torch.no_grad()
364
+ def estimate_loss():
365
+ out = {}
366
+ model.eval()
367
+ for split in ['train', 'val']:
368
+ losses = torch.zeros(eval_iters)
369
+ for k in range(eval_iters):
370
+ X, Y = get_batch(split)
371
+ logits, loss = model(X, Y)
372
+ losses[k] = loss.item()
373
+ out[split] = losses.mean()
374
+ model.train()
375
+ return out
376
+
377
+ class Head(nn.Module):
378
+ """ one head of self-attention """
379
+
380
+ def __init__(self, head_size):
381
+ super().__init__()
382
+ self.key = nn.Linear(n_embd, head_size, bias=False)
383
+ self.query = nn.Linear(n_embd, head_size, bias=False)
384
+ self.value = nn.Linear(n_embd, head_size, bias=False)
385
+ self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
386
+
387
+ self.dropout = nn.Dropout(dropout)
388
+
389
+ def forward(self, x):
390
+ B,T,C = x.shape
391
+ k = self.key(x) # (B,T,C)
392
+ q = self.query(x) # (B,T,C)
393
+ # compute attention scores ("affinities")
394
+ wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
395
+ wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
396
+ wei = F.softmax(wei, dim=-1) # (B, T, T)
397
+ wei = self.dropout(wei)
398
+ # perform the weighted aggregation of the values
399
+ v = self.value(x) # (B,T,C)
400
+ out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
401
+ return out
402
+
403
+ class MultiHeadAttention(nn.Module):
404
+ """ multiple heads of self-attention in parallel """
405
+
406
+ def __init__(self, num_heads, head_size):
407
+ super().__init__()
408
+ self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
409
+ self.proj = nn.Linear(n_embd, n_embd)
410
+ self.dropout = nn.Dropout(dropout)
411
+
412
+ def forward(self, x):
413
+ out = torch.cat([h(x) for h in self.heads], dim=-1)
414
+ out = self.dropout(self.proj(out))
415
+ return out
416
+
417
+ class FeedFoward(nn.Module):
418
+ """ a simple linear layer followed by a non-linearity """
419
+
420
+ def __init__(self, n_embd):
421
+ super().__init__()
422
+ self.net = nn.Sequential(
423
+ nn.Linear(n_embd, 4 * n_embd),
424
+ nn.ReLU(),
425
+ nn.Linear(4 * n_embd, n_embd),
426
+ nn.Dropout(dropout),
427
+ )
428
+
429
+ def forward(self, x):
430
+ return self.net(x)
431
+
432
+ class Block(nn.Module):
433
+ """ Transformer block: communication followed by computation """
434
+
435
+ def __init__(self, n_embd, n_head):
436
+ # n_embd: embedding dimension, n_head: the number of heads we'd like
437
+ super().__init__()
438
+ head_size = n_embd // n_head
439
+ self.sa = MultiHeadAttention(n_head, head_size)
440
+ self.ffwd = FeedFoward(n_embd)
441
+ self.ln1 = nn.LayerNorm(n_embd)
442
+ self.ln2 = nn.LayerNorm(n_embd)
443
+
444
+ def forward(self, x):
445
+ x = x + self.sa(self.ln1(x))
446
+ x = x + self.ffwd(self.ln2(x))
447
+ return x
448
+
449
+ # super simple bigram model
450
+ class BigramLanguageModel(nn.Module):
451
+
452
+ def __init__(self):
453
+ super().__init__()
454
+ # each token directly reads off the logits for the next token from a lookup table
455
+ self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
456
+ self.position_embedding_table = nn.Embedding(block_size, n_embd)
457
+ self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
458
+ self.ln_f = nn.LayerNorm(n_embd) # final layer norm
459
+ self.lm_head = nn.Linear(n_embd, vocab_size)
460
+
461
+ def forward(self, idx, targets=None):
462
+ B, T = idx.shape
463
+
464
+ # idx and targets are both (B,T) tensor of integers
465
+ tok_emb = self.token_embedding_table(idx) # (B,T,C)
466
+ pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
467
+ x = tok_emb + pos_emb # (B,T,C)
468
+ x = self.blocks(x) # (B,T,C)
469
+ x = self.ln_f(x) # (B,T,C)
470
+ logits = self.lm_head(x) # (B,T,vocab_size)
471
+
472
+ if targets is None:
473
+ loss = None
474
+ else:
475
+ B, T, C = logits.shape
476
+ logits = logits.view(B*T, C)
477
+ targets = targets.view(B*T)
478
+ loss = F.cross_entropy(logits, targets)
479
+
480
+ return logits, loss
481
+
482
+ def generate(self, idx, max_new_tokens):
483
+ # idx is (B, T) array of indices in the current context
484
+ for _ in range(max_new_tokens):
485
+ # crop idx to the last block_size tokens
486
+ idx_cond = idx[:, -block_size:]
487
+ # get the predictions
488
+ logits, loss = self(idx_cond)
489
+ # focus only on the last time step
490
+ logits = logits[:, -1, :] # becomes (B, C)
491
+ # apply softmax to get probabilities
492
+ probs = F.softmax(logits, dim=-1) # (B, C)
493
+ # sample from the distribution
494
+ idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
495
+ # append sampled index to the running sequence
496
+ idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
497
+ return idx
498
+
499
+ model = BigramLanguageModel()
500
+
501
+ m = model.to(device)
502
+ # print the number of parameters in the model
503
+ print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')
504
+ #torch.save(model, 'transformer_model.pth')
505
+
506
+ # create a PyTorch optimizer
507
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
508
+
509
+ for iter in range(max_iters):
510
+
511
+ # every once in a while evaluate the loss on train and val sets
512
+ if iter % eval_interval == 0 or iter == max_iters - 1:
513
+ losses = estimate_loss()
514
+ print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
515
+
516
+ # sample a batch of data
517
+ xb, yb = get_batch('train')
518
+
519
+ # evaluate the loss
520
+ logits, loss = model(xb, yb)
521
+ optimizer.zero_grad(set_to_none=True)
522
+ loss.backward()
523
+ optimizer.step()
524
+
525
+
526
+ # Load the saved weights into the model
527
+ #model.load_state_dict(torch.load('transformer_weights.pth'))
528
+ #torch.save(model.state_dict(), 'transformer_weights.pth')
529
+ #print("Model weights saved successfully.")
530
+
531
+ #import torch
532
+
533
+ # Load the entire model
534
+ #model = torch.load('transformer_model.pth')
535
+ #model.eval() # Set the model to evaluation mode
536
+
537
+ #print("Entire model loaded successfully.")
538
+ # generate from the model
539
+ context = torch.zeros((1, 1), dtype=torch.long, device=device)
540
+ print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))