NanoGPT-X_Base / train.py
luxopes's picture
Update train.py
776b5b5 verified
# -- coding: utf-8 --
# Author: Antonín Tomeček
# Date: 3 Jan. 2026
# Description: GPT-style Transformer with Flash Attention 2, Memmap dataset,
# correct gradient accumulation, and clean English logging.
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional
from torch.utils.data import Dataset, DataLoader
from accelerate import Accelerator
from tqdm import tqdm
import sentencepiece as spm
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# =========================
# FLASH ATTENTION 2
# =========================
try:
print(f"[Info] Torch version: {torch.__version__}")
print(f"[Info] CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"[Info] CUDA version: {torch.version.cuda}")
from flash_attn import flash_attn_func
FLASH_ATTENTION_2 = True
print("[OK] Flash Attention 2 enabled")
except Exception:
FLASH_ATTENTION_2 = False
print("[WARN] Flash Attention 2 not available – using PyTorch SDPA")
# =========================
# CONFIG
# =========================
@dataclass
class ModelArgs:
dim: int = 768
n_layers: int = 12
n_heads: int = 12
n_kv_heads: int = 4
vocab_size: int = 32000
multiple_of: int = 256
ffn_dim_multiplier: float = 3.0
norm_eps: float = 1e-5
max_seq_len: int = 1024
SAVE_EVERY_STEPS = 100_000
TOKENIZER_MODEL_PATH = "tokenizer.model"
TRAIN_BIN = "dataset.bin"
VALID_BIN = "valid.bin"
CHECKPOINT_DIR = "checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
# =========================
# MODEL
# =========================
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
def precompute_freqs_cis(dim, seq_len, theta=10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) / dim))
t = torch.arange(seq_len)
freqs = torch.outer(t, freqs)
return freqs.cos(), freqs.sin()
def apply_rotary_emb(x, cos, sin):
x1, x2 = x[..., 0::2], x[..., 1::2]
cos = cos.unsqueeze(0).unsqueeze(2)
sin = sin.unsqueeze(0).unsqueeze(2)
out = torch.empty_like(x)
out[..., 0::2] = x1 * cos - x2 * sin
out[..., 1::2] = x1 * sin + x2 * cos
return out
class Attention(nn.Module):
def __init__(self, args):
super().__init__()
self.n_heads = args.n_heads
self.head_dim = args.dim // args.n_heads
self.n_kv_heads = args.n_kv_heads
self.repeat_kv = args.n_heads // args.n_kv_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
def forward(self, x, cos, sin):
B, T, _ = x.shape
q = self.wq(x).view(B, T, self.n_heads, self.head_dim)
k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim)
v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim)
k = k.repeat_interleave(self.repeat_kv, dim=2)
v = v.repeat_interleave(self.repeat_kv, dim=2)
q = apply_rotary_emb(q, cos, sin)
k = apply_rotary_emb(k, cos, sin)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if FLASH_ATTENTION_2:
out = flash_attn_func(q, k, v, causal=True)
else:
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
out = out.transpose(1, 2).contiguous().view(B, T, -1)
return self.wo(out)
class FeedForward(nn.Module):
def __init__(self, dim, multiple_of, mult):
super().__init__()
hidden = multiple_of * ((int(dim * mult) + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden, bias=False)
self.w2 = nn.Linear(hidden, dim, bias=False)
self.w3 = nn.Linear(dim, hidden, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, args):
super().__init__()
self.attn = Attention(args)
self.ffn = FeedForward(args.dim, args.multiple_of, args.ffn_dim_multiplier)
self.attn_norm = RMSNorm(args.dim, args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, args.norm_eps)
self.gradient_checkpointing = False
def forward(self, x, cos, sin):
x = x + self.attn(self.attn_norm(x), cos, sin)
if self.training and self.gradient_checkpointing:
x = x + torch.utils.checkpoint.checkpoint(
self._ffn, x, use_reentrant=False
)
else:
x = x + self.ffn(self.ffn_norm(x))
return x
def _ffn(self, x):
return self.ffn(self.ffn_norm(x))
class Transformer(nn.Module):
def __init__(self, args):
super().__init__()
self.tok_emb = nn.Embedding(args.vocab_size, args.dim)
self.layers = nn.ModuleList([TransformerBlock(args) for _ in range(args.n_layers)])
self.norm = RMSNorm(args.dim, args.norm_eps)
self.out = nn.Linear(args.dim, args.vocab_size, bias=False)
cos, sin = precompute_freqs_cis(args.dim // args.n_heads, args.max_seq_len * 2)
self.register_buffer("cos_cached", cos, persistent=False)
self.register_buffer("sin_cached", sin, persistent=False)
self.apply(self._init)
def gradient_checkpointing_enable(self):
for layer in self.layers:
layer.gradient_checkpointing = True
print("[OK] Gradient checkpointing enabled")
def _init(self, m):
if isinstance(m, (nn.Linear, nn.Embedding)):
nn.init.normal_(m.weight, std=0.02)
def forward(self, tokens):
B, T = tokens.shape
h = self.tok_emb(tokens)
cos = self.cos_cached[:T]
sin = self.sin_cached[:T]
for layer in self.layers:
h = layer(h, cos, sin)
h = self.norm(h)
return self.out(h)
def get_num_params(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
# =========================
# MEMMAP DATASET (FIXED)
# =========================
class MemmapDataset(Dataset):
def __init__(self, path: str, max_seq_len: int, stride: Optional[int] = None):
self.tokens = np.memmap(path, dtype=np.int32, mode="r")
self.max_seq_len = max_seq_len
self.stride = stride or max_seq_len // 2
max_start = len(self.tokens) - (max_seq_len + 1)
if max_start <= 0:
raise ValueError("Dataset too small for the given max_seq_len")
self.starts = list(range(0, max_start, self.stride))
if self.starts[-1] != max_start:
self.starts.append(max_start)
def __len__(self):
return len(self.starts)
def __getitem__(self, idx):
i = self.starts[idx]
seq = torch.from_numpy(
self.tokens[i:i + self.max_seq_len + 1].copy()
).long()
return seq[:-1], seq[1:]
# =========================
# TEXT GENERATION
# =========================
@torch.no_grad()
def generate_text(model, tokenizer, prompts,
max_new_tokens=128, temperature=0.8, top_p=0.95, eos_id=1):
model.eval()
device = next(model.parameters()).device
results = {}
for prompt in prompts:
ids = tokenizer.encode(prompt)
x = torch.tensor([ids], device=device)
for _ in range(max_new_tokens):
logits = model(x)[0, -1] / temperature
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
probs = torch.softmax(sorted_logits, dim=0)
cum_probs = probs.cumsum(dim=0)
mask = cum_probs > top_p
mask[1:] = mask[:-1].clone()
mask[0] = False
logits[sorted_idx[mask]] = -float("inf")
probs = torch.softmax(logits, dim=0)
next_tok = torch.multinomial(probs, 1)
x = torch.cat([x, next_tok.unsqueeze(0)], dim=1)
if next_tok.item() == eos_id:
break
results[prompt] = tokenizer.decode(x[0].tolist())
return results
# =========================
# TRAINING
# =========================
def train(
model,
train_ds,
valid_ds,
tokenizer,
args,
batch_size=1,
grad_accum=8,
epochs=1,
lr=1e-5,
warmup_steps=500,
):
accelerator = Accelerator(
mixed_precision="bf16" if torch.cuda.is_bf16_supported() else "fp16",
gradient_accumulation_steps=grad_accum,
)
model.gradient_checkpointing_enable()
train_loader = DataLoader(
train_ds,
batch_size=batch_size,
shuffle=True,
num_workers=2,
pin_memory=True,
)
valid_loader = DataLoader(
valid_ds,
batch_size=batch_size,
shuffle=False,
num_workers=2,
pin_memory=True,
)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=lr,
betas=(0.9, 0.95),
weight_decay=0.01,
)
total_steps = math.ceil(len(train_loader) / grad_accum) * epochs
def lr_lambda(step):
if step < warmup_steps:
return step / max(1, warmup_steps)
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
return 0.5 * (1.0 + math.cos(math.pi * progress))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
model, optimizer, train_loader, valid_loader, scheduler = accelerator.prepare(
model, optimizer, train_loader, valid_loader, scheduler
)
if accelerator.is_main_process:
eff_bs = batch_size * grad_accum * accelerator.num_processes
print(f"Model params: {model.get_num_params():,}")
print(f"Effective batch size: {eff_bs}")
print(f"Total optimizer steps: {total_steps}")
print(f"Flash Attention: {FLASH_ATTENTION_2}")
print("-" * 60)
global_step = 0
best_val = float("inf")
for epoch in range(epochs):
model.train()
running_loss = 0.0
pbar = tqdm(
train_loader,
disable=not accelerator.is_local_main_process,
desc=f"Epoch {epoch+1}/{epochs}",
)
for step, (x, y) in enumerate(pbar):
with accelerator.accumulate(model):
logits = model(x)
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
y.view(-1),
ignore_index=tokenizer.pad_id(),
)
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# ======== global_step podle training steps (batchů) ========
global_step += 1
# ==========================================
# PERIODIC CHECKPOINT + TEXT GENERATION
# ==========================================
if accelerator.is_main_process and global_step % SAVE_EVERY_STEPS == 0:
ckpt_path = f"{CHECKPOINT_DIR}/step_{global_step}.pt"
checkpoint = {
"step": global_step,
"model_state_dict": accelerator.unwrap_model(model).state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"model_args": args,
}
torch.save(checkpoint, ckpt_path)
print(f"[Checkpoint] Saved complete checkpoint at step {global_step}")
prompts = [
"Once upon a time",
"In a distant future",
"First step to build a rocket",
"Capital city of France",
"Artificial intelligence will",
]
samples = generate_text(
accelerator.unwrap_model(model),
tokenizer,
prompts,
max_new_tokens=100,
temperature=0.8,
top_p=0.95,
)
print(f"[Sample generation @ step {global_step}]")
for prompt, text in samples.items():
print(f"Prompt: {prompt}")
print(f"Generated: {text}")
print("-" * 50)
running_loss += loss.item()
pbar.set_postfix(
loss=f"{running_loss/(step+1):.4f}",
lr=f"{scheduler.get_last_lr()[0]:.2e}",
)
# =========================
# VALIDATION
# =========================
model.eval()
val_loss = 0.0
with torch.no_grad():
for x, y in valid_loader:
logits = model(x)
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
y.view(-1),
ignore_index=tokenizer.pad_id(),
)
val_loss += loss.item()
val_loss /= len(valid_loader)
accelerator.print(
f"[Epoch {epoch+1}] Train Loss: {running_loss/len(train_loader):.6f} | "
f"Val Loss: {val_loss:.6f}"
)
# =========================
# END-OF-EPOCH GENERATION
# =========================
if accelerator.is_main_process:
prompts = [
"Once upon a time",
"In a distant future",
"First step to build a rocket",
"Capital city of France",
"Artificial intelligence will",
]
samples = generate_text(
accelerator.unwrap_model(model),
tokenizer,
prompts,
max_new_tokens=100,
temperature=0.8,
top_p=0.95,
)
print("[Sample generation]")
for prompt, text in samples.items():
print(f"Prompt: {prompt}")
print(f"Generated: {text}")
print("-" * 50)
# =========================
# FINAL SAVE
# =========================
if accelerator.is_main_process:
checkpoint = {
"step": global_step,
"model_state_dict": accelerator.unwrap_model(model).state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"model_args": args,
}
torch.save(checkpoint, f"{CHECKPOINT_DIR}/final_model.pt")
print("Training complete.")
# =========================
# MAIN
# =========================
if __name__ == "__main__":
args = ModelArgs()
tokenizer = spm.SentencePieceProcessor(model_file=TOKENIZER_MODEL_PATH)
args.vocab_size = tokenizer.vocab_size()
train_ds = MemmapDataset(TRAIN_BIN, args.max_seq_len)
valid_ds = MemmapDataset(VALID_BIN, args.max_seq_len)
model = Transformer(args)
'''
RESUME_FROM = "checkpoints/step_200000.pt"
if os.path.exists(RESUME_FROM):
print(f"[Resume] Loading checkpoint from {RESUME_FROM}")
checkpoint = torch.load(RESUME_FROM, map_location="cpu")
# Support both old format (direct state_dict) and new format (checkpoint dict)
if "model_state_dict" in checkpoint:
model.load_state_dict(checkpoint["model_state_dict"])
print(f"[Resume] Loaded model from step {checkpoint.get('step', 'unknown')}")
else:
# Old format: checkpoint is directly the model state_dict
model.load_state_dict(checkpoint)
print(f"[Resume] Loaded model (old format)")
'''
train(
model,
train_ds,
valid_ds,
tokenizer,
args,
batch_size=1,
grad_accum=8,
epochs=1,
lr=1e-5,
)