pt / v2
Roman190928's picture
Create v2
d187d2d verified
# train_babygpt_optimized.py
"""
Optimized training script for small GPT on GPU (Tesla T4) with:
- FP16 mixed precision (autocast + GradScaler) when CUDA is available
- pinned memory + non_blocking transfers
- optional dataset pre-tokenization -> tokens.pt
- torch.compile try/except (won't break saving/loading)
- minimal GC, set_to_none=True for zero_grad
- gradient-checkpointing opt-in (disabled by default for speed)
- safer checkpoint saving (saved as float32 on CPU)
- robust autocast context that handles older/newer PyTorch signatures
"""
import os
import io
import time
import math
import gc
import traceback
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
# CPU thread settings (safe)
try:
ncpu = os.cpu_count() or 2
torch.set_num_threads(ncpu)
torch.set_num_interop_threads(max(1, ncpu // 2))
except Exception:
pass
try:
import sentencepiece as spm
except ImportError:
raise RuntimeError("Please install sentencepiece: pip install sentencepiece")
try:
import matplotlib.pyplot as plt
except Exception:
plt = None
# ---------- Paths / defaults ----------
DATA_PATH = "worldsim.txt"
SP_MODEL_PREFIX = "tokenizer"
SP_MODEL_FILE = f"{SP_MODEL_PREFIX}.model"
SP_VOCAB_DEFAULT = 14000
TOKENS_PT = "tokens.pt" # pre-tokenized cached file
# Device default: set to 'cuda' to use the GPU; will auto-check availability
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# GPU-optimized defaults (tuned for T4: speed + reasonable model capacity)
DEFAULTS = dict(
BLOCK_SIZE=512, # shorter sequence = more steps/sec
BATCH_SIZE=8, # bigger batch to saturate GPU (tune down if OOM)
EMBED_DIM=448, # embed size
NUM_HEADS=7,
NUM_LAYERS=10, # fewer layers => much faster
DROPOUT=0.05,
EPOCHS=6,
LR=2e-5,
PRINT_EVERY=50,
ACCUM_STEPS=1, # don't accumulate unless needed
SP_VOCAB=SP_VOCAB_DEFAULT,
GRADIENT_CHECKPOINT=True, # default off for speed on T4
PRETOKENIZE=True, # create tokens.pt to avoid tokenization overhead during training
)
# ---------- stop flag helpers ----------
_TRAIN_STOP_REQUESTED = False
def request_stop():
global _TRAIN_STOP_REQUESTED; _TRAIN_STOP_REQUESTED = True
def clear_stop_request():
global _TRAIN_STOP_REQUESTED; _TRAIN_STOP_REQUESTED = False
def stop_requested():
return _TRAIN_STOP_REQUESTED
# ---------- sentencepiece helpers ----------
def ensure_sp_model(data_path=DATA_PATH, model_prefix=SP_MODEL_PREFIX, vocab_size=SP_VOCAB_DEFAULT):
if os.path.exists(f"{model_prefix}.model") and os.path.exists(f"{model_prefix}.vocab"):
sp = spm.SentencePieceProcessor()
sp.load(f"{model_prefix}.model")
return sp
# train sentencepiece
spm.SentencePieceTrainer.train(
input=data_path,
model_prefix=model_prefix,
vocab_size=vocab_size,
model_type="bpe",
character_coverage=1.0,
unk_id=0, bos_id=-1, eos_id=-1,
)
sp = spm.SentencePieceProcessor()
sp.load(f"{model_prefix}.model")
return sp
# ---------- model ----------
class BabyGPT(nn.Module):
def __init__(self, vocab_size, embed_dim, block_size, num_heads, num_layers, dropout, use_checkpoint=False):
super().__init__()
self.block_size = block_size
self.tok_emb = nn.Embedding(vocab_size, embed_dim)
self.pos_emb = nn.Embedding(block_size, embed_dim)
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim * 4,
dropout=dropout, batch_first=True
) for _ in range(num_layers)
])
self.ln = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, vocab_size)
self._use_gradient_checkpointing = use_checkpoint
def enable_gradient_checkpointing(self):
self._use_gradient_checkpointing = True
def disable_gradient_checkpointing(self):
self._use_gradient_checkpointing = False
def forward(self, idx, targets=None):
B, T = idx.shape
device = idx.device
pos = torch.arange(0, T, device=device).unsqueeze(0)
x = self.tok_emb(idx) + self.pos_emb(pos)
for layer in self.layers:
if self._use_gradient_checkpointing and self.training:
# must pass tensors only; explicit use_reentrant for PyTorch >=2.5 compatibility
def run_layer(x_local, layer_local=layer):
return layer_local(x_local)
# pass use_reentrant explicitly to avoid PyTorch warning in 2.5+
x = checkpoint.checkpoint(run_layer, x, use_reentrant=False)
else:
x = layer(x)
x = self.ln(x)
logits = self.head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens=50, temperature=1.0, top_k=50):
self.eval()
device = next(self.parameters()).device
idx = idx.to(device)
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.block_size:].to(device)
logits, _ = self(idx_cond)
last_logits = logits[:, -1, :].to(torch.float32)
last_logits = last_logits / (temperature if temperature > 0 else 1.0)
if top_k > 0:
v, _ = torch.topk(last_logits, top_k)
last_logits[last_logits < v[:, [-1]]] = -float("Inf")
probs = F.softmax(last_logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, next_id.to(idx.device)), dim=1)
return idx
# ---------- data streaming & token caching ----------
def stream_text_chunks(path=DATA_PATH, chunk_size=128_000):
with open(path, "r", encoding="utf-8", errors="ignore") as f:
while True:
chunk = f.read(chunk_size)
if not chunk:
break
yield chunk
def build_or_load_token_cache(sp, path=DATA_PATH, tokens_pt=TOKENS_PT, chunk_size=128_000):
"""
If tokens.pt exists, load it. Otherwise tokenize the dataset (streaming) and save
a 1-D tensor of token ids to tokens.pt for fast subsequent training runs.
"""
if os.path.exists(tokens_pt):
try:
tokens = torch.load(tokens_pt, map_location="cpu")
if isinstance(tokens, torch.Tensor):
return tokens
except Exception:
pass
print("Pre-tokenizing dataset to", tokens_pt, " — this may take a while (one-time).")
all_ids = []
for chunk in stream_text_chunks(path, chunk_size):
ids = sp.encode(chunk, out_type=int)
all_ids.extend(ids)
tokens = torch.tensor(all_ids, dtype=torch.long)
torch.save(tokens, tokens_pt)
return tokens
def get_batch_from_ids(ids, block_size, batch_size, device="cpu"):
"""
ids: 1D list or 1D torch.Tensor of token ids.
Returns (x, y) on target device with pinned-memory & non_blocking copy when using CUDA.
"""
dev = torch.device(device)
use_cuda = (dev.type == "cuda")
if isinstance(ids, torch.Tensor):
L = ids.numel()
else:
L = len(ids)
if L < block_size + 1:
raise ValueError("ids too short for block size")
ix = torch.randint(0, L - block_size - 1, (batch_size,))
# allocate pinned buffers if sending to CUDA
if use_cuda:
x = torch.empty((batch_size, block_size), dtype=torch.long).pin_memory()
y = torch.empty((batch_size, block_size), dtype=torch.long).pin_memory()
else:
x = torch.empty((batch_size, block_size), dtype=torch.long)
y = torch.empty((batch_size, block_size), dtype=torch.long)
for bi, i in enumerate(ix):
if torch.is_tensor(ids):
seg = ids[i:i + block_size + 1].tolist()
else:
seg = ids[i:i + block_size + 1]
x[bi].copy_(torch.tensor(seg[:-1], dtype=torch.long))
y[bi].copy_(torch.tensor(seg[1:], dtype=torch.long))
if use_cuda:
return x.to(dev, non_blocking=True), y.to(dev, non_blocking=True)
else:
return x.to(dev), y.to(dev)
# ---------- plotting ----------
def plot_loss(history):
if not plt:
return b""
plt.switch_backend("Agg")
fig, ax = plt.subplots(figsize=(5,3))
ax.plot(history)
ax.set_title("Loss")
buf = io.BytesIO()
fig.savefig(buf, format="png")
buf.seek(0)
plt.close(fig)
return buf.read()
# ---------- checkpoint utils ----------
def _unwrap_model_for_saving(model: nn.Module) -> nn.Module:
"""Return underlying module if a compiled wrapper added attributes like _orig_mod."""
if hasattr(model, "_orig_mod"):
return model._orig_mod
# some wrappers use .module
if hasattr(model, "module"):
return model.module
return model
def save_checkpoint(model, cfg, path):
# save CPU float32 weights to ensure cross-device loading compatibility
try:
real_model = _unwrap_model_for_saving(model)
state = real_model.state_dict()
cpu_state = {k: v.detach().cpu().to(torch.float32) for k, v in state.items()}
data = {'model_state_dict': cpu_state, 'config': cfg}
torch.save(data, path)
except Exception as e:
# fallback: try normal save (best effort)
try:
torch.save({'model_state_dict': model.state_dict(), 'config': cfg}, path)
except Exception as e2:
print("Failed to save checkpoint:", e2)
raise
def latest_checkpoint():
ckpts = [f for f in os.listdir(".") if f.startswith("baby_gpt_epoch") and f.endswith(".pth")]
if ckpts:
return sorted(ckpts, key=os.path.getmtime)[-1]
return "baby_gpt_final.pth" if os.path.exists("baby_gpt_final.pth") else None
def _strip_orig_mod_prefix(state_dict: dict) -> dict:
"""Strip leading _orig_mod. prefix from keys if present."""
keys = list(state_dict.keys())
# detect if keys have the prefix
if any(k.startswith("_orig_mod.") for k in keys):
return {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
# also support accidental 'module.' prefix
if any(k.startswith("module.") for k in keys) and not any(k.startswith("tok_emb") for k in keys):
return {k.replace("module.", ""): v for k, v in state_dict.items()}
return state_dict
def load_model_for_inference(checkpoint=None, device=DEVICE):
ck = checkpoint or latest_checkpoint()
if not ck:
raise FileNotFoundError("No checkpoint.")
data = torch.load(ck, map_location="cpu")
sp = ensure_sp_model(DATA_PATH, SP_MODEL_PREFIX, DEFAULTS["SP_VOCAB"])
cfg = data.get("config", DEFAULTS)
model = BabyGPT(sp.get_piece_size(), cfg["EMBED_DIM"], cfg["BLOCK_SIZE"],
cfg["NUM_HEADS"], cfg["NUM_LAYERS"], cfg["DROPOUT"],
use_checkpoint=cfg.get("GRADIENT_CHECKPOINT", False)).to(device)
state = data["model_state_dict"]
state = _strip_orig_mod_prefix(state)
# try load, try forgiving missing/unexpected keys by strict=False first, then stricter load
try:
model.load_state_dict(state, strict=False)
except Exception as e:
# final fallback: try to match only exact keys
try:
model.load_state_dict(state)
except Exception as e2:
raise RuntimeError(f"Failed to load state_dict: {e2}")
model.eval()
return model, sp, cfg
# ---------- autocast helper ----------
def autocast_ctx():
"""
Return a context manager for autocast. Robust to PyTorch versions which may accept
different signatures for torch.cuda.amp.autocast.
"""
if not (torch.cuda.is_available() and hasattr(torch.cuda, "amp")):
# dummy context
class _NopCtx:
def __enter__(self): return None
def __exit__(self, exc_type, exc, tb): return False
return _NopCtx()
# try modern signature with dtype
try:
return torch.cuda.amp.autocast(dtype=torch.float16)
except TypeError:
# older signatures may accept no args or device_type
try:
return torch.cuda.amp.autocast()
except Exception:
class _NopCtx:
def __enter__(self): return None
def __exit__(self, exc_type, exc, tb): return False
return _NopCtx()
# ---------- train generator ----------
def train_generator(**params):
"""
Yields (log_text, plot_png_bytes) pairs (same behavior as your Gradio app expects).
Accepts overrides for any DEFAULTS keys by passing them into train_generator(...).
"""
clear_stop_request()
cfg = {**DEFAULTS, **{k: v for k, v in params.items() if v is not None}}
# SentencePiece
sp = ensure_sp_model(DATA_PATH, SP_MODEL_PREFIX, cfg.get("SP_VOCAB", SP_VOCAB_DEFAULT))
# Optional pre-tokenize
tokens_cache = None
if cfg.get("PRETOKENIZE", True):
tokens_cache = build_or_load_token_cache(sp, DATA_PATH, TOKENS_PT, chunk_size=128_000)
# Build model (vocab from sp)
vocab_size = sp.get_piece_size()
model = BabyGPT(vocab_size, cfg["EMBED_DIM"], cfg["BLOCK_SIZE"],
cfg["NUM_HEADS"], cfg["NUM_LAYERS"], cfg["DROPOUT"],
use_checkpoint=cfg.get("GRADIENT_CHECKPOINT", False))
# gradient checkpointing opt-in
if cfg.get("GRADIENT_CHECKPOINT", False):
try:
model.enable_gradient_checkpointing()
print("Gradient checkpointing enabled.")
except Exception as e:
print("Could not enable gradient checkpointing:", e)
# Choose dtype/device strategy
# On CUDA: keep params float32 and use FP16 autocast + GradScaler (T4 prefers FP16).
chosen_dtype = torch.float32
if DEVICE.startswith("cuda") and torch.cuda.is_available():
model = model.to(torch.float32) # keep params float32
chosen_dtype = torch.float32
print("Using float32 params on CUDA. Mixed FP16 autocast will be used during forward.")
else:
# On CPU, try bfloat16 if available (rare)
try:
_ = torch.empty(1, dtype=torch.bfloat16)
model = model.to(dtype=torch.bfloat16)
chosen_dtype = torch.bfloat16
print("Using bfloat16 on CPU.")
except Exception:
model = model.to(dtype=torch.float32)
chosen_dtype = torch.float32
print("Using float32 on CPU.")
# Move model to device
model = model.to(DEVICE)
try:
dev = next(model.parameters()).device
device_info = str(dev)
if "cuda" in device_info:
dev_idx = torch.cuda.current_device()
dev_name = torch.cuda.get_device_name(dev_idx)
mem_alloc = torch.cuda.memory_allocated(dev_idx) / 1024**2
mem_reserved = torch.cuda.memory_reserved(dev_idx) / 1024**2
device_info = f"{device_info} ({dev_name}) alloc={mem_alloc:.1f}MB reserved={mem_reserved:.1f}MB"
print("Model moved to device:", device_info)
except Exception:
pass
# Build optimizer
opt = torch.optim.AdamW(model.parameters(), lr=cfg["LR"])
# CUDA-specific optimizations
use_amp = (DEVICE.startswith("cuda") and torch.cuda.is_available())
scaler = torch.cuda.amp.GradScaler() if use_amp else None
if use_amp:
torch.backends.cudnn.benchmark = True
# optional TF32 speedup
try:
torch.backends.cuda.matmul.allow_tf32 = True
except Exception:
pass
# Try torch.compile for speedups (best effort)
try:
model = torch.compile(model, mode="reduce-overhead")
print("torch.compile applied.")
torch.cuda.synchronize()
t_start = time.time()
last_print_time = t_start
steps_since = 0
except Exception as e:
print("torch.compile unavailable:", e)
# training prep
loss_hist = []
total_params_m = sum(p.numel() for p in model.parameters())/1e6
log = f"🚀 Training started | {total_params_m:.2f}M params | dtype={chosen_dtype}\n"
yield (log, plot_loss(loss_hist))
model.train()
global_step = 0
# timing
t_start = time.time()
last_print_time = t_start
steps_since = 0
# read ids source: prefer pre-tokenized cache
ids_source = tokens_cache if tokens_cache is not None else None
for epoch in range(cfg["EPOCHS"]):
log += f"\n=== Epoch {epoch+1}/{cfg['EPOCHS']} ===\n"
# If we have tokens cache: iterate over chunks of the token stream
if ids_source is not None:
L = ids_source.numel()
chunk_length = cfg["BLOCK_SIZE"] * max(1, cfg["BATCH_SIZE"]) * 2
num_chunks = max(1, L // chunk_length)
for ci in range(num_chunks):
if stop_requested(): return
start = (ci * chunk_length) % max(1, L - chunk_length)
chunk_ids = ids_source[start:start + chunk_length].tolist()
steps = max(1, len(chunk_ids) // (cfg["BLOCK_SIZE"] * max(1, cfg["BATCH_SIZE"])))
for step in range(steps):
if stop_requested(): return
xb, yb = get_batch_from_ids(chunk_ids, cfg["BLOCK_SIZE"], cfg["BATCH_SIZE"], device=DEVICE)
# forward/backward with AMP if available
try:
if use_amp:
with autocast_ctx():
logits, loss = model(xb, yb)
scaler.scale(loss / cfg["ACCUM_STEPS"]).backward()
else:
logits, loss = model(xb, yb)
(loss / cfg["ACCUM_STEPS"]).backward()
except RuntimeError as e:
# handle occasional OOM gracefully by reducing batch
if "out of memory" in str(e).lower():
torch.cuda.empty_cache()
print("OOM during training step - skipping step (reduce BATCH_SIZE).")
continue
else:
raise
if (step + 1) % cfg["ACCUM_STEPS"] == 0:
if use_amp:
try:
scaler.step(opt)
scaler.update()
except Exception as e:
print("Scaler step failed:", e)
opt.step()
else:
opt.step()
opt.zero_grad(set_to_none=True)
loss_hist.append(loss.item())
global_step += 1
steps_since += 1
if step % cfg["PRINT_EVERY"] == 0:
torch.cuda.synchronize()
now = time.time()
elapsed = now - last_print_time
overall_elapsed = now - t_start
sps = steps_since / elapsed if elapsed > 0 else 0.0
avg_sps = global_step / overall_elapsed if overall_elapsed > 0 else 0.0
# device stats
dev_stats = ""
try:
dev = next(model.parameters()).device
if "cuda" in str(dev):
dev_idx = torch.cuda.current_device()
mem_alloc = torch.cuda.memory_allocated(dev_idx) / 1024**2
mem_reserved = torch.cuda.memory_reserved(dev_idx) / 1024**2
dev_name = torch.cuda.get_device_name(dev_idx)
dev_stats = f" | {dev_name} alloc={mem_alloc:.1f}MB reserved={mem_reserved:.1f}MB"
else:
dev_stats = f" | device={dev}"
except Exception:
dev_stats = ""
log += (f"[E{epoch+1}C{ci}] step {step}/{steps} loss={loss.item():.4f} "
f"| steps/s={sps:.2f} (avg {avg_sps:.2f}){dev_stats}\n")
last_print_time = now
steps_since = 0
yield (log, plot_loss(loss_hist))
# free references but avoid overusing gc.collect()
del xb, yb, logits, loss
# end steps loop
# end chunk loop
else:
# tokenization-on-the-fly (slower) - fallback to streaming chunks
for ci, chunk in enumerate(stream_text_chunks(DATA_PATH, chunk_size=128_000)):
ids = sp.encode(chunk, out_type=int)
if len(ids) < cfg["BLOCK_SIZE"]:
continue
steps = min(1000, max(1, len(ids)//(cfg["BLOCK_SIZE"] * max(1, cfg["BATCH_SIZE"]))))
for step in range(steps):
if stop_requested(): return
xb, yb = get_batch_from_ids(ids, cfg["BLOCK_SIZE"], cfg["BATCH_SIZE"], device=DEVICE)
try:
if use_amp:
with autocast_ctx():
logits, loss = model(xb, yb)
scaler.scale(loss / cfg["ACCUM_STEPS"]).backward()
else:
logits, loss = model(xb, yb)
(loss / cfg["ACCUM_STEPS"]).backward()
except RuntimeError as e:
if "out of memory" in str(e).lower():
torch.cuda.empty_cache()
print("OOM during training step - skipping step (reduce BATCH_SIZE).")
continue
else:
raise
if (step + 1) % cfg["ACCUM_STEPS"] == 0:
if use_amp:
try:
scaler.step(opt)
scaler.update()
except Exception as e:
print("Scaler step failed:", e)
opt.step()
else:
opt.step()
opt.zero_grad(set_to_none=True)
loss_hist.append(loss.item())
global_step += 1
steps_since += 1
if step % cfg["PRINT_EVERY"] == 0:
now = time.time()
elapsed = now - last_print_time
overall_elapsed = now - t_start
sps = steps_since / elapsed if elapsed > 0 else 0.0
avg_sps = global_step / overall_elapsed if overall_elapsed > 0 else 0.0
dev_stats = ""
try:
dev = next(model.parameters()).device
if "cuda" in str(dev):
dev_idx = torch.cuda.current_device()
mem_alloc = torch.cuda.memory_allocated(dev_idx) / 1024**2
mem_reserved = torch.cuda.memory_reserved(dev_idx) / 1024**2
dev_name = torch.cuda.get_device_name(dev_idx)
dev_stats = f" | {dev_name} alloc={mem_alloc:.1f}MB reserved={mem_reserved:.1f}MB"
else:
dev_stats = f" | device={dev}"
except Exception:
dev_stats = ""
log += (f"[E{epoch+1}C{ci}] step {step}/{steps} loss={loss.item():.4f} "
f"| steps/s={sps:.2f} (avg {avg_sps:.2f}){dev_stats}\n")
last_print_time = now
steps_since = 0
yield (log, plot_loss(loss_hist))
del xb, yb, logits, loss
# epoch end: save checkpoint
ck = f"baby_gpt_epoch{epoch+1}.pth"
try:
save_checkpoint(model, {**cfg}, ck)
log += f"💾 Saved checkpoint {ck}\n"
except Exception as e:
log += f"❌ Failed saving checkpoint {ck}: {e}\n"
yield (log, plot_loss(loss_hist))
# final save
try:
save_checkpoint(model, {**cfg}, "baby_gpt_final.pth")
log += "\n🎉 Training complete! Saved baby_gpt_final.pth\n"
except Exception as e:
log += f"\n❌ Failed final save: {e}\n"
yield (log, plot_loss(loss_hist))
# ---------- minimal CLI ----------
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=None)
parser.add_argument("--vocab", type=int, default=None)
parser.add_argument("--batch", type=int, default=None)
parser.add_argument("--bs", type=int, default=None)
parser.add_argument("--layers", type=int, default=None)
parser.add_argument("--pretok", action="store_true", help="Force pretokenization")
args = parser.parse_args()
extra = {}
if args.epochs: extra["EPOCHS"] = args.epochs
if args.vocab: extra["SP_VOCAB"] = args.vocab
if args.batch: extra["BATCH_SIZE"] = args.batch
if args.bs: extra["BLOCK_SIZE"] = args.bs
if args.layers: extra["NUM_LAYERS"] = args.layers
if args.pretok: extra["PRETOKENIZE"] = True
tg = train_generator(**extra)
try:
while True:
out = next(tg)
text, img = out
os.write(1, text.encode("utf-8"))
if img:
with open("latest_loss.png", "wb") as f:
f.write(img)
except StopIteration:
print("\nDone.")
except KeyboardInterrupt:
request_stop()
print("\nInterrupted and requested stop.")