| import os |
| import time |
| import math |
| import pickle |
| import inspect |
| import json |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import argparse |
| from contextlib import nullcontext |
| from dataclasses import dataclass |
| from q_learning_agent import QLearningAgent |
|
|
| |
| class LayerNorm(nn.Module): |
| """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" |
|
|
| def __init__(self, ndim, bias): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(ndim)) |
| self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None |
|
|
| def forward(self, input): |
| return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) |
|
|
|
|
| class CausalSelfAttention(nn.Module): |
|
|
| def __init__(self, config): |
| super().__init__() |
| assert config.n_embd % config.n_head == 0 |
| |
| self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) |
| |
| self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) |
| |
| self.attn_dropout = nn.Dropout(config.dropout) |
| self.resid_dropout = nn.Dropout(config.dropout) |
| self.n_head = config.n_head |
| self.n_embd = config.n_embd |
| self.dropout = config.dropout |
| |
| self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") |
| if not self.flash: |
| print( |
| "WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0" |
| ) |
| |
| self.register_buffer( |
| "bias", |
| torch.tril(torch.ones(config.block_size, config.block_size)).view( |
| 1, 1, config.block_size, config.block_size |
| ), |
| ) |
|
|
| def forward(self, x): |
| B, T, C = ( |
| x.size() |
| ) |
|
|
| |
| q, k, v = self.c_attn(x).split(self.n_embd, dim=2) |
| k = k.view(B, T, self.n_head, C // self.n_head).transpose( |
| 1, 2 |
| ) |
| q = q.view(B, T, self.n_head, C // self.n_head).transpose( |
| 1, 2 |
| ) |
| v = v.view(B, T, self.n_head, C // self.n_head).transpose( |
| 1, 2 |
| ) |
|
|
| |
| if self.flash: |
| |
| y = torch.nn.functional.scaled_dot_product_attention( |
| q, |
| k, |
| v, |
| attn_mask=None, |
| dropout_p=self.dropout if self.training else 0, |
| is_causal=True, |
| ) |
| else: |
| |
| att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) |
| att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) |
| att = F.softmax(att, dim=-1) |
| att = self.attn_dropout(att) |
| y = att @ v |
| y = ( |
| y.transpose(1, 2).contiguous().view(B, T, C) |
| ) |
|
|
| |
| y = self.resid_dropout(self.c_proj(y)) |
| return y |
|
|
|
|
| class MLP(nn.Module): |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) |
| self.gelu = nn.GELU() |
| self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def forward(self, x): |
| x = self.c_fc(x) |
| x = self.gelu(x) |
| x = self.c_proj(x) |
| x = self.dropout(x) |
| return x |
|
|
|
|
| class Block(nn.Module): |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) |
| self.attn = CausalSelfAttention(config) |
| self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) |
| self.mlp = MLP(config) |
|
|
| def forward(self, x): |
| x = x + self.attn(self.ln_1(x)) |
| x = x + self.mlp(self.ln_2(x)) |
| return x |
|
|
|
|
| @dataclass |
| class GPTConfig: |
| block_size: int = 1024 |
| vocab_size: int = ( |
| 50304 |
| ) |
| n_layer: int = 12 |
| n_head: int = 12 |
| n_embd: int = 768 |
| dropout: float = 0.0 |
| bias: bool = ( |
| True |
| ) |
|
|
|
|
| class GPT(nn.Module): |
|
|
| def __init__(self, config): |
| super().__init__() |
| assert config.vocab_size is not None |
| assert config.block_size is not None |
| self.config = config |
|
|
| self.transformer = nn.ModuleDict( |
| dict( |
| wte=nn.Embedding(config.vocab_size, config.n_embd), |
| wpe=nn.Embedding(config.block_size, config.n_embd), |
| drop=nn.Dropout(config.dropout), |
| h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), |
| ln_f=LayerNorm(config.n_embd, bias=config.bias), |
| ) |
| ) |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
| |
| |
| |
| |
| self.transformer.wte.weight = ( |
| self.lm_head.weight |
| ) |
|
|
| |
| self.apply(self._init_weights) |
| |
| for pn, p in self.named_parameters(): |
| if pn.endswith("c_proj.weight"): |
| torch.nn.init.normal_( |
| p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) |
| ) |
|
|
| |
| print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,)) |
|
|
| def get_num_params(self, non_embedding=True): |
| """ |
| Return the number of parameters in the model. |
| For non-embedding count (default), the position embeddings get subtracted. |
| The token embeddings would too, except due to the parameter sharing these |
| params are actually used as weights in the final layer, so we include them. |
| """ |
| n_params = sum(p.numel() for p in self.parameters()) |
| if non_embedding: |
| n_params -= self.transformer.wpe.weight.numel() |
| return n_params |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| torch.nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
| def forward(self, idx, targets=None): |
| device = idx.device |
| b, t = idx.size() |
| assert ( |
| t <= self.config.block_size |
| ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" |
| pos = torch.arange(0, t, dtype=torch.long, device=device) |
|
|
| |
| tok_emb = self.transformer.wte(idx) |
| pos_emb = self.transformer.wpe(pos) |
| x = self.transformer.drop(tok_emb + pos_emb) |
| for block in self.transformer.h: |
| x = block(x) |
| x = self.transformer.ln_f(x) |
|
|
| if targets is not None: |
| |
| logits = self.lm_head(x) |
| loss = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1 |
| ) |
| else: |
| |
| logits = self.lm_head( |
| x[:, [-1], :] |
| ) |
| loss = None |
|
|
| return logits, loss |
|
|
| def crop_block_size(self, block_size): |
| |
| |
| |
| assert block_size <= self.config.block_size |
| self.config.block_size = block_size |
| self.transformer.wpe.weight = nn.Parameter( |
| self.transformer.wpe.weight[:block_size] |
| ) |
| for block in self.transformer.h: |
| if hasattr(block.attn, "bias"): |
| block.attn.bias = block.attn.bias[:, :, :block_size, :block_size] |
|
|
| def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): |
| |
| param_dict = {pn: p for pn, p in self.named_parameters()} |
| |
| param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} |
| |
| |
| decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] |
| nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] |
| optim_groups = [ |
| {"params": decay_params, "weight_decay": weight_decay}, |
| {"params": nodecay_params, "weight_decay": 0.0}, |
| ] |
| num_decay_params = sum(p.numel() for p in decay_params) |
| num_nodecay_params = sum(p.numel() for p in nodecay_params) |
| print( |
| f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters" |
| ) |
| print( |
| f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters" |
| ) |
| |
| fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters |
| use_fused = fused_available and device_type == "cuda" |
| extra_args = dict(fused=True) if use_fused else dict() |
| optimizer = torch.optim.AdamW( |
| optim_groups, lr=learning_rate, betas=betas, **extra_args |
| ) |
| print(f"using fused AdamW: {use_fused}") |
|
|
| return optimizer |
|
|
| @torch.no_grad() |
| def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): |
| """ |
| Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete |
| the sequence max_new_tokens times, feeding the predictions back into the model each time. |
| Most likely you'll want to make sure to be in model.eval() mode of operation for this. |
| """ |
| for _ in range(max_new_tokens): |
| |
| idx_cond = ( |
| idx |
| if idx.size(1) <= self.config.block_size |
| else idx[:, -self.config.block_size :] |
| ) |
| |
| logits, _ = self(idx_cond) |
| |
| logits = logits[:, -1, :] / temperature |
| |
| if top_k is not None: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = -float("Inf") |
| |
| probs = F.softmax(logits, dim=-1) |
| |
| idx_next = torch.multinomial(probs, num_samples=1) |
| |
| idx = torch.cat((idx, idx_next), dim=1) |
|
|
| return idx |
|
|
|
|
| |
| def train(dataset="shakespeare_char", out_dir="run_0", seed_offset=0): |
| |
| |
| |
| gradient_accumulation_steps = 1 |
| batch_size = 64 if dataset == "shakespeare_char" else 32 |
| block_size = 256 |
| |
| eval_interval = 250 if dataset == "shakespeare_char" else 1000 |
| log_interval = 10 if dataset == "shakespeare_char" else 100 |
| eval_iters = 200 |
| eval_only = False |
| always_save_checkpoint = ( |
| False |
| ) |
| never_save_checkpoint = True |
| |
| n_layer = 6 |
| n_head = 6 |
| n_embd = 384 |
| dropout = 0.2 |
| bias = False |
| |
| learning_rate = ( |
| 2e-3 if dataset == "shakespeare_char" else 1e-3 |
| ) |
| max_iters = 5000 if dataset == "shakespeare_char" else 100000 |
| weight_decay = 1e-1 |
| beta1 = 0.9 |
| beta2 = 0.99 |
| grad_clip = 1.0 |
| |
| decay_lr = True |
| warmup_iters = 100 if dataset == "shakespeare_char" else 200 |
| lr_decay_iters = max_iters |
| min_lr = 1e-4 if dataset == "shakespeare_char" else 5e-5 |
| |
| backend = "nccl" |
| |
| device = "cuda" |
| dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' |
| compile = True |
|
|
|
|
| |
| |
| master_process = True |
| tokens_per_iter = gradient_accumulation_steps * batch_size * block_size |
| print(f"tokens per iteration will be: {tokens_per_iter:,}") |
|
|
| if master_process: |
| os.makedirs(out_dir, exist_ok=True) |
| torch.manual_seed(1337 + seed_offset) |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| device_type = "cuda" if "cuda" in device else "cpu" |
| |
| ptdtype = { |
| "float32": torch.float32, |
| "bfloat16": torch.bfloat16, |
| "float16": torch.float16, |
| }[dtype] |
| ctx = ( |
| nullcontext() |
| if device_type == "cpu" |
| else torch.amp.autocast(device_type=device_type, dtype=ptdtype) |
| ) |
|
|
| |
| data_dir = os.path.join("../../../data", dataset) |
|
|
|
|
| def get_batch(split): |
| |
| |
| if split == "train": |
| data = np.memmap(os.path.join(data_dir, "train.bin"), dtype=np.uint16, mode="r") |
| else: |
| data = np.memmap(os.path.join(data_dir, "val.bin"), dtype=np.uint16, mode="r") |
| ix = torch.randint(len(data) - block_size, (batch_size,)) |
| x = torch.stack( |
| [torch.from_numpy((data[i : i + block_size]).astype(np.int64)) for i in ix] |
| ) |
| y = torch.stack( |
| [ |
| torch.from_numpy((data[i + 1 : i + 1 + block_size]).astype(np.int64)) |
| for i in ix |
| ] |
| ) |
| if device_type == "cuda": |
| |
| x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to( |
| device, non_blocking=True |
| ) |
| else: |
| x, y = x.to(device), y.to(device) |
| return x, y |
|
|
| iter_num = 0 |
| best_val_loss = 1e9 |
|
|
| |
| meta_path = os.path.join(data_dir, "meta.pkl") |
| meta_vocab_size = None |
| if os.path.exists(meta_path): |
| with open(meta_path, "rb") as f: |
| meta = pickle.load(f) |
| meta_vocab_size = meta["vocab_size"] |
| print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") |
|
|
| |
| q_agent = QLearningAgent(lr=0.1, gamma=0.9, epsilon=0.1) |
| model_args = dict( |
| n_layer=n_layer, |
| n_head=n_head, |
| n_embd=n_embd, |
| block_size=block_size, |
| bias=bias, |
| vocab_size=None, |
| dropout=dropout, |
| ) |
| |
| print("Initializing a new model from scratch") |
| |
| if meta_vocab_size is None: |
| print( |
| "defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)" |
| ) |
| model_args["vocab_size"] = meta_vocab_size if meta_vocab_size is not None else 50304 |
| gptconf = GPTConfig(**model_args) |
| model = GPT(gptconf) |
| |
| if block_size < model.config.block_size: |
| model.crop_block_size(block_size) |
| model_args["block_size"] = ( |
| block_size |
| ) |
| model.to(device) |
|
|
| |
| scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16")) |
|
|
| |
| optimizer = model.configure_optimizers( |
| weight_decay, learning_rate, (beta1, beta2), device_type |
| ) |
| checkpoint = None |
|
|
| |
| if compile: |
| print("compiling the model... (takes a ~minute)") |
| unoptimized_model = model |
| model = torch.compile(model) |
|
|
|
|
| |
| @torch.no_grad() |
| def estimate_loss(): |
| out = {} |
| model.eval() |
| for split in ["train", "val"]: |
| losses = torch.zeros(eval_iters) |
| for k in range(eval_iters): |
| X, Y = get_batch(split) |
| with ctx: |
| logits, loss = model(X, Y) |
| losses[k] = loss.item() |
| out[split] = losses.mean() |
| model.train() |
| return out |
|
|
|
|
| |
| def get_lr(it): |
| |
| if it < warmup_iters: |
| return learning_rate * it / warmup_iters |
| |
| if it > lr_decay_iters: |
| return min_lr |
| |
| decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) |
| assert 0 <= decay_ratio <= 1 |
| coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) |
| return min_lr + coeff * (learning_rate - min_lr) |
|
|
|
|
| |
| val_log_info = [] |
| train_log_info = [] |
|
|
| |
| X, Y = get_batch("train") |
| og_t0 = time.time() |
| t0 = time.time() |
| local_iter_num = 0 |
| raw_model = model |
| while True: |
|
|
| |
| lr = get_lr(iter_num) if decay_lr else learning_rate |
| for param_group in optimizer.param_groups: |
| param_group["lr"] = lr |
|
|
| |
| if iter_num % eval_interval == 0 and master_process: |
| losses = estimate_loss() |
| print( |
| f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}" |
| ) |
| |
| state = q_agent.get_state(losses["val"], lr) |
| action = q_agent.choose_action(state) |
| lr = max(min_lr, lr * (1 + action * 0.1)) |
| next_state = q_agent.get_state(losses["val"], lr) |
| reward = -losses["val"] |
| q_agent.update_q_values(state, action, reward, next_state) |
|
|
| val_log_info.append( |
| { |
| "iter": iter_num, |
| "train/loss": losses["train"].item(), |
| "val/loss": losses["val"].item(), |
| "lr": lr, |
| } |
| ) |
| if losses["val"] < best_val_loss or always_save_checkpoint: |
| best_val_loss = losses["val"] |
| if iter_num > 0 and not never_save_checkpoint: |
| checkpoint = { |
| "model": raw_model.state_dict(), |
| "optimizer": optimizer.state_dict(), |
| "model_args": model_args, |
| "iter_num": iter_num, |
| "best_val_loss": best_val_loss, |
| } |
| print(f"saving checkpoint to {out_dir}") |
| torch.save(checkpoint, os.path.join(out_dir, "ckpt.pt")) |
| if iter_num == 0 and eval_only: |
| break |
|
|
| |
| |
| for micro_step in range(gradient_accumulation_steps): |
| with ctx: |
| logits, loss = model(X, Y) |
| loss = ( |
| loss / gradient_accumulation_steps |
| ) |
| |
| X, Y = get_batch("train") |
| |
| scaler.scale(loss).backward() |
| |
| if grad_clip != 0.0: |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) |
| |
| scaler.step(optimizer) |
| scaler.update() |
| |
| optimizer.zero_grad(set_to_none=True) |
|
|
| |
| t1 = time.time() |
| dt = t1 - t0 |
| t0 = t1 |
| if iter_num % log_interval == 0 and master_process: |
| |
| |
| lossf = loss.item() * gradient_accumulation_steps |
| print( |
| f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms" |
| ) |
| train_log_info.append( |
| { |
| "iter": iter_num, |
| "loss": lossf, |
| "time": dt*1000, |
| } |
| ) |
| iter_num += 1 |
| local_iter_num += 1 |
|
|
| |
| if iter_num > max_iters: |
| break |
|
|
| print("training done") |
| print(f"Best validation loss: {best_val_loss}") |
| print(f"Total train time: {(time.time() - og_t0) / 60:.2f} mins") |
|
|
| final_info = { |
| "final_train_loss": lossf, |
| "best_val_loss": best_val_loss.item(), |
| "total_train_time": time.time() - og_t0, |
| } |
|
|
| |
|
|
| |
| start = " " |
| num_samples = 10 |
| max_new_tokens = 500 |
| temperature = 0.8 |
| top_k = 200 |
|
|
| |
| assert os.path.exists(meta_path), "meta.pkl not found, please run training script first" |
| print(f"Loading meta from {meta_path}...") |
| with open(meta_path, 'rb') as f: |
| meta = pickle.load(f) |
| stoi, itos = meta['stoi'], meta['itos'] |
| encode = lambda s: [stoi[c] for c in s] |
| decode = lambda l: ''.join([itos[i] for i in l]) |
|
|
| |
| if start.startswith('FILE:'): |
| with open(start[5:], 'r', encoding='utf-8') as f: |
| start = f.read() |
| start_ids = encode(start) |
| x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) |
|
|
| |
| model.eval() |
| results = [] |
| with torch.no_grad(): |
| with ctx: |
| for k in range(num_samples): |
| start_time = time.time() |
| y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) |
| end_time = time.time() |
| |
| generated_text = decode(y[0].tolist()) |
| inference_time = end_time - start_time |
| tokens_per_second = max_new_tokens / inference_time |
| |
| print(f"Sample {k+1}:") |
| print(generated_text) |
| print(f"Inference time: {inference_time:.2f} seconds") |
| print(f"Tokens per second: {tokens_per_second:.2f}") |
| print('---------------') |
| |
| results.append({ |
| "sample_id": k+1, |
| "generated_text": generated_text, |
| "inference_time": inference_time, |
| "tokens_per_second": tokens_per_second |
| }) |
|
|
| |
| avg_tokens_per_second = sum(r['tokens_per_second'] for r in results) / len(results) |
| print(f"Average tokens per second: {avg_tokens_per_second:.2f}") |
|
|
| final_info["avg_inference_tokens_per_second"] = avg_tokens_per_second |
|
|
| with open(os.path.join(out_dir, f"final_info_{dataset}_{seed_offset}.json"), "w") as f: |
| json.dump(final_info, f) |
| return final_info, train_log_info, val_log_info |
|
|
| parser = argparse.ArgumentParser(description='Run experiment') |
| parser.add_argument('--out_dir', type=str, default='run_0', help='Output directory') |
| args = parser.parse_args() |
|
|
| if __name__ == "__main__": |
| num_seeds = { |
| "shakespeare_char": 3, |
| "enwik8": 1, |
| "text8": 1, |
| } |
|
|
| out_dir = args.out_dir |
| all_results = {} |
| final_infos = {} |
| for dataset in ["shakespeare_char", "enwik8", "text8"]: |
| final_info_list = [] |
| for seed_offset in range(num_seeds[dataset]): |
| final_info, train_info, val_info = train(dataset, out_dir, seed_offset) |
| all_results[f"{dataset}_{seed_offset}_final_info"] = final_info |
| all_results[f"{dataset}_{seed_offset}_train_info"] = train_info |
| all_results[f"{dataset}_{seed_offset}_val_info"] = val_info |
| final_info_list.append(final_info) |
| final_info_dict = {k: [d[k] for d in final_info_list] for k in final_info_list[0].keys()} |
| means = {f"{k}_mean": np.mean(v) for k, v in final_info_dict.items()} |
| stderrs = {f"{k}_stderr": np.std(v) / len(v) for k, v in final_info_dict.items()} |
| final_infos[dataset] = { |
| "means": means, |
| "stderrs": stderrs, |
| "final_info_dict": final_info_dict, |
| } |
|
|
| with open(os.path.join(out_dir, "final_info.json"), "w") as f: |
| json.dump(final_infos, f) |
|
|
| with open(os.path.join(out_dir, "all_results.npy"), "wb") as f: |
| np.save(f, all_results) |
|
|