import random import torch from .data import build_training_text from .model import SmallGPTModel from .tokenizer import WordTokenizer def set_seed(seed: int): random.seed(seed) torch.manual_seed(seed) def create_model_and_tokenizer(config, extra_text=""): text = build_training_text(extra_text) tokenizer = WordTokenizer().fit(text) encoded = tokenizer.encode(text, add_bos=True, add_eos=True) encoded = torch.tensor(encoded, dtype=torch.long) model = SmallGPTModel( vocab_size=tokenizer.vocab_size, block_size=config.block_size, d_model=config.d_model, n_heads=config.n_heads, n_layers=config.n_layers, dropout=config.dropout, ) return model, tokenizer, encoded def build_batch(encoded, block_size, batch_size): max_start = max(1, len(encoded) - block_size - 1) starts = torch.randint(0, max_start, (batch_size,)) x = torch.stack([encoded[start : start + block_size] for start in starts]) y = torch.stack([encoded[start + 1 : start + block_size + 1] for start in starts]) return x, y def train_model(model, encoded, config, steps): optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) model.train() losses = [] for _ in range(steps): xb, yb = build_batch(encoded, config.block_size, config.batch_size) _, loss = model(xb, targets=yb) optimizer.zero_grad() loss.backward() optimizer.step() losses.append(float(loss.item())) return losses