Spaces:
Running
Running
File size: 3,583 Bytes
2ef951d e0646b5 2ef951d 5649c37 2ef951d 5649c37 2ef951d 5649c37 2ef951d 5649c37 2ef951d 5649c37 2ef951d 5649c37 2ef951d 5649c37 2ef951d 5649c37 2ef951d 5649c37 2ef951d 5649c37 2ef951d 5649c37 2ef951d 5649c37 2ef951d 5649c37 2ef951d 5649c37 2ef951d 5649c37 2ef951d 5649c37 5e3f56c 5649c37 2ef951d 5649c37 2ef951d 5649c37 2ef951d 5649c37 5e3f56c 5649c37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import torch
import mmap
import random
import os
from GPTLanguageModelClass import *
block_size = hyperparams.block_size
batch_size = hyperparams.batch_size
max_iters = hyperparams.max_iters
learning_rate = hyperparams.learning_rate
eval_every = hyperparams.eval_every
n_embd = hyperparams.n_embd
n_head = hyperparams.n_head
n_layer = hyperparams.n_layer
dropout = hyperparams.dropout
device = hyperparams.device
print(device)
if (
not os.path.exists("./vocab.txt")
or not os.path.exists("./openwebtext/train_split.txt")
or not os.path.exists("./openwebtext/val_split.txt")
):
raise Exception("Please run extract.py first")
chars = ""
with open("./vocab.txt", "r", encoding="utf-8") as f:
text = f.read()
chars = sorted(list(set(text)))
vocab_size = len(chars)
string_to_int = {ch: i for i, ch in enumerate(chars)}
int_to_string = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [string_to_int[ch] for ch in s]
decode = lambda x: "".join([int_to_string[i] for i in x])
# memory map for using small snippets of text from a single file of any size
def get_random_chunk(split):
filename = (
"./openwebtext/train_split.txt"
if split == "train"
else "./openwebtext/val_split.txt"
)
with open(filename, "rb") as f:
with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm:
# Determine the file size and a random position to start reading
file_size = len(mm)
start_pos = random.randint(0, (file_size) - block_size * batch_size)
# Seek to the random position and read the block of text
mm.seek(start_pos)
block = mm.read(block_size * batch_size - 1)
# Decode the block to a string, ignoring any invalid byte sequences
decoded_block = block.decode("utf-8", errors="ignore").replace("\r", "")
# Train and test splits
data = torch.tensor(encode(decoded_block), dtype=torch.long)
return data
def get_batch(split):
data = get_random_chunk(split)
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([data[i : i + block_size] for i in ix])
y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
x, y = x.to(device), y.to(device)
return x, y
@torch.no_grad()
def estimate_loss():
out = {}
model.eval()
for split in ["train", "val"]:
losses = torch.zeros(eval_every)
for k in range(eval_every):
X, Y = get_batch(split)
logits, loss = model(X, Y)
losses[k] = loss.item()
out[split] = losses.mean()
model.train()
return out
model = GPTLanguageModel(vocab_size).to(device)
model_pickle_path = "./model.pt"
if os.path.exists(model_pickle_path):
print("loading model parameters...")
with open(model_pickle_path, "rb") as f:
model = torch.load(f, map_location=device)
print("loaded successfully!")
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
for iter in range(max_iters):
if iter % eval_every == 0:
losses = estimate_loss()
print(
f"step: {iter}, train loss: {losses['train']:.3f}, val loss: {losses['val']:.3f}"
)
# sample a batch of data
xb, yb = get_batch("train")
# evaluate the loss
logits, loss = model.forward(xb, yb)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
print(loss.item())
with open(model_pickle_path, "wb") as f:
torch.save(model, f)
print("model saved")
|