File size: 1,545 Bytes
79078fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import random

import torch

from .data import build_training_text
from .model import SmallGPTModel
from .tokenizer import WordTokenizer


def set_seed(seed: int):
    random.seed(seed)
    torch.manual_seed(seed)


def create_model_and_tokenizer(config, extra_text=""):
    text = build_training_text(extra_text)
    tokenizer = WordTokenizer().fit(text)
    encoded = tokenizer.encode(text, add_bos=True, add_eos=True)
    encoded = torch.tensor(encoded, dtype=torch.long)
    model = SmallGPTModel(
        vocab_size=tokenizer.vocab_size,
        block_size=config.block_size,
        d_model=config.d_model,
        n_heads=config.n_heads,
        n_layers=config.n_layers,
        dropout=config.dropout,
    )
    return model, tokenizer, encoded


def build_batch(encoded, block_size, batch_size):
    max_start = max(1, len(encoded) - block_size - 1)
    starts = torch.randint(0, max_start, (batch_size,))
    x = torch.stack([encoded[start : start + block_size] for start in starts])
    y = torch.stack([encoded[start + 1 : start + block_size + 1] for start in starts])
    return x, y


def train_model(model, encoded, config, steps):
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
    model.train()
    losses = []

    for _ in range(steps):
        xb, yb = build_batch(encoded, config.block_size, config.batch_size)
        _, loss = model(xb, targets=yb)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(float(loss.item()))

    return losses