File size: 5,224 Bytes
0775134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python3
import argparse

import torch
import numpy as np

from util import CharacterTokenizer, BPETokenizer, Dataset
from gpt import GPTLanguageModel
from loss import estimate_loss
from metrics import Metrics


def train(data, model, tokenizer, steps, report_frequency, lr, save_path, checkpoint_freq, warmup_steps=0):
  device = next(model.parameters()).device
  optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
  metrics = Metrics()

  if warmup_steps > 0:
    from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
    warmup    = LinearLR(optimizer, start_factor=0.01, end_factor=1.0, total_iters=warmup_steps)
    decay     = CosineAnnealingLR(optimizer, T_max=max(1, steps - warmup_steps), eta_min=lr * 0.01)
    scheduler = SequentialLR(optimizer, schedulers=[warmup, decay], milestones=[warmup_steps])
  else:
    scheduler = None

  for step in range(steps):
    xb, yb = data.get_batch('train', device)

    _, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if scheduler:
      scheduler.step()
    if step % report_frequency == 0 or step == steps - 1:
      losses = estimate_loss(data, model)
      print(f"Step {step}, train loss: {losses['train']:.4f} val loss: {losses['val']:.4f}")

      metrics_dict = metrics(data, model, tokenizer)
      print("Metrics:", metrics_dict)

      print()

    if checkpoint_freq and step > 0 and step % checkpoint_freq == 0:
      ckpt_path = save_path.replace('.pth', f'_step{step}.pth')
      torch.save(model.state_dict(), ckpt_path)
      print(f"Checkpoint saved: {ckpt_path}")

parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, default="input.txt")
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--tokenizer", type=str, default="char", choices=["char", "bpe"],
                    help="Tokenizer type: char (default) or bpe (subword sentencepiece)")
parser.add_argument("--bpe-vocab-size", type=int, default=2000,
                    help="Vocab size when training a BPE tokenizer (ignored for char)")
parser.add_argument("--bpe-model", type=str, default="bpe.model",
                    help="Path to save/load the sentencepiece .model file")

parser.add_argument("--context-size", type=int, default=256)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--n-embd", type=int, default=384)
parser.add_argument("--n-head", type=int, default=6)
parser.add_argument("--n-layer", type=int, default=6)
parser.add_argument("--dropout", type=float, default=0.2)

subparsers = parser.add_subparsers(dest="command", required=True)

train_parser = subparsers.add_parser("train")
train_parser.add_argument("--save", type=str, default="model.pth")
train_parser.add_argument("--steps", type=int, default=5000)
train_parser.add_argument("--report", type=int, default=500)
train_parser.add_argument("--lr", type=float, default=1e-3)
train_parser.add_argument("--checkpoint-freq", type=int, default=0, help="Save a checkpoint every N steps (0 = disabled)")
train_parser.add_argument("--warmup-steps", type=int, default=0, help="LR warmup steps before cosine decay (0 = constant LR)")

eval_parser = subparsers.add_parser("eval")
eval_parser.add_argument("--load", type=str, default="model.pth")
eval_parser.add_argument("--prompt", type=str)
eval_parser.add_argument("--token-count", type=int, default=300)

args = parser.parse_args()

if args.seed:
  torch.manual_seed(args.seed)
batch_size = args.batch_size
context_size = args.context_size
n_embd = args.n_embd
n_head = args.n_head
n_layer = args.n_layer
dropout = args.dropout

# replace this with ur hw backend if needed
device = 'cuda' if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
if device == "cpu":
  print("WARNING: Running on cpu!")

with open(args.input, "r") as f:
  content = f.read()

if args.tokenizer == "bpe":
  tokenizer = BPETokenizer(args.input, model_path=args.bpe_model, vocab_size=args.bpe_vocab_size)
else:
  tokenizer = CharacterTokenizer(content)

data = torch.tensor(tokenizer.encode(content), dtype=torch.long)
dataset = Dataset(data, context_size, batch_size)

model = GPTLanguageModel(tokenizer.vocab_size, n_embd, context_size, n_head, n_layer)
model = model.to(device)

print(f"Total parameters: {sum(p.numel() for p in model.parameters()) / 1e6}")
print(f"Using device: {device}")
print()

if args.command == "eval":
  print("=" * 20, "INFERENCE", "=" * 20)
  model.load_state_dict(torch.load(args.load))
  model.eval()
elif args.command == "train":
  print("=" * 20, "TRAINING", "=" * 20)
  train(dataset, model, tokenizer, args.steps, args.report, args.lr, args.save, args.checkpoint_freq, args.warmup_steps)
  torch.save(model.state_dict(), args.save)
  print("=" * 50)

context = torch.zeros((1, 1), dtype=torch.long, device=device)
max_tokens = 300
if args.command == "eval":
  if args.prompt is not None:
    context = torch.tensor([tokenizer.encode(args.prompt)], dtype=torch.long, device=device)
  max_tokens = args.token_count
print(
  tokenizer.decode(
    model.generate(start_idx=context, number_of_tokens=max_tokens, use_cache=True)[0].tolist()
  )
)