Spaces:
Running
Running
Commit
·
462d56c
1
Parent(s):
15ba535
update train_gpt_openwebtext script
Browse files- train_gpt_openwebtext.py +0 -31
train_gpt_openwebtext.py
CHANGED
|
@@ -1,16 +1,3 @@
|
|
| 1 |
-
# ---
|
| 2 |
-
# jupyter:
|
| 3 |
-
# jupytext:
|
| 4 |
-
# text_representation:
|
| 5 |
-
# extension: .py
|
| 6 |
-
# format_name: percent
|
| 7 |
-
# format_version: '1.3'
|
| 8 |
-
# jupytext_version: 1.3.4
|
| 9 |
-
# kernelspec:
|
| 10 |
-
# display_name: Python 3
|
| 11 |
-
# language: python
|
| 12 |
-
# name: python3
|
| 13 |
-
# ---
|
| 14 |
import torch
|
| 15 |
import torch.nn as nn
|
| 16 |
from torch.nn import functional as F
|
|
@@ -20,7 +7,6 @@ import pickle
|
|
| 20 |
import os
|
| 21 |
|
| 22 |
|
| 23 |
-
# %%
|
| 24 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 25 |
print(device)
|
| 26 |
block_size = 128
|
|
@@ -33,10 +19,8 @@ n_head = 8
|
|
| 33 |
n_layer = 8
|
| 34 |
dropout = 0.2
|
| 35 |
|
| 36 |
-
# %%
|
| 37 |
if not os.path.exists("./openwebtext/vocab.txt") or not os.path.exists("./openwebtext/train_split.txt") or not os.path.exists("./openwebtext/val_split.txt"):
|
| 38 |
raise Exception("Please run extract.py first")
|
| 39 |
-
# %%
|
| 40 |
chars = ""
|
| 41 |
with open("./openwebtext/vocab.txt", 'r', encoding='utf-8') as f:
|
| 42 |
text = f.read()
|
|
@@ -44,17 +28,11 @@ with open("./openwebtext/vocab.txt", 'r', encoding='utf-8') as f:
|
|
| 44 |
|
| 45 |
vocab_size = len(chars)
|
| 46 |
|
| 47 |
-
# %%
|
| 48 |
-
print(f"Vocab size: {vocab_size}")
|
| 49 |
-
print(f"Text length: {len(text)}")
|
| 50 |
-
|
| 51 |
-
# %%
|
| 52 |
string_to_int = {ch: i for i, ch in enumerate(chars)}
|
| 53 |
int_to_string = {i: ch for i, ch in enumerate(chars)}
|
| 54 |
|
| 55 |
encode = lambda s: [string_to_int[ch] for ch in s]
|
| 56 |
decode = lambda x: ''.join([int_to_string[i] for i in x])
|
| 57 |
-
# %%
|
| 58 |
# memory map for using small snippets of text from a single file of any size
|
| 59 |
def get_random_chunk(split):
|
| 60 |
filename = "./openwebtext/train_split.txt" if split == 'train' else "./openwebtext/val_split.txt"
|
|
@@ -85,7 +63,6 @@ def get_batch(split):
|
|
| 85 |
x, y = x.to(device), y.to(device)
|
| 86 |
return x, y
|
| 87 |
|
| 88 |
-
# %%
|
| 89 |
@torch.no_grad()
|
| 90 |
def estimate_loss():
|
| 91 |
out = {}
|
|
@@ -100,7 +77,6 @@ def estimate_loss():
|
|
| 100 |
model.train()
|
| 101 |
return out
|
| 102 |
|
| 103 |
-
# %%
|
| 104 |
|
| 105 |
class Head(nn.Module):
|
| 106 |
""" one head of self-attention """
|
|
@@ -248,7 +224,6 @@ if os.path.exists(model_pickle_path):
|
|
| 248 |
with open(model_pickle_path, 'rb') as f:
|
| 249 |
model = pickle.load(f)
|
| 250 |
print('loaded successfully!')
|
| 251 |
-
# %%
|
| 252 |
# create a PyTorch optimizer
|
| 253 |
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
| 254 |
|
|
@@ -270,9 +245,3 @@ print(loss.item())
|
|
| 270 |
with open(model_pickle_path, 'wb') as f:
|
| 271 |
pickle.dump(model, f)
|
| 272 |
print('model saved')
|
| 273 |
-
|
| 274 |
-
# %%
|
| 275 |
-
prompt = 'Hello! Can you see me?'
|
| 276 |
-
context = torch.tensor(encode(prompt), dtype=torch.long, device=device)
|
| 277 |
-
generated_chars = decode(model.generate(context.unsqueeze(0), max_new_tokens=100)[0].tolist())
|
| 278 |
-
print(generated_chars)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
from torch.nn import functional as F
|
|
|
|
| 7 |
import os
|
| 8 |
|
| 9 |
|
|
|
|
| 10 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 11 |
print(device)
|
| 12 |
block_size = 128
|
|
|
|
| 19 |
n_layer = 8
|
| 20 |
dropout = 0.2
|
| 21 |
|
|
|
|
| 22 |
if not os.path.exists("./openwebtext/vocab.txt") or not os.path.exists("./openwebtext/train_split.txt") or not os.path.exists("./openwebtext/val_split.txt"):
|
| 23 |
raise Exception("Please run extract.py first")
|
|
|
|
| 24 |
chars = ""
|
| 25 |
with open("./openwebtext/vocab.txt", 'r', encoding='utf-8') as f:
|
| 26 |
text = f.read()
|
|
|
|
| 28 |
|
| 29 |
vocab_size = len(chars)
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
string_to_int = {ch: i for i, ch in enumerate(chars)}
|
| 32 |
int_to_string = {i: ch for i, ch in enumerate(chars)}
|
| 33 |
|
| 34 |
encode = lambda s: [string_to_int[ch] for ch in s]
|
| 35 |
decode = lambda x: ''.join([int_to_string[i] for i in x])
|
|
|
|
| 36 |
# memory map for using small snippets of text from a single file of any size
|
| 37 |
def get_random_chunk(split):
|
| 38 |
filename = "./openwebtext/train_split.txt" if split == 'train' else "./openwebtext/val_split.txt"
|
|
|
|
| 63 |
x, y = x.to(device), y.to(device)
|
| 64 |
return x, y
|
| 65 |
|
|
|
|
| 66 |
@torch.no_grad()
|
| 67 |
def estimate_loss():
|
| 68 |
out = {}
|
|
|
|
| 77 |
model.train()
|
| 78 |
return out
|
| 79 |
|
|
|
|
| 80 |
|
| 81 |
class Head(nn.Module):
|
| 82 |
""" one head of self-attention """
|
|
|
|
| 224 |
with open(model_pickle_path, 'rb') as f:
|
| 225 |
model = pickle.load(f)
|
| 226 |
print('loaded successfully!')
|
|
|
|
| 227 |
# create a PyTorch optimizer
|
| 228 |
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
| 229 |
|
|
|
|
| 245 |
with open(model_pickle_path, 'wb') as f:
|
| 246 |
pickle.dump(model, f)
|
| 247 |
print('model saved')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|