Spaces:
Paused
Paused
| # train_babygpt_optimized.py | |
| """ | |
| Optimized training script for small GPT on GPU (Tesla T4) with: | |
| - FP16 mixed precision (autocast + GradScaler) when CUDA is available | |
| - pinned memory + non_blocking transfers | |
| - optional dataset pre-tokenization -> tokens.pt | |
| - torch.compile try/except (won't break saving/loading) | |
| - minimal GC, set_to_none=True for zero_grad | |
| - gradient-checkpointing opt-in (disabled by default for speed) | |
| - safer checkpoint saving (saved as float32 on CPU) | |
| - robust autocast context that handles older/newer PyTorch signatures | |
| """ | |
| import os | |
| import io | |
| import time | |
| import math | |
| import gc | |
| import traceback | |
| from typing import Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint as checkpoint | |
| torch.backends.cudnn.benchmark = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| # CPU thread settings (safe) | |
| try: | |
| ncpu = os.cpu_count() or 2 | |
| torch.set_num_threads(ncpu) | |
| torch.set_num_interop_threads(max(1, ncpu // 2)) | |
| except Exception: | |
| pass | |
| try: | |
| import sentencepiece as spm | |
| except ImportError: | |
| raise RuntimeError("Please install sentencepiece: pip install sentencepiece") | |
| try: | |
| import matplotlib.pyplot as plt | |
| except Exception: | |
| plt = None | |
| # ---------- Paths / defaults ---------- | |
| DATA_PATH = "worldsim.txt" | |
| SP_MODEL_PREFIX = "tokenizer" | |
| SP_MODEL_FILE = f"{SP_MODEL_PREFIX}.model" | |
| SP_VOCAB_DEFAULT = 14000 | |
| TOKENS_PT = "tokens.pt" # pre-tokenized cached file | |
| # Device default: set to 'cuda' to use the GPU; will auto-check availability | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # GPU-optimized defaults (tuned for T4: speed + reasonable model capacity) | |
| DEFAULTS = dict( | |
| BLOCK_SIZE=512, # shorter sequence = more steps/sec | |
| BATCH_SIZE=8, # bigger batch to saturate GPU (tune down if OOM) | |
| EMBED_DIM=448, # embed size | |
| NUM_HEADS=7, | |
| NUM_LAYERS=10, # fewer layers => much faster | |
| DROPOUT=0.05, | |
| EPOCHS=6, | |
| LR=2e-5, | |
| PRINT_EVERY=50, | |
| ACCUM_STEPS=1, # don't accumulate unless needed | |
| SP_VOCAB=SP_VOCAB_DEFAULT, | |
| GRADIENT_CHECKPOINT=True, # default off for speed on T4 | |
| PRETOKENIZE=True, # create tokens.pt to avoid tokenization overhead during training | |
| ) | |
| # ---------- stop flag helpers ---------- | |
| _TRAIN_STOP_REQUESTED = False | |
| def request_stop(): | |
| global _TRAIN_STOP_REQUESTED; _TRAIN_STOP_REQUESTED = True | |
| def clear_stop_request(): | |
| global _TRAIN_STOP_REQUESTED; _TRAIN_STOP_REQUESTED = False | |
| def stop_requested(): | |
| return _TRAIN_STOP_REQUESTED | |
| # ---------- sentencepiece helpers ---------- | |
| def ensure_sp_model(data_path=DATA_PATH, model_prefix=SP_MODEL_PREFIX, vocab_size=SP_VOCAB_DEFAULT): | |
| if os.path.exists(f"{model_prefix}.model") and os.path.exists(f"{model_prefix}.vocab"): | |
| sp = spm.SentencePieceProcessor() | |
| sp.load(f"{model_prefix}.model") | |
| return sp | |
| # train sentencepiece | |
| spm.SentencePieceTrainer.train( | |
| input=data_path, | |
| model_prefix=model_prefix, | |
| vocab_size=vocab_size, | |
| model_type="bpe", | |
| character_coverage=1.0, | |
| unk_id=0, bos_id=-1, eos_id=-1, | |
| ) | |
| sp = spm.SentencePieceProcessor() | |
| sp.load(f"{model_prefix}.model") | |
| return sp | |
| # ---------- model ---------- | |
| class BabyGPT(nn.Module): | |
| def __init__(self, vocab_size, embed_dim, block_size, num_heads, num_layers, dropout, use_checkpoint=False): | |
| super().__init__() | |
| self.block_size = block_size | |
| self.tok_emb = nn.Embedding(vocab_size, embed_dim) | |
| self.pos_emb = nn.Embedding(block_size, embed_dim) | |
| self.layers = nn.ModuleList([ | |
| nn.TransformerEncoderLayer( | |
| d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim * 4, | |
| dropout=dropout, batch_first=True | |
| ) for _ in range(num_layers) | |
| ]) | |
| self.ln = nn.LayerNorm(embed_dim) | |
| self.head = nn.Linear(embed_dim, vocab_size) | |
| self._use_gradient_checkpointing = use_checkpoint | |
| def enable_gradient_checkpointing(self): | |
| self._use_gradient_checkpointing = True | |
| def disable_gradient_checkpointing(self): | |
| self._use_gradient_checkpointing = False | |
| def forward(self, idx, targets=None): | |
| B, T = idx.shape | |
| device = idx.device | |
| pos = torch.arange(0, T, device=device).unsqueeze(0) | |
| x = self.tok_emb(idx) + self.pos_emb(pos) | |
| for layer in self.layers: | |
| if self._use_gradient_checkpointing and self.training: | |
| # must pass tensors only; explicit use_reentrant for PyTorch >=2.5 compatibility | |
| def run_layer(x_local, layer_local=layer): | |
| return layer_local(x_local) | |
| # pass use_reentrant explicitly to avoid PyTorch warning in 2.5+ | |
| x = checkpoint.checkpoint(run_layer, x, use_reentrant=False) | |
| else: | |
| x = layer(x) | |
| x = self.ln(x) | |
| logits = self.head(x) | |
| loss = None | |
| if targets is not None: | |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) | |
| return logits, loss | |
| def generate(self, idx, max_new_tokens=50, temperature=1.0, top_k=50): | |
| self.eval() | |
| device = next(self.parameters()).device | |
| idx = idx.to(device) | |
| for _ in range(max_new_tokens): | |
| idx_cond = idx[:, -self.block_size:].to(device) | |
| logits, _ = self(idx_cond) | |
| last_logits = logits[:, -1, :].to(torch.float32) | |
| last_logits = last_logits / (temperature if temperature > 0 else 1.0) | |
| if top_k > 0: | |
| v, _ = torch.topk(last_logits, top_k) | |
| last_logits[last_logits < v[:, [-1]]] = -float("Inf") | |
| probs = F.softmax(last_logits, dim=-1) | |
| next_id = torch.multinomial(probs, num_samples=1) | |
| idx = torch.cat((idx, next_id.to(idx.device)), dim=1) | |
| return idx | |
| # ---------- data streaming & token caching ---------- | |
| def stream_text_chunks(path=DATA_PATH, chunk_size=128_000): | |
| with open(path, "r", encoding="utf-8", errors="ignore") as f: | |
| while True: | |
| chunk = f.read(chunk_size) | |
| if not chunk: | |
| break | |
| yield chunk | |
| def build_or_load_token_cache(sp, path=DATA_PATH, tokens_pt=TOKENS_PT, chunk_size=128_000): | |
| """ | |
| If tokens.pt exists, load it. Otherwise tokenize the dataset (streaming) and save | |
| a 1-D tensor of token ids to tokens.pt for fast subsequent training runs. | |
| """ | |
| if os.path.exists(tokens_pt): | |
| try: | |
| tokens = torch.load(tokens_pt, map_location="cpu") | |
| if isinstance(tokens, torch.Tensor): | |
| return tokens | |
| except Exception: | |
| pass | |
| print("Pre-tokenizing dataset to", tokens_pt, " — this may take a while (one-time).") | |
| all_ids = [] | |
| for chunk in stream_text_chunks(path, chunk_size): | |
| ids = sp.encode(chunk, out_type=int) | |
| all_ids.extend(ids) | |
| tokens = torch.tensor(all_ids, dtype=torch.long) | |
| torch.save(tokens, tokens_pt) | |
| return tokens | |
| def get_batch_from_ids(ids, block_size, batch_size, device="cpu"): | |
| """ | |
| ids: 1D list or 1D torch.Tensor of token ids. | |
| Returns (x, y) on target device with pinned-memory & non_blocking copy when using CUDA. | |
| """ | |
| dev = torch.device(device) | |
| use_cuda = (dev.type == "cuda") | |
| if isinstance(ids, torch.Tensor): | |
| L = ids.numel() | |
| else: | |
| L = len(ids) | |
| if L < block_size + 1: | |
| raise ValueError("ids too short for block size") | |
| ix = torch.randint(0, L - block_size - 1, (batch_size,)) | |
| # allocate pinned buffers if sending to CUDA | |
| if use_cuda: | |
| x = torch.empty((batch_size, block_size), dtype=torch.long).pin_memory() | |
| y = torch.empty((batch_size, block_size), dtype=torch.long).pin_memory() | |
| else: | |
| x = torch.empty((batch_size, block_size), dtype=torch.long) | |
| y = torch.empty((batch_size, block_size), dtype=torch.long) | |
| for bi, i in enumerate(ix): | |
| if torch.is_tensor(ids): | |
| seg = ids[i:i + block_size + 1].tolist() | |
| else: | |
| seg = ids[i:i + block_size + 1] | |
| x[bi].copy_(torch.tensor(seg[:-1], dtype=torch.long)) | |
| y[bi].copy_(torch.tensor(seg[1:], dtype=torch.long)) | |
| if use_cuda: | |
| return x.to(dev, non_blocking=True), y.to(dev, non_blocking=True) | |
| else: | |
| return x.to(dev), y.to(dev) | |
| # ---------- plotting ---------- | |
| def plot_loss(history): | |
| if not plt: | |
| return b"" | |
| plt.switch_backend("Agg") | |
| fig, ax = plt.subplots(figsize=(5,3)) | |
| ax.plot(history) | |
| ax.set_title("Loss") | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png") | |
| buf.seek(0) | |
| plt.close(fig) | |
| return buf.read() | |
| # ---------- checkpoint utils ---------- | |
| def _unwrap_model_for_saving(model: nn.Module) -> nn.Module: | |
| """Return underlying module if a compiled wrapper added attributes like _orig_mod.""" | |
| if hasattr(model, "_orig_mod"): | |
| return model._orig_mod | |
| # some wrappers use .module | |
| if hasattr(model, "module"): | |
| return model.module | |
| return model | |
| def save_checkpoint(model, cfg, path): | |
| # save CPU float32 weights to ensure cross-device loading compatibility | |
| try: | |
| real_model = _unwrap_model_for_saving(model) | |
| state = real_model.state_dict() | |
| cpu_state = {k: v.detach().cpu().to(torch.float32) for k, v in state.items()} | |
| data = {'model_state_dict': cpu_state, 'config': cfg} | |
| torch.save(data, path) | |
| except Exception as e: | |
| # fallback: try normal save (best effort) | |
| try: | |
| torch.save({'model_state_dict': model.state_dict(), 'config': cfg}, path) | |
| except Exception as e2: | |
| print("Failed to save checkpoint:", e2) | |
| raise | |
| def latest_checkpoint(): | |
| ckpts = [f for f in os.listdir(".") if f.startswith("baby_gpt_epoch") and f.endswith(".pth")] | |
| if ckpts: | |
| return sorted(ckpts, key=os.path.getmtime)[-1] | |
| return "baby_gpt_final.pth" if os.path.exists("baby_gpt_final.pth") else None | |
| def _strip_orig_mod_prefix(state_dict: dict) -> dict: | |
| """Strip leading _orig_mod. prefix from keys if present.""" | |
| keys = list(state_dict.keys()) | |
| # detect if keys have the prefix | |
| if any(k.startswith("_orig_mod.") for k in keys): | |
| return {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} | |
| # also support accidental 'module.' prefix | |
| if any(k.startswith("module.") for k in keys) and not any(k.startswith("tok_emb") for k in keys): | |
| return {k.replace("module.", ""): v for k, v in state_dict.items()} | |
| return state_dict | |
| def load_model_for_inference(checkpoint=None, device=DEVICE): | |
| ck = checkpoint or latest_checkpoint() | |
| if not ck: | |
| raise FileNotFoundError("No checkpoint.") | |
| data = torch.load(ck, map_location="cpu") | |
| sp = ensure_sp_model(DATA_PATH, SP_MODEL_PREFIX, DEFAULTS["SP_VOCAB"]) | |
| cfg = data.get("config", DEFAULTS) | |
| model = BabyGPT(sp.get_piece_size(), cfg["EMBED_DIM"], cfg["BLOCK_SIZE"], | |
| cfg["NUM_HEADS"], cfg["NUM_LAYERS"], cfg["DROPOUT"], | |
| use_checkpoint=cfg.get("GRADIENT_CHECKPOINT", False)).to(device) | |
| state = data["model_state_dict"] | |
| state = _strip_orig_mod_prefix(state) | |
| # try load, try forgiving missing/unexpected keys by strict=False first, then stricter load | |
| try: | |
| model.load_state_dict(state, strict=False) | |
| except Exception as e: | |
| # final fallback: try to match only exact keys | |
| try: | |
| model.load_state_dict(state) | |
| except Exception as e2: | |
| raise RuntimeError(f"Failed to load state_dict: {e2}") | |
| model.eval() | |
| return model, sp, cfg | |
| # ---------- autocast helper ---------- | |
| def autocast_ctx(): | |
| """ | |
| Return a context manager for autocast. Robust to PyTorch versions which may accept | |
| different signatures for torch.cuda.amp.autocast. | |
| """ | |
| if not (torch.cuda.is_available() and hasattr(torch.cuda, "amp")): | |
| # dummy context | |
| class _NopCtx: | |
| def __enter__(self): return None | |
| def __exit__(self, exc_type, exc, tb): return False | |
| return _NopCtx() | |
| # try modern signature with dtype | |
| try: | |
| return torch.cuda.amp.autocast(dtype=torch.float16) | |
| except TypeError: | |
| # older signatures may accept no args or device_type | |
| try: | |
| return torch.cuda.amp.autocast() | |
| except Exception: | |
| class _NopCtx: | |
| def __enter__(self): return None | |
| def __exit__(self, exc_type, exc, tb): return False | |
| return _NopCtx() | |
| # ---------- train generator ---------- | |
| def train_generator(**params): | |
| """ | |
| Yields (log_text, plot_png_bytes) pairs (same behavior as your Gradio app expects). | |
| Accepts overrides for any DEFAULTS keys by passing them into train_generator(...). | |
| """ | |
| clear_stop_request() | |
| cfg = {**DEFAULTS, **{k: v for k, v in params.items() if v is not None}} | |
| # SentencePiece | |
| sp = ensure_sp_model(DATA_PATH, SP_MODEL_PREFIX, cfg.get("SP_VOCAB", SP_VOCAB_DEFAULT)) | |
| # Optional pre-tokenize | |
| tokens_cache = None | |
| if cfg.get("PRETOKENIZE", True): | |
| tokens_cache = build_or_load_token_cache(sp, DATA_PATH, TOKENS_PT, chunk_size=128_000) | |
| # Build model (vocab from sp) | |
| vocab_size = sp.get_piece_size() | |
| model = BabyGPT(vocab_size, cfg["EMBED_DIM"], cfg["BLOCK_SIZE"], | |
| cfg["NUM_HEADS"], cfg["NUM_LAYERS"], cfg["DROPOUT"], | |
| use_checkpoint=cfg.get("GRADIENT_CHECKPOINT", False)) | |
| # gradient checkpointing opt-in | |
| if cfg.get("GRADIENT_CHECKPOINT", False): | |
| try: | |
| model.enable_gradient_checkpointing() | |
| print("Gradient checkpointing enabled.") | |
| except Exception as e: | |
| print("Could not enable gradient checkpointing:", e) | |
| # Choose dtype/device strategy | |
| # On CUDA: keep params float32 and use FP16 autocast + GradScaler (T4 prefers FP16). | |
| chosen_dtype = torch.float32 | |
| if DEVICE.startswith("cuda") and torch.cuda.is_available(): | |
| model = model.to(torch.float32) # keep params float32 | |
| chosen_dtype = torch.float32 | |
| print("Using float32 params on CUDA. Mixed FP16 autocast will be used during forward.") | |
| else: | |
| # On CPU, try bfloat16 if available (rare) | |
| try: | |
| _ = torch.empty(1, dtype=torch.bfloat16) | |
| model = model.to(dtype=torch.bfloat16) | |
| chosen_dtype = torch.bfloat16 | |
| print("Using bfloat16 on CPU.") | |
| except Exception: | |
| model = model.to(dtype=torch.float32) | |
| chosen_dtype = torch.float32 | |
| print("Using float32 on CPU.") | |
| # Move model to device | |
| model = model.to(DEVICE) | |
| try: | |
| dev = next(model.parameters()).device | |
| device_info = str(dev) | |
| if "cuda" in device_info: | |
| dev_idx = torch.cuda.current_device() | |
| dev_name = torch.cuda.get_device_name(dev_idx) | |
| mem_alloc = torch.cuda.memory_allocated(dev_idx) / 1024**2 | |
| mem_reserved = torch.cuda.memory_reserved(dev_idx) / 1024**2 | |
| device_info = f"{device_info} ({dev_name}) alloc={mem_alloc:.1f}MB reserved={mem_reserved:.1f}MB" | |
| print("Model moved to device:", device_info) | |
| except Exception: | |
| pass | |
| # Build optimizer | |
| opt = torch.optim.AdamW(model.parameters(), lr=cfg["LR"]) | |
| # CUDA-specific optimizations | |
| use_amp = (DEVICE.startswith("cuda") and torch.cuda.is_available()) | |
| scaler = torch.cuda.amp.GradScaler() if use_amp else None | |
| if use_amp: | |
| torch.backends.cudnn.benchmark = True | |
| # optional TF32 speedup | |
| try: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| except Exception: | |
| pass | |
| # Try torch.compile for speedups (best effort) | |
| try: | |
| model = torch.compile(model, mode="reduce-overhead") | |
| print("torch.compile applied.") | |
| torch.cuda.synchronize() | |
| t_start = time.time() | |
| last_print_time = t_start | |
| steps_since = 0 | |
| except Exception as e: | |
| print("torch.compile unavailable:", e) | |
| # training prep | |
| loss_hist = [] | |
| total_params_m = sum(p.numel() for p in model.parameters())/1e6 | |
| log = f"🚀 Training started | {total_params_m:.2f}M params | dtype={chosen_dtype}\n" | |
| yield (log, plot_loss(loss_hist)) | |
| model.train() | |
| global_step = 0 | |
| # timing | |
| t_start = time.time() | |
| last_print_time = t_start | |
| steps_since = 0 | |
| # read ids source: prefer pre-tokenized cache | |
| ids_source = tokens_cache if tokens_cache is not None else None | |
| for epoch in range(cfg["EPOCHS"]): | |
| log += f"\n=== Epoch {epoch+1}/{cfg['EPOCHS']} ===\n" | |
| # If we have tokens cache: iterate over chunks of the token stream | |
| if ids_source is not None: | |
| L = ids_source.numel() | |
| chunk_length = cfg["BLOCK_SIZE"] * max(1, cfg["BATCH_SIZE"]) * 2 | |
| num_chunks = max(1, L // chunk_length) | |
| for ci in range(num_chunks): | |
| if stop_requested(): return | |
| start = (ci * chunk_length) % max(1, L - chunk_length) | |
| chunk_ids = ids_source[start:start + chunk_length].tolist() | |
| steps = max(1, len(chunk_ids) // (cfg["BLOCK_SIZE"] * max(1, cfg["BATCH_SIZE"]))) | |
| for step in range(steps): | |
| if stop_requested(): return | |
| xb, yb = get_batch_from_ids(chunk_ids, cfg["BLOCK_SIZE"], cfg["BATCH_SIZE"], device=DEVICE) | |
| # forward/backward with AMP if available | |
| try: | |
| if use_amp: | |
| with autocast_ctx(): | |
| logits, loss = model(xb, yb) | |
| scaler.scale(loss / cfg["ACCUM_STEPS"]).backward() | |
| else: | |
| logits, loss = model(xb, yb) | |
| (loss / cfg["ACCUM_STEPS"]).backward() | |
| except RuntimeError as e: | |
| # handle occasional OOM gracefully by reducing batch | |
| if "out of memory" in str(e).lower(): | |
| torch.cuda.empty_cache() | |
| print("OOM during training step - skipping step (reduce BATCH_SIZE).") | |
| continue | |
| else: | |
| raise | |
| if (step + 1) % cfg["ACCUM_STEPS"] == 0: | |
| if use_amp: | |
| try: | |
| scaler.step(opt) | |
| scaler.update() | |
| except Exception as e: | |
| print("Scaler step failed:", e) | |
| opt.step() | |
| else: | |
| opt.step() | |
| opt.zero_grad(set_to_none=True) | |
| loss_hist.append(loss.item()) | |
| global_step += 1 | |
| steps_since += 1 | |
| if step % cfg["PRINT_EVERY"] == 0: | |
| torch.cuda.synchronize() | |
| now = time.time() | |
| elapsed = now - last_print_time | |
| overall_elapsed = now - t_start | |
| sps = steps_since / elapsed if elapsed > 0 else 0.0 | |
| avg_sps = global_step / overall_elapsed if overall_elapsed > 0 else 0.0 | |
| # device stats | |
| dev_stats = "" | |
| try: | |
| dev = next(model.parameters()).device | |
| if "cuda" in str(dev): | |
| dev_idx = torch.cuda.current_device() | |
| mem_alloc = torch.cuda.memory_allocated(dev_idx) / 1024**2 | |
| mem_reserved = torch.cuda.memory_reserved(dev_idx) / 1024**2 | |
| dev_name = torch.cuda.get_device_name(dev_idx) | |
| dev_stats = f" | {dev_name} alloc={mem_alloc:.1f}MB reserved={mem_reserved:.1f}MB" | |
| else: | |
| dev_stats = f" | device={dev}" | |
| except Exception: | |
| dev_stats = "" | |
| log += (f"[E{epoch+1}C{ci}] step {step}/{steps} loss={loss.item():.4f} " | |
| f"| steps/s={sps:.2f} (avg {avg_sps:.2f}){dev_stats}\n") | |
| last_print_time = now | |
| steps_since = 0 | |
| yield (log, plot_loss(loss_hist)) | |
| # free references but avoid overusing gc.collect() | |
| del xb, yb, logits, loss | |
| # end steps loop | |
| # end chunk loop | |
| else: | |
| # tokenization-on-the-fly (slower) - fallback to streaming chunks | |
| for ci, chunk in enumerate(stream_text_chunks(DATA_PATH, chunk_size=128_000)): | |
| ids = sp.encode(chunk, out_type=int) | |
| if len(ids) < cfg["BLOCK_SIZE"]: | |
| continue | |
| steps = min(1000, max(1, len(ids)//(cfg["BLOCK_SIZE"] * max(1, cfg["BATCH_SIZE"])))) | |
| for step in range(steps): | |
| if stop_requested(): return | |
| xb, yb = get_batch_from_ids(ids, cfg["BLOCK_SIZE"], cfg["BATCH_SIZE"], device=DEVICE) | |
| try: | |
| if use_amp: | |
| with autocast_ctx(): | |
| logits, loss = model(xb, yb) | |
| scaler.scale(loss / cfg["ACCUM_STEPS"]).backward() | |
| else: | |
| logits, loss = model(xb, yb) | |
| (loss / cfg["ACCUM_STEPS"]).backward() | |
| except RuntimeError as e: | |
| if "out of memory" in str(e).lower(): | |
| torch.cuda.empty_cache() | |
| print("OOM during training step - skipping step (reduce BATCH_SIZE).") | |
| continue | |
| else: | |
| raise | |
| if (step + 1) % cfg["ACCUM_STEPS"] == 0: | |
| if use_amp: | |
| try: | |
| scaler.step(opt) | |
| scaler.update() | |
| except Exception as e: | |
| print("Scaler step failed:", e) | |
| opt.step() | |
| else: | |
| opt.step() | |
| opt.zero_grad(set_to_none=True) | |
| loss_hist.append(loss.item()) | |
| global_step += 1 | |
| steps_since += 1 | |
| if step % cfg["PRINT_EVERY"] == 0: | |
| now = time.time() | |
| elapsed = now - last_print_time | |
| overall_elapsed = now - t_start | |
| sps = steps_since / elapsed if elapsed > 0 else 0.0 | |
| avg_sps = global_step / overall_elapsed if overall_elapsed > 0 else 0.0 | |
| dev_stats = "" | |
| try: | |
| dev = next(model.parameters()).device | |
| if "cuda" in str(dev): | |
| dev_idx = torch.cuda.current_device() | |
| mem_alloc = torch.cuda.memory_allocated(dev_idx) / 1024**2 | |
| mem_reserved = torch.cuda.memory_reserved(dev_idx) / 1024**2 | |
| dev_name = torch.cuda.get_device_name(dev_idx) | |
| dev_stats = f" | {dev_name} alloc={mem_alloc:.1f}MB reserved={mem_reserved:.1f}MB" | |
| else: | |
| dev_stats = f" | device={dev}" | |
| except Exception: | |
| dev_stats = "" | |
| log += (f"[E{epoch+1}C{ci}] step {step}/{steps} loss={loss.item():.4f} " | |
| f"| steps/s={sps:.2f} (avg {avg_sps:.2f}){dev_stats}\n") | |
| last_print_time = now | |
| steps_since = 0 | |
| yield (log, plot_loss(loss_hist)) | |
| del xb, yb, logits, loss | |
| # epoch end: save checkpoint | |
| ck = f"baby_gpt_epoch{epoch+1}.pth" | |
| try: | |
| save_checkpoint(model, {**cfg}, ck) | |
| log += f"💾 Saved checkpoint {ck}\n" | |
| except Exception as e: | |
| log += f"❌ Failed saving checkpoint {ck}: {e}\n" | |
| yield (log, plot_loss(loss_hist)) | |
| # final save | |
| try: | |
| save_checkpoint(model, {**cfg}, "baby_gpt_final.pth") | |
| log += "\n🎉 Training complete! Saved baby_gpt_final.pth\n" | |
| except Exception as e: | |
| log += f"\n❌ Failed final save: {e}\n" | |
| yield (log, plot_loss(loss_hist)) | |
| # ---------- minimal CLI ---------- | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--epochs", type=int, default=None) | |
| parser.add_argument("--vocab", type=int, default=None) | |
| parser.add_argument("--batch", type=int, default=None) | |
| parser.add_argument("--bs", type=int, default=None) | |
| parser.add_argument("--layers", type=int, default=None) | |
| parser.add_argument("--pretok", action="store_true", help="Force pretokenization") | |
| args = parser.parse_args() | |
| extra = {} | |
| if args.epochs: extra["EPOCHS"] = args.epochs | |
| if args.vocab: extra["SP_VOCAB"] = args.vocab | |
| if args.batch: extra["BATCH_SIZE"] = args.batch | |
| if args.bs: extra["BLOCK_SIZE"] = args.bs | |
| if args.layers: extra["NUM_LAYERS"] = args.layers | |
| if args.pretok: extra["PRETOKENIZE"] = True | |
| tg = train_generator(**extra) | |
| try: | |
| while True: | |
| out = next(tg) | |
| text, img = out | |
| os.write(1, text.encode("utf-8")) | |
| if img: | |
| with open("latest_loss.png", "wb") as f: | |
| f.write(img) | |
| except StopIteration: | |
| print("\nDone.") | |
| except KeyboardInterrupt: | |
| request_stop() | |
| print("\nInterrupted and requested stop.") | |