ibrahimmkhalid commited on
Commit
5e3f56c
·
1 Parent(s): ca18dfd

use pytorch serialization instead of pickle

Browse files
Files changed (3) hide show
  1. app.py +2 -3
  2. model.pkl → model.pt +2 -2
  3. train_gpt_openwebtext.py +3 -4
app.py CHANGED
@@ -2,7 +2,6 @@ import streamlit as st
2
  import torch
3
  import torch.nn as nn
4
  from torch.nn import functional as F
5
- import pickle
6
  import os
7
 
8
  st.title('LLM from scratch Demo')
@@ -169,11 +168,11 @@ encode = lambda s: [string_to_int[ch] for ch in s]
169
  decode = lambda x: ''.join([int_to_string[i] for i in x])
170
 
171
 
172
- model_pickle_path = './model.pkl'
173
 
174
  st.write('loading model parameters...')
175
  with open(model_pickle_path, 'rb') as f:
176
- model = pickle.load(f)
177
  st.write('model loaded successfully!')
178
 
179
  prompt = ''
 
2
  import torch
3
  import torch.nn as nn
4
  from torch.nn import functional as F
 
5
  import os
6
 
7
  st.title('LLM from scratch Demo')
 
168
  decode = lambda x: ''.join([int_to_string[i] for i in x])
169
 
170
 
171
+ model_pickle_path = './model.pt'
172
 
173
  st.write('loading model parameters...')
174
  with open(model_pickle_path, 'rb') as f:
175
+ model = torch.load(f, map_location=device)
176
  st.write('model loaded successfully!')
177
 
178
  prompt = ''
model.pkl → model.pt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:59989e9551bb95c5c24630505acca58e99a9608218081b2fbea732f536090517
3
- size 160269240
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04e95f8e46dd7b7b894d288f3c2b75bb0a535fb266960803587a9f552e6b5a73
3
+ size 160274578
train_gpt_openwebtext.py CHANGED
@@ -3,7 +3,6 @@ import torch.nn as nn
3
  from torch.nn import functional as F
4
  import mmap
5
  import random
6
- import pickle
7
  import os
8
 
9
 
@@ -218,11 +217,11 @@ class GPTLanguageModel(nn.Module):
218
 
219
  model = GPTLanguageModel(vocab_size).to(device)
220
 
221
- model_pickle_path = './model.pkl'
222
  if os.path.exists(model_pickle_path):
223
  print('loading model parameters...')
224
  with open(model_pickle_path, 'rb') as f:
225
- model = pickle.load(f)
226
  print('loaded successfully!')
227
  # create a PyTorch optimizer
228
  optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
@@ -243,5 +242,5 @@ for iter in range(max_iters):
243
  print(loss.item())
244
 
245
  with open(model_pickle_path, 'wb') as f:
246
- pickle.dump(model, f)
247
  print('model saved')
 
3
  from torch.nn import functional as F
4
  import mmap
5
  import random
 
6
  import os
7
 
8
 
 
217
 
218
  model = GPTLanguageModel(vocab_size).to(device)
219
 
220
+ model_pickle_path = './model.pt'
221
  if os.path.exists(model_pickle_path):
222
  print('loading model parameters...')
223
  with open(model_pickle_path, 'rb') as f:
224
+ model = torch.load(f, map_location=device)
225
  print('loaded successfully!')
226
  # create a PyTorch optimizer
227
  optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
 
242
  print(loss.item())
243
 
244
  with open(model_pickle_path, 'wb') as f:
245
+ torch.save(model, f)
246
  print('model saved')