# train_babygpt_optimized.py """ Optimized training script for small GPT on GPU (Tesla T4) with: - FP16 mixed precision (autocast + GradScaler) when CUDA is available - pinned memory + non_blocking transfers - optional dataset pre-tokenization -> tokens.pt - torch.compile try/except (won't break saving/loading) - minimal GC, set_to_none=True for zero_grad - gradient-checkpointing opt-in (disabled by default for speed) - safer checkpoint saving (saved as float32 on CPU) - robust autocast context that handles older/newer PyTorch signatures """ import os import io import time import math import gc import traceback from typing import Tuple import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint torch.backends.cudnn.benchmark = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True # CPU thread settings (safe) try: ncpu = os.cpu_count() or 2 torch.set_num_threads(ncpu) torch.set_num_interop_threads(max(1, ncpu // 2)) except Exception: pass try: import sentencepiece as spm except ImportError: raise RuntimeError("Please install sentencepiece: pip install sentencepiece") try: import matplotlib.pyplot as plt except Exception: plt = None # ---------- Paths / defaults ---------- DATA_PATH = "worldsim.txt" SP_MODEL_PREFIX = "tokenizer" SP_MODEL_FILE = f"{SP_MODEL_PREFIX}.model" SP_VOCAB_DEFAULT = 14000 TOKENS_PT = "tokens.pt" # pre-tokenized cached file # Device default: set to 'cuda' to use the GPU; will auto-check availability DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # GPU-optimized defaults (tuned for T4: speed + reasonable model capacity) DEFAULTS = dict( BLOCK_SIZE=512, # shorter sequence = more steps/sec BATCH_SIZE=8, # bigger batch to saturate GPU (tune down if OOM) EMBED_DIM=448, # embed size NUM_HEADS=7, NUM_LAYERS=10, # fewer layers => much faster DROPOUT=0.05, EPOCHS=6, LR=2e-5, PRINT_EVERY=50, ACCUM_STEPS=1, # don't accumulate unless needed SP_VOCAB=SP_VOCAB_DEFAULT, GRADIENT_CHECKPOINT=True, # default off for speed on T4 PRETOKENIZE=True, # create tokens.pt to avoid tokenization overhead during training ) # ---------- stop flag helpers ---------- _TRAIN_STOP_REQUESTED = False def request_stop(): global _TRAIN_STOP_REQUESTED; _TRAIN_STOP_REQUESTED = True def clear_stop_request(): global _TRAIN_STOP_REQUESTED; _TRAIN_STOP_REQUESTED = False def stop_requested(): return _TRAIN_STOP_REQUESTED # ---------- sentencepiece helpers ---------- def ensure_sp_model(data_path=DATA_PATH, model_prefix=SP_MODEL_PREFIX, vocab_size=SP_VOCAB_DEFAULT): if os.path.exists(f"{model_prefix}.model") and os.path.exists(f"{model_prefix}.vocab"): sp = spm.SentencePieceProcessor() sp.load(f"{model_prefix}.model") return sp # train sentencepiece spm.SentencePieceTrainer.train( input=data_path, model_prefix=model_prefix, vocab_size=vocab_size, model_type="bpe", character_coverage=1.0, unk_id=0, bos_id=-1, eos_id=-1, ) sp = spm.SentencePieceProcessor() sp.load(f"{model_prefix}.model") return sp # ---------- model ---------- class BabyGPT(nn.Module): def __init__(self, vocab_size, embed_dim, block_size, num_heads, num_layers, dropout, use_checkpoint=False): super().__init__() self.block_size = block_size self.tok_emb = nn.Embedding(vocab_size, embed_dim) self.pos_emb = nn.Embedding(block_size, embed_dim) self.layers = nn.ModuleList([ nn.TransformerEncoderLayer( d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim * 4, dropout=dropout, batch_first=True ) for _ in range(num_layers) ]) self.ln = nn.LayerNorm(embed_dim) self.head = nn.Linear(embed_dim, vocab_size) self._use_gradient_checkpointing = use_checkpoint def enable_gradient_checkpointing(self): self._use_gradient_checkpointing = True def disable_gradient_checkpointing(self): self._use_gradient_checkpointing = False def forward(self, idx, targets=None): B, T = idx.shape device = idx.device pos = torch.arange(0, T, device=device).unsqueeze(0) x = self.tok_emb(idx) + self.pos_emb(pos) for layer in self.layers: if self._use_gradient_checkpointing and self.training: # must pass tensors only; explicit use_reentrant for PyTorch >=2.5 compatibility def run_layer(x_local, layer_local=layer): return layer_local(x_local) # pass use_reentrant explicitly to avoid PyTorch warning in 2.5+ x = checkpoint.checkpoint(run_layer, x, use_reentrant=False) else: x = layer(x) x = self.ln(x) logits = self.head(x) loss = None if targets is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) return logits, loss @torch.no_grad() def generate(self, idx, max_new_tokens=50, temperature=1.0, top_k=50): self.eval() device = next(self.parameters()).device idx = idx.to(device) for _ in range(max_new_tokens): idx_cond = idx[:, -self.block_size:].to(device) logits, _ = self(idx_cond) last_logits = logits[:, -1, :].to(torch.float32) last_logits = last_logits / (temperature if temperature > 0 else 1.0) if top_k > 0: v, _ = torch.topk(last_logits, top_k) last_logits[last_logits < v[:, [-1]]] = -float("Inf") probs = F.softmax(last_logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1) idx = torch.cat((idx, next_id.to(idx.device)), dim=1) return idx # ---------- data streaming & token caching ---------- def stream_text_chunks(path=DATA_PATH, chunk_size=128_000): with open(path, "r", encoding="utf-8", errors="ignore") as f: while True: chunk = f.read(chunk_size) if not chunk: break yield chunk def build_or_load_token_cache(sp, path=DATA_PATH, tokens_pt=TOKENS_PT, chunk_size=128_000): """ If tokens.pt exists, load it. Otherwise tokenize the dataset (streaming) and save a 1-D tensor of token ids to tokens.pt for fast subsequent training runs. """ if os.path.exists(tokens_pt): try: tokens = torch.load(tokens_pt, map_location="cpu") if isinstance(tokens, torch.Tensor): return tokens except Exception: pass print("Pre-tokenizing dataset to", tokens_pt, " — this may take a while (one-time).") all_ids = [] for chunk in stream_text_chunks(path, chunk_size): ids = sp.encode(chunk, out_type=int) all_ids.extend(ids) tokens = torch.tensor(all_ids, dtype=torch.long) torch.save(tokens, tokens_pt) return tokens def get_batch_from_ids(ids, block_size, batch_size, device="cpu"): """ ids: 1D list or 1D torch.Tensor of token ids. Returns (x, y) on target device with pinned-memory & non_blocking copy when using CUDA. """ dev = torch.device(device) use_cuda = (dev.type == "cuda") if isinstance(ids, torch.Tensor): L = ids.numel() else: L = len(ids) if L < block_size + 1: raise ValueError("ids too short for block size") ix = torch.randint(0, L - block_size - 1, (batch_size,)) # allocate pinned buffers if sending to CUDA if use_cuda: x = torch.empty((batch_size, block_size), dtype=torch.long).pin_memory() y = torch.empty((batch_size, block_size), dtype=torch.long).pin_memory() else: x = torch.empty((batch_size, block_size), dtype=torch.long) y = torch.empty((batch_size, block_size), dtype=torch.long) for bi, i in enumerate(ix): if torch.is_tensor(ids): seg = ids[i:i + block_size + 1].tolist() else: seg = ids[i:i + block_size + 1] x[bi].copy_(torch.tensor(seg[:-1], dtype=torch.long)) y[bi].copy_(torch.tensor(seg[1:], dtype=torch.long)) if use_cuda: return x.to(dev, non_blocking=True), y.to(dev, non_blocking=True) else: return x.to(dev), y.to(dev) # ---------- plotting ---------- def plot_loss(history): if not plt: return b"" plt.switch_backend("Agg") fig, ax = plt.subplots(figsize=(5,3)) ax.plot(history) ax.set_title("Loss") buf = io.BytesIO() fig.savefig(buf, format="png") buf.seek(0) plt.close(fig) return buf.read() # ---------- checkpoint utils ---------- def _unwrap_model_for_saving(model: nn.Module) -> nn.Module: """Return underlying module if a compiled wrapper added attributes like _orig_mod.""" if hasattr(model, "_orig_mod"): return model._orig_mod # some wrappers use .module if hasattr(model, "module"): return model.module return model def save_checkpoint(model, cfg, path): # save CPU float32 weights to ensure cross-device loading compatibility try: real_model = _unwrap_model_for_saving(model) state = real_model.state_dict() cpu_state = {k: v.detach().cpu().to(torch.float32) for k, v in state.items()} data = {'model_state_dict': cpu_state, 'config': cfg} torch.save(data, path) except Exception as e: # fallback: try normal save (best effort) try: torch.save({'model_state_dict': model.state_dict(), 'config': cfg}, path) except Exception as e2: print("Failed to save checkpoint:", e2) raise def latest_checkpoint(): ckpts = [f for f in os.listdir(".") if f.startswith("baby_gpt_epoch") and f.endswith(".pth")] if ckpts: return sorted(ckpts, key=os.path.getmtime)[-1] return "baby_gpt_final.pth" if os.path.exists("baby_gpt_final.pth") else None def _strip_orig_mod_prefix(state_dict: dict) -> dict: """Strip leading _orig_mod. prefix from keys if present.""" keys = list(state_dict.keys()) # detect if keys have the prefix if any(k.startswith("_orig_mod.") for k in keys): return {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} # also support accidental 'module.' prefix if any(k.startswith("module.") for k in keys) and not any(k.startswith("tok_emb") for k in keys): return {k.replace("module.", ""): v for k, v in state_dict.items()} return state_dict def load_model_for_inference(checkpoint=None, device=DEVICE): ck = checkpoint or latest_checkpoint() if not ck: raise FileNotFoundError("No checkpoint.") data = torch.load(ck, map_location="cpu") sp = ensure_sp_model(DATA_PATH, SP_MODEL_PREFIX, DEFAULTS["SP_VOCAB"]) cfg = data.get("config", DEFAULTS) model = BabyGPT(sp.get_piece_size(), cfg["EMBED_DIM"], cfg["BLOCK_SIZE"], cfg["NUM_HEADS"], cfg["NUM_LAYERS"], cfg["DROPOUT"], use_checkpoint=cfg.get("GRADIENT_CHECKPOINT", False)).to(device) state = data["model_state_dict"] state = _strip_orig_mod_prefix(state) # try load, try forgiving missing/unexpected keys by strict=False first, then stricter load try: model.load_state_dict(state, strict=False) except Exception as e: # final fallback: try to match only exact keys try: model.load_state_dict(state) except Exception as e2: raise RuntimeError(f"Failed to load state_dict: {e2}") model.eval() return model, sp, cfg # ---------- autocast helper ---------- def autocast_ctx(): """ Return a context manager for autocast. Robust to PyTorch versions which may accept different signatures for torch.cuda.amp.autocast. """ if not (torch.cuda.is_available() and hasattr(torch.cuda, "amp")): # dummy context class _NopCtx: def __enter__(self): return None def __exit__(self, exc_type, exc, tb): return False return _NopCtx() # try modern signature with dtype try: return torch.cuda.amp.autocast(dtype=torch.float16) except TypeError: # older signatures may accept no args or device_type try: return torch.cuda.amp.autocast() except Exception: class _NopCtx: def __enter__(self): return None def __exit__(self, exc_type, exc, tb): return False return _NopCtx() # ---------- train generator ---------- def train_generator(**params): """ Yields (log_text, plot_png_bytes) pairs (same behavior as your Gradio app expects). Accepts overrides for any DEFAULTS keys by passing them into train_generator(...). """ clear_stop_request() cfg = {**DEFAULTS, **{k: v for k, v in params.items() if v is not None}} # SentencePiece sp = ensure_sp_model(DATA_PATH, SP_MODEL_PREFIX, cfg.get("SP_VOCAB", SP_VOCAB_DEFAULT)) # Optional pre-tokenize tokens_cache = None if cfg.get("PRETOKENIZE", True): tokens_cache = build_or_load_token_cache(sp, DATA_PATH, TOKENS_PT, chunk_size=128_000) # Build model (vocab from sp) vocab_size = sp.get_piece_size() model = BabyGPT(vocab_size, cfg["EMBED_DIM"], cfg["BLOCK_SIZE"], cfg["NUM_HEADS"], cfg["NUM_LAYERS"], cfg["DROPOUT"], use_checkpoint=cfg.get("GRADIENT_CHECKPOINT", False)) # gradient checkpointing opt-in if cfg.get("GRADIENT_CHECKPOINT", False): try: model.enable_gradient_checkpointing() print("Gradient checkpointing enabled.") except Exception as e: print("Could not enable gradient checkpointing:", e) # Choose dtype/device strategy # On CUDA: keep params float32 and use FP16 autocast + GradScaler (T4 prefers FP16). chosen_dtype = torch.float32 if DEVICE.startswith("cuda") and torch.cuda.is_available(): model = model.to(torch.float32) # keep params float32 chosen_dtype = torch.float32 print("Using float32 params on CUDA. Mixed FP16 autocast will be used during forward.") else: # On CPU, try bfloat16 if available (rare) try: _ = torch.empty(1, dtype=torch.bfloat16) model = model.to(dtype=torch.bfloat16) chosen_dtype = torch.bfloat16 print("Using bfloat16 on CPU.") except Exception: model = model.to(dtype=torch.float32) chosen_dtype = torch.float32 print("Using float32 on CPU.") # Move model to device model = model.to(DEVICE) try: dev = next(model.parameters()).device device_info = str(dev) if "cuda" in device_info: dev_idx = torch.cuda.current_device() dev_name = torch.cuda.get_device_name(dev_idx) mem_alloc = torch.cuda.memory_allocated(dev_idx) / 1024**2 mem_reserved = torch.cuda.memory_reserved(dev_idx) / 1024**2 device_info = f"{device_info} ({dev_name}) alloc={mem_alloc:.1f}MB reserved={mem_reserved:.1f}MB" print("Model moved to device:", device_info) except Exception: pass # Build optimizer opt = torch.optim.AdamW(model.parameters(), lr=cfg["LR"]) # CUDA-specific optimizations use_amp = (DEVICE.startswith("cuda") and torch.cuda.is_available()) scaler = torch.cuda.amp.GradScaler() if use_amp else None if use_amp: torch.backends.cudnn.benchmark = True # optional TF32 speedup try: torch.backends.cuda.matmul.allow_tf32 = True except Exception: pass # Try torch.compile for speedups (best effort) try: model = torch.compile(model, mode="reduce-overhead") print("torch.compile applied.") torch.cuda.synchronize() t_start = time.time() last_print_time = t_start steps_since = 0 except Exception as e: print("torch.compile unavailable:", e) # training prep loss_hist = [] total_params_m = sum(p.numel() for p in model.parameters())/1e6 log = f"šŸš€ Training started | {total_params_m:.2f}M params | dtype={chosen_dtype}\n" yield (log, plot_loss(loss_hist)) model.train() global_step = 0 # timing t_start = time.time() last_print_time = t_start steps_since = 0 # read ids source: prefer pre-tokenized cache ids_source = tokens_cache if tokens_cache is not None else None for epoch in range(cfg["EPOCHS"]): log += f"\n=== Epoch {epoch+1}/{cfg['EPOCHS']} ===\n" # If we have tokens cache: iterate over chunks of the token stream if ids_source is not None: L = ids_source.numel() chunk_length = cfg["BLOCK_SIZE"] * max(1, cfg["BATCH_SIZE"]) * 2 num_chunks = max(1, L // chunk_length) for ci in range(num_chunks): if stop_requested(): return start = (ci * chunk_length) % max(1, L - chunk_length) chunk_ids = ids_source[start:start + chunk_length].tolist() steps = max(1, len(chunk_ids) // (cfg["BLOCK_SIZE"] * max(1, cfg["BATCH_SIZE"]))) for step in range(steps): if stop_requested(): return xb, yb = get_batch_from_ids(chunk_ids, cfg["BLOCK_SIZE"], cfg["BATCH_SIZE"], device=DEVICE) # forward/backward with AMP if available try: if use_amp: with autocast_ctx(): logits, loss = model(xb, yb) scaler.scale(loss / cfg["ACCUM_STEPS"]).backward() else: logits, loss = model(xb, yb) (loss / cfg["ACCUM_STEPS"]).backward() except RuntimeError as e: # handle occasional OOM gracefully by reducing batch if "out of memory" in str(e).lower(): torch.cuda.empty_cache() print("OOM during training step - skipping step (reduce BATCH_SIZE).") continue else: raise if (step + 1) % cfg["ACCUM_STEPS"] == 0: if use_amp: try: scaler.step(opt) scaler.update() except Exception as e: print("Scaler step failed:", e) opt.step() else: opt.step() opt.zero_grad(set_to_none=True) loss_hist.append(loss.item()) global_step += 1 steps_since += 1 if step % cfg["PRINT_EVERY"] == 0: torch.cuda.synchronize() now = time.time() elapsed = now - last_print_time overall_elapsed = now - t_start sps = steps_since / elapsed if elapsed > 0 else 0.0 avg_sps = global_step / overall_elapsed if overall_elapsed > 0 else 0.0 # device stats dev_stats = "" try: dev = next(model.parameters()).device if "cuda" in str(dev): dev_idx = torch.cuda.current_device() mem_alloc = torch.cuda.memory_allocated(dev_idx) / 1024**2 mem_reserved = torch.cuda.memory_reserved(dev_idx) / 1024**2 dev_name = torch.cuda.get_device_name(dev_idx) dev_stats = f" | {dev_name} alloc={mem_alloc:.1f}MB reserved={mem_reserved:.1f}MB" else: dev_stats = f" | device={dev}" except Exception: dev_stats = "" log += (f"[E{epoch+1}C{ci}] step {step}/{steps} loss={loss.item():.4f} " f"| steps/s={sps:.2f} (avg {avg_sps:.2f}){dev_stats}\n") last_print_time = now steps_since = 0 yield (log, plot_loss(loss_hist)) # free references but avoid overusing gc.collect() del xb, yb, logits, loss # end steps loop # end chunk loop else: # tokenization-on-the-fly (slower) - fallback to streaming chunks for ci, chunk in enumerate(stream_text_chunks(DATA_PATH, chunk_size=128_000)): ids = sp.encode(chunk, out_type=int) if len(ids) < cfg["BLOCK_SIZE"]: continue steps = min(1000, max(1, len(ids)//(cfg["BLOCK_SIZE"] * max(1, cfg["BATCH_SIZE"])))) for step in range(steps): if stop_requested(): return xb, yb = get_batch_from_ids(ids, cfg["BLOCK_SIZE"], cfg["BATCH_SIZE"], device=DEVICE) try: if use_amp: with autocast_ctx(): logits, loss = model(xb, yb) scaler.scale(loss / cfg["ACCUM_STEPS"]).backward() else: logits, loss = model(xb, yb) (loss / cfg["ACCUM_STEPS"]).backward() except RuntimeError as e: if "out of memory" in str(e).lower(): torch.cuda.empty_cache() print("OOM during training step - skipping step (reduce BATCH_SIZE).") continue else: raise if (step + 1) % cfg["ACCUM_STEPS"] == 0: if use_amp: try: scaler.step(opt) scaler.update() except Exception as e: print("Scaler step failed:", e) opt.step() else: opt.step() opt.zero_grad(set_to_none=True) loss_hist.append(loss.item()) global_step += 1 steps_since += 1 if step % cfg["PRINT_EVERY"] == 0: now = time.time() elapsed = now - last_print_time overall_elapsed = now - t_start sps = steps_since / elapsed if elapsed > 0 else 0.0 avg_sps = global_step / overall_elapsed if overall_elapsed > 0 else 0.0 dev_stats = "" try: dev = next(model.parameters()).device if "cuda" in str(dev): dev_idx = torch.cuda.current_device() mem_alloc = torch.cuda.memory_allocated(dev_idx) / 1024**2 mem_reserved = torch.cuda.memory_reserved(dev_idx) / 1024**2 dev_name = torch.cuda.get_device_name(dev_idx) dev_stats = f" | {dev_name} alloc={mem_alloc:.1f}MB reserved={mem_reserved:.1f}MB" else: dev_stats = f" | device={dev}" except Exception: dev_stats = "" log += (f"[E{epoch+1}C{ci}] step {step}/{steps} loss={loss.item():.4f} " f"| steps/s={sps:.2f} (avg {avg_sps:.2f}){dev_stats}\n") last_print_time = now steps_since = 0 yield (log, plot_loss(loss_hist)) del xb, yb, logits, loss # epoch end: save checkpoint ck = f"baby_gpt_epoch{epoch+1}.pth" try: save_checkpoint(model, {**cfg}, ck) log += f"šŸ’¾ Saved checkpoint {ck}\n" except Exception as e: log += f"āŒ Failed saving checkpoint {ck}: {e}\n" yield (log, plot_loss(loss_hist)) # final save try: save_checkpoint(model, {**cfg}, "baby_gpt_final.pth") log += "\nšŸŽ‰ Training complete! Saved baby_gpt_final.pth\n" except Exception as e: log += f"\nāŒ Failed final save: {e}\n" yield (log, plot_loss(loss_hist)) # ---------- minimal CLI ---------- if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--epochs", type=int, default=None) parser.add_argument("--vocab", type=int, default=None) parser.add_argument("--batch", type=int, default=None) parser.add_argument("--bs", type=int, default=None) parser.add_argument("--layers", type=int, default=None) parser.add_argument("--pretok", action="store_true", help="Force pretokenization") args = parser.parse_args() extra = {} if args.epochs: extra["EPOCHS"] = args.epochs if args.vocab: extra["SP_VOCAB"] = args.vocab if args.batch: extra["BATCH_SIZE"] = args.batch if args.bs: extra["BLOCK_SIZE"] = args.bs if args.layers: extra["NUM_LAYERS"] = args.layers if args.pretok: extra["PRETOKENIZE"] = True tg = train_generator(**extra) try: while True: out = next(tg) text, img = out os.write(1, text.encode("utf-8")) if img: with open("latest_loss.png", "wb") as f: f.write(img) except StopIteration: print("\nDone.") except KeyboardInterrupt: request_stop() print("\nInterrupted and requested stop.")