# -- coding: utf-8 -- # Author: Antonín Tomeček # Date: 3 Jan. 2026 # Description: GPT-style Transformer with Flash Attention 2, Memmap dataset, # correct gradient accumulation, and clean English logging. import os os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from dataclasses import dataclass from typing import Optional from torch.utils.data import Dataset, DataLoader from accelerate import Accelerator from tqdm import tqdm import sentencepiece as spm torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # ========================= # FLASH ATTENTION 2 # ========================= try: print(f"[Info] Torch version: {torch.__version__}") print(f"[Info] CUDA available: {torch.cuda.is_available()}") if torch.cuda.is_available(): print(f"[Info] CUDA version: {torch.version.cuda}") from flash_attn import flash_attn_func FLASH_ATTENTION_2 = True print("[OK] Flash Attention 2 enabled") except Exception: FLASH_ATTENTION_2 = False print("[WARN] Flash Attention 2 not available – using PyTorch SDPA") # ========================= # CONFIG # ========================= @dataclass class ModelArgs: dim: int = 768 n_layers: int = 12 n_heads: int = 12 n_kv_heads: int = 4 vocab_size: int = 32000 multiple_of: int = 256 ffn_dim_multiplier: float = 3.0 norm_eps: float = 1e-5 max_seq_len: int = 1024 SAVE_EVERY_STEPS = 100_000 TOKENIZER_MODEL_PATH = "tokenizer.model" TRAIN_BIN = "dataset.bin" VALID_BIN = "valid.bin" CHECKPOINT_DIR = "checkpoints" os.makedirs(CHECKPOINT_DIR, exist_ok=True) # ========================= # MODEL # ========================= class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight def precompute_freqs_cis(dim, seq_len, theta=10000.0): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) / dim)) t = torch.arange(seq_len) freqs = torch.outer(t, freqs) return freqs.cos(), freqs.sin() def apply_rotary_emb(x, cos, sin): x1, x2 = x[..., 0::2], x[..., 1::2] cos = cos.unsqueeze(0).unsqueeze(2) sin = sin.unsqueeze(0).unsqueeze(2) out = torch.empty_like(x) out[..., 0::2] = x1 * cos - x2 * sin out[..., 1::2] = x1 * sin + x2 * cos return out class Attention(nn.Module): def __init__(self, args): super().__init__() self.n_heads = args.n_heads self.head_dim = args.dim // args.n_heads self.n_kv_heads = args.n_kv_heads self.repeat_kv = args.n_heads // args.n_kv_heads self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) self.wk = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False) self.wv = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False) self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) def forward(self, x, cos, sin): B, T, _ = x.shape q = self.wq(x).view(B, T, self.n_heads, self.head_dim) k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim) v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim) k = k.repeat_interleave(self.repeat_kv, dim=2) v = v.repeat_interleave(self.repeat_kv, dim=2) q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) if FLASH_ATTENTION_2: out = flash_attn_func(q, k, v, causal=True) else: out = F.scaled_dot_product_attention(q, k, v, is_causal=True) out = out.transpose(1, 2).contiguous().view(B, T, -1) return self.wo(out) class FeedForward(nn.Module): def __init__(self, dim, multiple_of, mult): super().__init__() hidden = multiple_of * ((int(dim * mult) + multiple_of - 1) // multiple_of) self.w1 = nn.Linear(dim, hidden, bias=False) self.w2 = nn.Linear(hidden, dim, bias=False) self.w3 = nn.Linear(dim, hidden, bias=False) def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) class TransformerBlock(nn.Module): def __init__(self, args): super().__init__() self.attn = Attention(args) self.ffn = FeedForward(args.dim, args.multiple_of, args.ffn_dim_multiplier) self.attn_norm = RMSNorm(args.dim, args.norm_eps) self.ffn_norm = RMSNorm(args.dim, args.norm_eps) self.gradient_checkpointing = False def forward(self, x, cos, sin): x = x + self.attn(self.attn_norm(x), cos, sin) if self.training and self.gradient_checkpointing: x = x + torch.utils.checkpoint.checkpoint( self._ffn, x, use_reentrant=False ) else: x = x + self.ffn(self.ffn_norm(x)) return x def _ffn(self, x): return self.ffn(self.ffn_norm(x)) class Transformer(nn.Module): def __init__(self, args): super().__init__() self.tok_emb = nn.Embedding(args.vocab_size, args.dim) self.layers = nn.ModuleList([TransformerBlock(args) for _ in range(args.n_layers)]) self.norm = RMSNorm(args.dim, args.norm_eps) self.out = nn.Linear(args.dim, args.vocab_size, bias=False) cos, sin = precompute_freqs_cis(args.dim // args.n_heads, args.max_seq_len * 2) self.register_buffer("cos_cached", cos, persistent=False) self.register_buffer("sin_cached", sin, persistent=False) self.apply(self._init) def gradient_checkpointing_enable(self): for layer in self.layers: layer.gradient_checkpointing = True print("[OK] Gradient checkpointing enabled") def _init(self, m): if isinstance(m, (nn.Linear, nn.Embedding)): nn.init.normal_(m.weight, std=0.02) def forward(self, tokens): B, T = tokens.shape h = self.tok_emb(tokens) cos = self.cos_cached[:T] sin = self.sin_cached[:T] for layer in self.layers: h = layer(h, cos, sin) h = self.norm(h) return self.out(h) def get_num_params(self): return sum(p.numel() for p in self.parameters() if p.requires_grad) # ========================= # MEMMAP DATASET (FIXED) # ========================= class MemmapDataset(Dataset): def __init__(self, path: str, max_seq_len: int, stride: Optional[int] = None): self.tokens = np.memmap(path, dtype=np.int32, mode="r") self.max_seq_len = max_seq_len self.stride = stride or max_seq_len // 2 max_start = len(self.tokens) - (max_seq_len + 1) if max_start <= 0: raise ValueError("Dataset too small for the given max_seq_len") self.starts = list(range(0, max_start, self.stride)) if self.starts[-1] != max_start: self.starts.append(max_start) def __len__(self): return len(self.starts) def __getitem__(self, idx): i = self.starts[idx] seq = torch.from_numpy( self.tokens[i:i + self.max_seq_len + 1].copy() ).long() return seq[:-1], seq[1:] # ========================= # TEXT GENERATION # ========================= @torch.no_grad() def generate_text(model, tokenizer, prompts, max_new_tokens=128, temperature=0.8, top_p=0.95, eos_id=1): model.eval() device = next(model.parameters()).device results = {} for prompt in prompts: ids = tokenizer.encode(prompt) x = torch.tensor([ids], device=device) for _ in range(max_new_tokens): logits = model(x)[0, -1] / temperature sorted_logits, sorted_idx = torch.sort(logits, descending=True) probs = torch.softmax(sorted_logits, dim=0) cum_probs = probs.cumsum(dim=0) mask = cum_probs > top_p mask[1:] = mask[:-1].clone() mask[0] = False logits[sorted_idx[mask]] = -float("inf") probs = torch.softmax(logits, dim=0) next_tok = torch.multinomial(probs, 1) x = torch.cat([x, next_tok.unsqueeze(0)], dim=1) if next_tok.item() == eos_id: break results[prompt] = tokenizer.decode(x[0].tolist()) return results # ========================= # TRAINING # ========================= def train( model, train_ds, valid_ds, tokenizer, args, batch_size=1, grad_accum=8, epochs=1, lr=1e-5, warmup_steps=500, ): accelerator = Accelerator( mixed_precision="bf16" if torch.cuda.is_bf16_supported() else "fp16", gradient_accumulation_steps=grad_accum, ) model.gradient_checkpointing_enable() train_loader = DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, ) valid_loader = DataLoader( valid_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True, ) optimizer = torch.optim.AdamW( model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.01, ) total_steps = math.ceil(len(train_loader) / grad_accum) * epochs def lr_lambda(step): if step < warmup_steps: return step / max(1, warmup_steps) progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) return 0.5 * (1.0 + math.cos(math.pi * progress)) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) model, optimizer, train_loader, valid_loader, scheduler = accelerator.prepare( model, optimizer, train_loader, valid_loader, scheduler ) if accelerator.is_main_process: eff_bs = batch_size * grad_accum * accelerator.num_processes print(f"Model params: {model.get_num_params():,}") print(f"Effective batch size: {eff_bs}") print(f"Total optimizer steps: {total_steps}") print(f"Flash Attention: {FLASH_ATTENTION_2}") print("-" * 60) global_step = 0 best_val = float("inf") for epoch in range(epochs): model.train() running_loss = 0.0 pbar = tqdm( train_loader, disable=not accelerator.is_local_main_process, desc=f"Epoch {epoch+1}/{epochs}", ) for step, (x, y) in enumerate(pbar): with accelerator.accumulate(model): logits = model(x) loss = F.cross_entropy( logits.view(-1, logits.size(-1)), y.view(-1), ignore_index=tokenizer.pad_id(), ) accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() optimizer.zero_grad() # ======== global_step podle training steps (batchů) ======== global_step += 1 # ========================================== # PERIODIC CHECKPOINT + TEXT GENERATION # ========================================== if accelerator.is_main_process and global_step % SAVE_EVERY_STEPS == 0: ckpt_path = f"{CHECKPOINT_DIR}/step_{global_step}.pt" checkpoint = { "step": global_step, "model_state_dict": accelerator.unwrap_model(model).state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "model_args": args, } torch.save(checkpoint, ckpt_path) print(f"[Checkpoint] Saved complete checkpoint at step {global_step}") prompts = [ "Once upon a time", "In a distant future", "First step to build a rocket", "Capital city of France", "Artificial intelligence will", ] samples = generate_text( accelerator.unwrap_model(model), tokenizer, prompts, max_new_tokens=100, temperature=0.8, top_p=0.95, ) print(f"[Sample generation @ step {global_step}]") for prompt, text in samples.items(): print(f"Prompt: {prompt}") print(f"Generated: {text}") print("-" * 50) running_loss += loss.item() pbar.set_postfix( loss=f"{running_loss/(step+1):.4f}", lr=f"{scheduler.get_last_lr()[0]:.2e}", ) # ========================= # VALIDATION # ========================= model.eval() val_loss = 0.0 with torch.no_grad(): for x, y in valid_loader: logits = model(x) loss = F.cross_entropy( logits.view(-1, logits.size(-1)), y.view(-1), ignore_index=tokenizer.pad_id(), ) val_loss += loss.item() val_loss /= len(valid_loader) accelerator.print( f"[Epoch {epoch+1}] Train Loss: {running_loss/len(train_loader):.6f} | " f"Val Loss: {val_loss:.6f}" ) # ========================= # END-OF-EPOCH GENERATION # ========================= if accelerator.is_main_process: prompts = [ "Once upon a time", "In a distant future", "First step to build a rocket", "Capital city of France", "Artificial intelligence will", ] samples = generate_text( accelerator.unwrap_model(model), tokenizer, prompts, max_new_tokens=100, temperature=0.8, top_p=0.95, ) print("[Sample generation]") for prompt, text in samples.items(): print(f"Prompt: {prompt}") print(f"Generated: {text}") print("-" * 50) # ========================= # FINAL SAVE # ========================= if accelerator.is_main_process: checkpoint = { "step": global_step, "model_state_dict": accelerator.unwrap_model(model).state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "model_args": args, } torch.save(checkpoint, f"{CHECKPOINT_DIR}/final_model.pt") print("Training complete.") # ========================= # MAIN # ========================= if __name__ == "__main__": args = ModelArgs() tokenizer = spm.SentencePieceProcessor(model_file=TOKENIZER_MODEL_PATH) args.vocab_size = tokenizer.vocab_size() train_ds = MemmapDataset(TRAIN_BIN, args.max_seq_len) valid_ds = MemmapDataset(VALID_BIN, args.max_seq_len) model = Transformer(args) ''' RESUME_FROM = "checkpoints/step_200000.pt" if os.path.exists(RESUME_FROM): print(f"[Resume] Loading checkpoint from {RESUME_FROM}") checkpoint = torch.load(RESUME_FROM, map_location="cpu") # Support both old format (direct state_dict) and new format (checkpoint dict) if "model_state_dict" in checkpoint: model.load_state_dict(checkpoint["model_state_dict"]) print(f"[Resume] Loaded model from step {checkpoint.get('step', 'unknown')}") else: # Old format: checkpoint is directly the model state_dict model.load_state_dict(checkpoint) print(f"[Resume] Loaded model (old format)") ''' train( model, train_ds, valid_ds, tokenizer, args, batch_size=1, grad_accum=8, epochs=1, lr=1e-5, )