| """ |
| Trains a GPT to add n-digit numbers. |
| """ |
|
|
| import os |
| import sys |
| import json |
|
|
| import torch |
| from torch.utils.data import Dataset |
| from torch.utils.data.dataloader import DataLoader |
|
|
| from mingpt.model import GPT |
| from mingpt.trainer import Trainer |
| from mingpt.utils import set_seed, setup_logging, CfgNode as CN |
|
|
| |
|
|
| def get_config(): |
|
|
| C = CN() |
|
|
| |
| C.system = CN() |
| C.system.seed = 3407 |
| C.system.work_dir = './out/adder' |
|
|
| |
| C.data = AdditionDataset.get_default_config() |
|
|
| |
| C.model = GPT.get_default_config() |
| C.model.model_type = 'gpt-nano' |
|
|
| |
| C.trainer = Trainer.get_default_config() |
| C.trainer.learning_rate = 5e-4 |
|
|
| return C |
|
|
| |
|
|
| class AdditionDataset(Dataset): |
| """ |
| Creates n-digit addition problems. For example, if n=2, then an example |
| addition problem would be to add 85 + 50 = 135. This problem would be |
| represented as the following string for the GPT: |
| |
| "8550531" |
| |
| This is because: |
| - we are discarding the + and =, which are not necessary. We just encode the digits |
| of the input numbers concatenated together. |
| - the result 135 is encoded backwards to make the addition easier to learn for the |
| GPT model, because of how the addition algorithm works. |
| |
| As one more example, the problem 6 + 39 = 45 would be encoded as: |
| |
| "0639054" |
| |
| where you will notice that we are padding with zeros to make sure that we always |
| produce strings of the exact same size: n + n + (n + 1). When n=2, this is 7. |
| At test time, we will feed in an addition problem by giving the first 2n digits, |
| and hoping that the GPT model completes the sequence with the next (n+1) digits |
| correctly. |
| """ |
|
|
| @staticmethod |
| def get_default_config(): |
| C = CN() |
| C.ndigit = 2 |
| return C |
|
|
| def __init__(self, config, split): |
| self.config = config |
| self.split = split |
|
|
| |
| ndigit = self.config.ndigit |
| assert ndigit <= 3, "the lines below would be very memory inefficient, in future maybe refactor to support" |
| num = (10**ndigit)**2 |
| rng = torch.Generator() |
| rng.manual_seed(1337) |
| perm = torch.randperm(num, generator=rng) |
| num_test = min(int(num*0.2), 500) |
| self.ixes = perm[:num_test] if split == 'test' else perm[num_test:] |
|
|
| def get_vocab_size(self): |
| return 10 |
|
|
| def get_block_size(self): |
| |
| |
| |
| return 3*self.config.ndigit + 1 - 1 |
|
|
| def __len__(self): |
| return self.ixes.nelement() |
|
|
| def __getitem__(self, idx): |
| ndigit = self.config.ndigit |
| |
| idx = self.ixes[idx].item() |
| nd = 10**ndigit |
| a = idx // nd |
| b = idx % nd |
| |
| c = a + b |
| |
| astr = f'%0{ndigit}d' % a |
| bstr = f'%0{ndigit}d' % b |
| cstr = (f'%0{ndigit+1}d' % c)[::-1] |
| render = astr + bstr + cstr |
| dix = [int(s) for s in render] |
| |
| x = torch.tensor(dix[:-1], dtype=torch.long) |
| y = torch.tensor(dix[1:], dtype=torch.long) |
| y[:ndigit*2-1] = -1 |
| return x, y |
|
|
| |
|
|
| if __name__ == '__main__': |
|
|
| |
| config = get_config() |
| config.merge_from_args(sys.argv[1:]) |
| print(config) |
| setup_logging(config) |
| set_seed(config.system.seed) |
|
|
| |
| train_dataset = AdditionDataset(config.data, split='train') |
| test_dataset = AdditionDataset(config.data, split='test') |
|
|
| |
| config.model.vocab_size = train_dataset.get_vocab_size() |
| config.model.block_size = train_dataset.get_block_size() |
| model = GPT(config.model) |
|
|
| |
| trainer = Trainer(config.trainer, model, train_dataset) |
|
|
| |
| def eval_split(trainer, split, max_batches=None): |
| dataset = {'train':train_dataset, 'test':test_dataset}[split] |
| ndigit = config.data.ndigit |
| results = [] |
| mistakes_printed_already = 0 |
| factors = torch.tensor([[10**i for i in range(ndigit+1)][::-1]]).to(trainer.device) |
| loader = DataLoader(dataset, batch_size=100, num_workers=0, drop_last=False) |
| for b, (x, y) in enumerate(loader): |
| x = x.to(trainer.device) |
| |
| d1d2 = x[:, :ndigit*2] |
| |
| d1d2d3 = model.generate(d1d2, ndigit+1, do_sample=False) |
| |
| d3 = d1d2d3[:, -(ndigit+1):] |
| d3 = d3.flip(1) |
| |
| d1i = (d1d2[:,:ndigit] * factors[:,1:]).sum(1) |
| d2i = (d1d2[:,ndigit:ndigit*2] * factors[:,1:]).sum(1) |
| d3i_pred = (d3 * factors).sum(1) |
| d3i_gt = d1i + d2i |
| |
| correct = (d3i_pred == d3i_gt).cpu() |
| for i in range(x.size(0)): |
| results.append(int(correct[i])) |
| if not correct[i] and mistakes_printed_already < 5: |
| mistakes_printed_already += 1 |
| print("GPT claims that %d + %d = %d but gt is %d" % (d1i[i], d2i[i], d3i_pred[i], d3i_gt[i])) |
| if max_batches is not None and b+1 >= max_batches: |
| break |
| rt = torch.tensor(results, dtype=torch.float) |
| print("%s final score: %d/%d = %.2f%% correct" % (split, rt.sum(), len(results), 100*rt.mean())) |
| return rt.sum() |
|
|
| |
| top_score = 0 |
| def batch_end_callback(trainer): |
| global top_score |
|
|
| if trainer.iter_num % 10 == 0: |
| print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}") |
|
|
| if trainer.iter_num % 500 == 0: |
| |
| train_max_batches = {1: None, 2: None, 3: 5}[config.data.ndigit] |
| model.eval() |
| with torch.no_grad(): |
| train_score = eval_split(trainer, 'train', max_batches=train_max_batches) |
| test_score = eval_split(trainer, 'test', max_batches=None) |
| score = train_score + test_score |
| |
| if score > top_score: |
| top_score = score |
| print(f"saving model with new top score of {score}") |
| ckpt_path = os.path.join(config.system.work_dir, "model.pt") |
| torch.save(model.state_dict(), ckpt_path) |
| |
| model.train() |
|
|
| trainer.set_callback('on_batch_end', batch_end_callback) |
|
|
| |
| trainer.run() |
|
|