Spaces:
Running
Running
Commit
·
5e3f56c
1
Parent(s):
ca18dfd
use pytorch serialization instead of pickle
Browse files- app.py +2 -3
- model.pkl → model.pt +2 -2
- train_gpt_openwebtext.py +3 -4
app.py
CHANGED
|
@@ -2,7 +2,6 @@ import streamlit as st
|
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
from torch.nn import functional as F
|
| 5 |
-
import pickle
|
| 6 |
import os
|
| 7 |
|
| 8 |
st.title('LLM from scratch Demo')
|
|
@@ -169,11 +168,11 @@ encode = lambda s: [string_to_int[ch] for ch in s]
|
|
| 169 |
decode = lambda x: ''.join([int_to_string[i] for i in x])
|
| 170 |
|
| 171 |
|
| 172 |
-
model_pickle_path = './model.
|
| 173 |
|
| 174 |
st.write('loading model parameters...')
|
| 175 |
with open(model_pickle_path, 'rb') as f:
|
| 176 |
-
model =
|
| 177 |
st.write('model loaded successfully!')
|
| 178 |
|
| 179 |
prompt = ''
|
|
|
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
from torch.nn import functional as F
|
|
|
|
| 5 |
import os
|
| 6 |
|
| 7 |
st.title('LLM from scratch Demo')
|
|
|
|
| 168 |
decode = lambda x: ''.join([int_to_string[i] for i in x])
|
| 169 |
|
| 170 |
|
| 171 |
+
model_pickle_path = './model.pt'
|
| 172 |
|
| 173 |
st.write('loading model parameters...')
|
| 174 |
with open(model_pickle_path, 'rb') as f:
|
| 175 |
+
model = torch.load(f, map_location=device)
|
| 176 |
st.write('model loaded successfully!')
|
| 177 |
|
| 178 |
prompt = ''
|
model.pkl → model.pt
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:04e95f8e46dd7b7b894d288f3c2b75bb0a535fb266960803587a9f552e6b5a73
|
| 3 |
+
size 160274578
|
train_gpt_openwebtext.py
CHANGED
|
@@ -3,7 +3,6 @@ import torch.nn as nn
|
|
| 3 |
from torch.nn import functional as F
|
| 4 |
import mmap
|
| 5 |
import random
|
| 6 |
-
import pickle
|
| 7 |
import os
|
| 8 |
|
| 9 |
|
|
@@ -218,11 +217,11 @@ class GPTLanguageModel(nn.Module):
|
|
| 218 |
|
| 219 |
model = GPTLanguageModel(vocab_size).to(device)
|
| 220 |
|
| 221 |
-
model_pickle_path = './model.
|
| 222 |
if os.path.exists(model_pickle_path):
|
| 223 |
print('loading model parameters...')
|
| 224 |
with open(model_pickle_path, 'rb') as f:
|
| 225 |
-
model =
|
| 226 |
print('loaded successfully!')
|
| 227 |
# create a PyTorch optimizer
|
| 228 |
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
|
@@ -243,5 +242,5 @@ for iter in range(max_iters):
|
|
| 243 |
print(loss.item())
|
| 244 |
|
| 245 |
with open(model_pickle_path, 'wb') as f:
|
| 246 |
-
|
| 247 |
print('model saved')
|
|
|
|
| 3 |
from torch.nn import functional as F
|
| 4 |
import mmap
|
| 5 |
import random
|
|
|
|
| 6 |
import os
|
| 7 |
|
| 8 |
|
|
|
|
| 217 |
|
| 218 |
model = GPTLanguageModel(vocab_size).to(device)
|
| 219 |
|
| 220 |
+
model_pickle_path = './model.pt'
|
| 221 |
if os.path.exists(model_pickle_path):
|
| 222 |
print('loading model parameters...')
|
| 223 |
with open(model_pickle_path, 'rb') as f:
|
| 224 |
+
model = torch.load(f, map_location=device)
|
| 225 |
print('loaded successfully!')
|
| 226 |
# create a PyTorch optimizer
|
| 227 |
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
|
|
|
| 242 |
print(loss.item())
|
| 243 |
|
| 244 |
with open(model_pickle_path, 'wb') as f:
|
| 245 |
+
torch.save(model, f)
|
| 246 |
print('model saved')
|