Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| import math | |
| import time | |
| import os | |
| import argparse | |
| import signal | |
| import copy | |
| import urllib.request | |
| from datetime import datetime | |
| from contextlib import nullcontext | |
| def ensure_data(data_path='archive/train.csv'): | |
| """Download Tiny Shakespeare if not present.""" | |
| if not os.path.exists(data_path): | |
| os.makedirs(os.path.dirname(data_path), exist_ok=True) | |
| url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' | |
| print(f'Downloading Tiny Shakespeare from {url}...') | |
| urllib.request.urlretrieve(url, data_path) | |
| print(f'Saved to {data_path}') | |
| return data_path | |
| # ========== Hyperparameters | |
| batch_size = 64 | |
| block_size = 256 # We will predict the 257 token on the basis of the 256 before that now! | |
| max_iters = 5000 | |
| eval_interval = 1000 | |
| learning_rate = 3e-4 # Bring down the learning rate | |
| device = 'cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu') | |
| eval_iters = 50 | |
| n_embd = 384 # 384 / 6 = 64 | |
| n_head = 6 | |
| n_layer = 6 | |
| dropout = 0.1 | |
| label_smoothing = 0.05 | |
| ema_decay = 0.999 | |
| use_ema_for_eval = True | |
| use_sdpa = True | |
| use_compile = True | |
| # LR schedule (cosine with warmup) | |
| warmup_iters = 200 | |
| lr_decay_iters = max_iters | |
| min_lr = 1e-4 | |
| # ========================== | |
| torch.manual_seed(1337) | |
| # Dataset | |
| data_path = ensure_data('archive/train.csv') | |
| with open(data_path, 'r', encoding='UTF-8') as f: | |
| text = f.read() | |
| # Tokenizer | |
| chars = sorted(list(set(text))) | |
| vocab_size = len(chars) | |
| lookup_table_in = { ch:i for i,ch in enumerate(chars)} | |
| lookup_table_out = { i:ch for i,ch in enumerate(chars)} | |
| encode = lambda s: [lookup_table_in[c] for c in s] # Encoder | |
| decode = lambda l: ''.join([lookup_table_out[i] for i in l]) # Decoder | |
| data = torch.tensor(encode(text), dtype=torch.long) | |
| # Train and Test Split | |
| n = int(0.9*len(data)) | |
| train_data = data[:n] | |
| val_data = data[n:] | |
| # Data Loading | |
| def get_batch(split): | |
| data = train_data if split == 'train' else val_data | |
| ix = torch.randint(len(data) - block_size, (batch_size,)) | |
| x = torch.stack([data[i:i+block_size] for i in ix]) | |
| y = torch.stack([data[i+1:i+block_size+1] for i in ix]) | |
| x, y = x.to(device), y.to(device) | |
| return x, y | |
| # Loss | |
| def estimate_loss(): | |
| out = {} | |
| eval_model = (ema_model if (use_ema_for_eval and 'ema_model' in globals() and ema_model is not None) else model) | |
| eval_model.eval() | |
| for split in ['train', 'val']: | |
| losses = torch.zeros(eval_iters) | |
| for k in range(eval_iters): | |
| X, Y = get_batch(split) | |
| with ctx: | |
| logits, loss = eval_model(X, Y) | |
| losses[k] = loss.item() | |
| out[split] = losses.mean() | |
| model.train() | |
| return out | |
| # =========== Transformer Components: | |
| class Head(nn.Module): | |
| """ one head of self-attention """ | |
| def __init__(self, head_size): | |
| super().__init__() | |
| self.key = nn.Linear(n_embd, head_size, bias=False) | |
| self.query = nn.Linear(n_embd, head_size, bias=False) | |
| self.value = nn.Linear(n_embd, head_size, bias=False) | |
| # Keep for reference but SDPA handles causal mask internally | |
| # self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| B, T, C = x.shape | |
| k = self.key(x) # (B, T, hs) | |
| q = self.query(x) # (B, T, hs) | |
| v = self.value(x) # (B, T, hs) | |
| if use_sdpa: | |
| # Use PyTorch SDPA; add a head dimension of size 1 | |
| qh = q.unsqueeze(1) # (B, 1, T, hs) | |
| kh = k.unsqueeze(1) # (B, 1, T, hs) | |
| vh = v.unsqueeze(1) # (B, 1, T, hs) | |
| out = F.scaled_dot_product_attention( | |
| qh, kh, vh, | |
| attn_mask=None, | |
| dropout_p=dropout if self.training else 0.0, | |
| is_causal=True, | |
| ) # (B, 1, T, hs) | |
| out = out.squeeze(1) # (B, T, hs) | |
| else: | |
| wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5 | |
| # Causal mask | |
| mask = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool)) | |
| wei = wei.masked_fill(~mask, float('-inf')) | |
| wei = F.softmax(wei, dim=-1) | |
| wei = self.dropout(wei) | |
| out = wei @ v | |
| return out | |
| class MultiHeadAttention(nn.Module): | |
| """ multiple heads of self-attention in parallel """ | |
| def __init__(self, num_heads, head_size): | |
| super().__init__() | |
| # Added the possibility to add heads per parameter and loop. That's it. | |
| self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)]) | |
| self.proj = nn.Linear(head_size * num_heads, n_embd) | |
| self.dropout = nn.Dropout(dropout) # <----- More Dropout! | |
| def forward(self, x): | |
| out = torch.cat([h(x) for h in self.heads], dim=-1) | |
| out = self.dropout(self.proj(out)) # <----- More Dropout! | |
| return out | |
| class FeedFoward(nn.Module): | |
| """ a simple linear layer followed by a non-linearity """ | |
| def __init__(self, n_embd): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(n_embd, 4 * n_embd), | |
| nn.GELU(approximate='tanh'), | |
| nn.Linear(4 * n_embd, n_embd), | |
| nn.Dropout(dropout), # <----- More Dropout! | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class Block(nn.Module): | |
| """ Transformer block: communication followed by computation """ | |
| def __init__(self, n_embd, n_head): | |
| super().__init__() | |
| head_size = n_embd // n_head | |
| self.sa = MultiHeadAttention(n_head, head_size) | |
| self.ffwd = FeedFoward(n_embd) | |
| self.ln1 = nn.LayerNorm(n_embd) | |
| self.ln2 = nn.LayerNorm(n_embd) | |
| def forward(self, x): | |
| x = x + self.sa(self.ln1(x)) | |
| x = x + self.ffwd(self.ln2(x)) | |
| return x | |
| # We now don't have a BigramLanguage anymore | |
| class GPTLanguageModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.token_embedding_table = nn.Embedding(vocab_size, n_embd) | |
| self.position_embedding_table = nn.Embedding(block_size, n_embd) | |
| # Added the possibility to add heads per parameter and loop. That's it. | |
| self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)]) | |
| self.ln_f = nn.LayerNorm(n_embd) | |
| self.lm_head = nn.Linear(n_embd, vocab_size) | |
| # Weight tying: improves perplexity and reduces params slightly | |
| self.lm_head.weight = self.token_embedding_table.weight | |
| def forward(self, idx, targets=None): | |
| B, T = idx.shape | |
| tok_emb = self.token_embedding_table(idx) | |
| pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device)) | |
| x = tok_emb + pos_emb | |
| x = self.blocks(x) | |
| x = self.ln_f(x) | |
| logits = self.lm_head(x) | |
| if targets is None: | |
| loss = None | |
| else: | |
| B, T, C = logits.shape | |
| logits = logits.view(B*T, C) | |
| targets = targets.view(B*T) | |
| # Compute CE loss (float32) with label smoothing for stability | |
| loss = F.cross_entropy(logits.float(), targets, label_smoothing=label_smoothing) | |
| return logits, loss | |
| def generate(self, idx, max_new_tokens): | |
| for _ in range(max_new_tokens): | |
| idx_cond = idx[:, -block_size:] | |
| logits, loss = self(idx_cond) | |
| logits = logits[:, -1, :] | |
| probs = F.softmax(logits, dim=-1) | |
| idx_next = torch.multinomial(probs, num_samples=1) # (B, 1) | |
| idx = torch.cat((idx, idx_next), dim=1) # (B, T+1) | |
| return idx | |
| # Train ============================= | |
| model = GPTLanguageModel() | |
| m = model.to(device) | |
| # print the number of parameters in the model | |
| print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters') | |
| print(device) | |
| # Optionally compile the model for speed (requires PyTorch 2.x) | |
| if use_compile: | |
| try: | |
| model = torch.compile(model) | |
| m = model # keep reference consistent | |
| print('torch.compile: enabled') | |
| except Exception as e: | |
| print(f'warning: torch.compile failed: {e}') | |
| # create a PyTorch optimizer | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.1, betas=(0.9, 0.95)) | |
| # autocast context for mixed precision (CUDA or MPS) | |
| if device == 'cuda': | |
| ctx = torch.amp.autocast(device_type='cuda', dtype=torch.float16) | |
| elif device == 'mps': | |
| ctx = torch.amp.autocast(device_type='mps', dtype=torch.float16) | |
| else: | |
| ctx = nullcontext() | |
| def get_lr(it): | |
| if it < warmup_iters: | |
| return learning_rate * it / max(1, warmup_iters) | |
| if it > lr_decay_iters: | |
| return min_lr | |
| decay_ratio = (it - warmup_iters) / max(1, lr_decay_iters - warmup_iters) | |
| coeff = 0.5 * (1 + math.cos(math.pi * decay_ratio)) | |
| return min_lr + coeff * (learning_rate - min_lr) | |
| log_window = 100 | |
| t_last = time.time() | |
| # ========== Checkpointing, resume, and interrupt handling | |
| def _ensure_dir(d): | |
| os.makedirs(d, exist_ok=True) | |
| return d | |
| def _checkpoint_dir(out_dir): | |
| return _ensure_dir(out_dir if out_dir else os.path.join('assets', 'checkpoints')) | |
| def save_ckpt(path, step): | |
| ckpt = { | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'iter': step, | |
| 'meta': { | |
| 'chars': chars, | |
| 'vocab_size': vocab_size, | |
| 'n_embd': n_embd, | |
| 'n_head': n_head, | |
| 'n_layer': n_layer, | |
| 'block_size': block_size, | |
| 'dropout': dropout, | |
| 'label_smoothing': label_smoothing, | |
| 'ema_decay': ema_decay, | |
| } | |
| } | |
| # Include EMA weights if available | |
| if 'ema_model' in globals() and ema_model is not None: | |
| try: | |
| ckpt['ema_state_dict'] = ema_model.state_dict() | |
| except Exception: | |
| pass | |
| torch.save(ckpt, path) | |
| def auto_latest_path(out_dir): | |
| return os.path.join(_checkpoint_dir(out_dir), 'latest.pt') | |
| def timed_step_path(out_dir, step): | |
| ts = datetime.now().strftime('%Y%m%d-%H%M%S') | |
| return os.path.join(_checkpoint_dir(out_dir), f'gpt-{ts}-step{step}.pt') | |
| parser = argparse.ArgumentParser(add_help=False) | |
| parser.add_argument('--resume', action='store_true', help='Resume training from latest checkpoint if available or from --ckpt') | |
| parser.add_argument('--ckpt', type=str, default=None, help='Specific checkpoint path to resume from') | |
| parser.add_argument('--save_interval', type=int, default=0, help='Steps between periodic checkpoints (0 to disable)') | |
| parser.add_argument('--save_twice', dest='save_twice', action='store_true', default=True, help='Save exactly twice at 1/3 and 2/3 progress') | |
| parser.add_argument('--no_save_twice', dest='save_twice', action='store_false', help='Disable the two-milestone saves') | |
| parser.add_argument('--out_dir', type=str, default=os.path.join('assets', 'checkpoints'), help='Directory to write checkpoints') | |
| try: | |
| args, _unknown = parser.parse_known_args() | |
| except SystemExit: | |
| class _A: pass | |
| args = _A() | |
| args.resume = False | |
| args.ckpt = None | |
| args.save_interval = 0 | |
| args.out_dir = os.path.join('assets', 'checkpoints') | |
| start_iter = 0 | |
| if args.resume: | |
| resume_path = args.ckpt if args.ckpt else (auto_latest_path(args.out_dir) if os.path.exists(auto_latest_path(args.out_dir)) else None) | |
| if resume_path and os.path.exists(resume_path): | |
| print(f"Resuming from checkpoint: {resume_path}") | |
| state = torch.load(resume_path, map_location=device) | |
| model.load_state_dict(state['model_state_dict']) | |
| try: | |
| optimizer.load_state_dict(state['optimizer_state_dict']) | |
| except Exception as e: | |
| print(f"warning: could not load optimizer state: {e}") | |
| start_iter = int(state.get('iter', -1)) + 1 | |
| if start_iter < 0: | |
| start_iter = 0 | |
| print(f"Resumed at step {start_iter}") | |
| else: | |
| print("--resume requested but no checkpoint found; starting fresh.") | |
| # Initialize EMA model after potential resume has loaded model | |
| ema_model = None | |
| if ema_decay and ema_decay > 0.0: | |
| ema_model = copy.deepcopy(model).to(device) | |
| for p in ema_model.parameters(): | |
| p.requires_grad_(False) | |
| # Milestone saves at ~1/3 and ~2/3 of max_iters | |
| milestones = sorted({max(1, round(max_iters/3)), max(1, round(2*max_iters/3))}) | |
| print(f"Milestone checkpoints planned at steps: {milestones}") | |
| interrupt_flag = {'hit': False} | |
| def _handle_sigint(signum, frame): | |
| interrupt_flag['hit'] = True | |
| print("\nCtrl+C detected; will save checkpoint at next safe point...") | |
| signal.signal(signal.SIGINT, _handle_sigint) | |
| for iter in range(start_iter, max_iters): | |
| # every once in a while evaluate the loss on train and val sets (skip step 0) | |
| if iter > 0 and (iter % eval_interval == 0 or iter == max_iters - 1): | |
| losses = estimate_loss() | |
| print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") | |
| # update learning rate via schedule | |
| lr = get_lr(iter) | |
| for g in optimizer.param_groups: | |
| g['lr'] = lr | |
| # sample a batch of data | |
| xb, yb = get_batch('train') | |
| # evaluate the loss | |
| with ctx: | |
| logits, loss = model(xb, yb) | |
| optimizer.zero_grad(set_to_none=True) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| # EMA update | |
| if ema_model is not None: | |
| with torch.no_grad(): | |
| msd = model.state_dict() | |
| for (k, v_ema) in ema_model.state_dict().items(): | |
| v = msd[k] | |
| if v_ema.dtype.is_floating_point: | |
| v_ema.mul_(ema_decay).add_(v, alpha=(1.0 - ema_decay)) | |
| # progress logging: avg ms/iter over last window and ETA | |
| if (iter + 1) % log_window == 0: | |
| t_now = time.time() | |
| ms_per_iter = (t_now - t_last) * 1000.0 / log_window | |
| t_last = t_now | |
| remaining = max_iters - (iter + 1) | |
| eta_min = (remaining * ms_per_iter) / 1000.0 / 60.0 | |
| print(f"~{ms_per_iter:.1f} ms/iter, ETA {eta_min:.1f} min, lr {lr:.2e}") | |
| # milestone and/or periodic checkpoint save | |
| do_milestone = args.save_twice and (iter in milestones) | |
| do_periodic = (args.save_interval and args.save_interval > 0 and (iter % args.save_interval == 0)) | |
| if iter > 0 and (do_milestone or do_periodic): | |
| latest = auto_latest_path(args.out_dir) | |
| step_path = timed_step_path(args.out_dir, iter) | |
| try: | |
| save_ckpt(latest, iter) | |
| save_ckpt(step_path, iter) | |
| which = 'milestone' if do_milestone and not do_periodic else ('periodic' if do_periodic and not do_milestone else 'periodic+milestone') | |
| print(f"Saved {which} checkpoint at step {iter} -> {latest} and {step_path}") | |
| except Exception as e: | |
| print(f"warning: failed to save checkpoint at step {iter}: {e}") | |
| # handle Ctrl+C gracefully: save and exit | |
| if interrupt_flag['hit']: | |
| latest = auto_latest_path(args.out_dir) | |
| try: | |
| save_ckpt(latest, iter) | |
| print(f"Checkpoint saved on interrupt at step {iter} -> {latest}") | |
| except Exception as e: | |
| print(f"warning: failed to save interrupt checkpoint: {e}") | |
| break | |
| # ========== Save final checkpoint and quick sample (if not interrupted) | |
| if not interrupt_flag['hit']: | |
| # Evaluate final losses for reference | |
| losses = estimate_loss() | |
| print(f"final: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") | |
| # Save model checkpoint with meta and optimizer | |
| latest_path = auto_latest_path(args.out_dir) | |
| step_path = timed_step_path(args.out_dir, max_iters - 1) | |
| try: | |
| save_ckpt(latest_path, max_iters - 1) | |
| save_ckpt(step_path, max_iters - 1) | |
| print(f"Saved checkpoint to {latest_path}\nSnapshot at {step_path}") | |
| except Exception as e: | |
| print(f"warning: failed to save final checkpoint: {e}") | |
| # Emit a short sample to verify end-to-end | |
| model.eval() | |
| with torch.no_grad(): | |
| # start from an empty context (first token index) | |
| start_idx = torch.zeros((1, 1), dtype=torch.long, device=device) | |
| out_idx = model.generate(start_idx, max_new_tokens=200)[0].tolist() | |
| sample_text = decode(out_idx) | |
| print("\n=== Sample (200 chars) ===") | |
| print(sample_text[:200]) | |
| print("==========================\n") | |