Spaces:
Sleeping
Sleeping
File size: 4,409 Bytes
d00fb47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import torch, torch.nn as nn, torch.optim as optim, torch.nn.functional as F, wandb, time
batch_size = 64
max_len = 256
d_model = 384
n_layer = 6
n_head = 6
d_q = int(d_model / n_head)
dropout = 0.2
vocab_size = 65
max_iters = 5000
eval_interval = 500
learning_rate = 3e-4
eval_iters = 200
"""
---- Device ----
"""
if torch.cuda.is_available():
device = torch.device('cuda')
print("Using CUDA (GPU)")
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = torch.device('mps')
print("Using MPS (Apple Silicon GPU)")
else:
device = torch.device('cpu')
print("Using device's CPU")
"""
--- WandB Integration ---
"""
wandb.init(
project="nano-model-shakesphere-training",
config={
"learning_rate": learning_rate,
"architecture": "decoder-only-model",
"dataset": "tinyshakesphere",
"d_model": d_model,
"n_layer": n_layer,
"n_head": n_head,
"max_iters": max_iters,
"dropout": dropout
}
)
with open('input.txt', 'r', encoding='utf-8') as f:
text = f.read()
chars = sorted(list(set(text))) # --> All unique characters within the text
vocab_size = len(chars) # 65 different characters in text
stoi = {}
itos = {}
for i in range(len(chars)):
stoi[chars[i]] = i # Convert strings to ints
itos[i] = chars[i] # Convert ints to strings
# Take a string, and output its characters indices in a list
def encoder(s):
res = []
for char in s:
res.append(stoi[char])
return res
# Take a list of indices and output a string
def decoder(l):
res = ""
for i in l:
res += itos[i]
return res
data = torch.tensor(encoder(text), dtype=torch.long) # --> Same shape as length, i.e., number of characters
n = int(0.9 * len(data))
train_data = data[:n] # 90% of text
val_data = data[n:] # 10% of text
def get_batch(split):
if split.lower() == 'train':
data = train_data
else:
data = val_data
ix = torch.randint(len(data)-max_len, (batch_size,)) # Generate batch_size=64 random numbers from 0 to len(data)-max_len
x = torch.stack([data[i:i+max_len] for i in ix]) # Generates 250 ids from that random number and stacks batch_size by rows, so shape[64, 256]
y = torch.stack([data[i+1:i+max_len+1] for i in ix]) # This is done in order to test teh real y with the later predicted y by the model using cross entropy and update weights
return x.to(device), y.to(device)
"""
--- Model Training ---
"""
if __name__ == "__main__":
from model import Model
model = Model()
m = model.to(device)
optimizer = optim.AdamW(
model.parameters(),
lr=learning_rate
)
@torch.no_grad
def estimate_loss():
out = {}
model.eval()
for split in ['train', 'val']:
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
X, Y = get_batch(split)
logits, loss = model(X, Y)
losses[k] = loss.item()
out[split] = losses.mean()
model.train()
return out
for iter in range(max_iters):
if iter % eval_interval == 0 or iter == max_iters - 1:
losses = estimate_loss()
print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
wandb.log({
"iter": iter,
"train/loss": losses['train'],
"val/loss": losses['val'],
"lr": learning_rate
})
iter_start = time.time()
xb, yb = get_batch("train")
logits, loss = model(xb, yb)
optimizer.zero_grad(set_to_none=True) # Required for new resetting as after iter, new set of batches will come
loss.backward() # Required for back passing, it gives you the amount of steepness and gradient
optimizer.step() # Required for actually nudging in that given direction (Taking a plausible value of lr right now but it influences a lot)
iter_time = time.time() - iter_start
print(f"Iteration {iter} completed in {iter_time:.2f} seconds")
wandb.log({"iter_time": iter_time})
wandb.finish()
print("Training finished. Saving model state...")
torch.save(model.state_dict(), 'nanogpt_model.pth')
print("Model saved to nanogpt_model.pth") |