Kartheekb7 commited on
Commit
c818c76
·
verified ·
1 Parent(s): 204eaac

Update gpt.py

Browse files
Files changed (1) hide show
  1. gpt.py +1 -28
gpt.py CHANGED
@@ -18,7 +18,6 @@ dropout = 0.2
18
 
19
  torch.manual_seed(1337)
20
 
21
- # wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
22
  with open('input.txt', 'r', encoding='utf-8') as f:
23
  text = f.read()
24
 
@@ -196,30 +195,4 @@ class GPTLanguageModel(nn.Module):
196
  return idx
197
 
198
  model = GPTLanguageModel()
199
- m = model.to(device)
200
- # print the number of parameters in the model
201
- print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')
202
-
203
- # create a PyTorch optimizer
204
- optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
205
-
206
- for iter in range(max_iters):
207
-
208
- # every once in a while evaluate the loss on train and val sets
209
- if iter % eval_interval == 0 or iter == max_iters - 1:
210
- losses = estimate_loss()
211
- print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
212
-
213
- # sample a batch of data
214
- xb, yb = get_batch('train')
215
-
216
- # evaluate the loss
217
- logits, loss = model(xb, yb)
218
- optimizer.zero_grad(set_to_none=True)
219
- loss.backward()
220
- optimizer.step()
221
-
222
- # generate from the model
223
- context = torch.zeros((1, 1), dtype=torch.long, device=device)
224
- print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))
225
- #open('more.txt', 'w').write(decode(m.generate(context, max_new_tokens=10000)[0].tolist()))
 
18
 
19
  torch.manual_seed(1337)
20
 
 
21
  with open('input.txt', 'r', encoding='utf-8') as f:
22
  text = f.read()
23
 
 
195
  return idx
196
 
197
  model = GPTLanguageModel()
198
+ model = model.to(device)