"""Mały GPT od zera (char-level) — architektura z lekcji L07. Embedding -> [Blok: uwaga (Q/K/V + maska) + MLP/GELU + residual + LayerNorm] x N -> logity. Czysty PyTorch, bez gotowych warstw transformera. """ import math from dataclasses import dataclass import torch import torch.nn as nn from torch.nn import functional as F @dataclass class GPTConfig: vocab_size: int = 53 block_size: int = 256 # ile znaków kontekstu (uwaga sięga tak daleko) n_layer: int = 4 n_head: int = 4 n_embd: int = 128 dropout: float = 0.1 class CausalSelfAttention(nn.Module): """Uwaga z maską przyczynową (L07): token patrzy tylko wstecz.""" def __init__(self, cfg: GPTConfig): super().__init__() assert cfg.n_embd % cfg.n_head == 0 self.qkv = nn.Linear(cfg.n_embd, 3 * cfg.n_embd) # Q, K, V naraz self.proj = nn.Linear(cfg.n_embd, cfg.n_embd) self.attn_drop = nn.Dropout(cfg.dropout) self.resid_drop = nn.Dropout(cfg.dropout) self.n_head = cfg.n_head self.n_embd = cfg.n_embd # maska dolnotrójkątna: pozycja t widzi tylko <= t self.register_buffer("mask", torch.tril(torch.ones(cfg.block_size, cfg.block_size)) .view(1, 1, cfg.block_size, cfg.block_size)) def forward(self, x): B, T, C = x.shape q, k, v = self.qkv(x).split(self.n_embd, dim=2) # rozbij na głowice: (B, nh, T, hs) hs = C // self.n_head q = q.view(B, T, self.n_head, hs).transpose(1, 2) k = k.view(B, T, self.n_head, hs).transpose(1, 2) v = v.view(B, T, self.n_head, hs).transpose(1, 2) # scores = Q·Kᵀ / √d (L07: /√d normalizuje wariancję) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(hs)) att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf")) att = self.attn_drop(F.softmax(att, dim=-1)) y = att @ v # ważona suma V y = y.transpose(1, 2).contiguous().view(B, T, C) return self.resid_drop(self.proj(y)) class MLP(nn.Module): """Przetwarzanie w tokenie (L07): 768->4x->768 z GELU.""" def __init__(self, cfg: GPTConfig): super().__init__() self.fc = nn.Linear(cfg.n_embd, 4 * cfg.n_embd) self.proj = nn.Linear(4 * cfg.n_embd, cfg.n_embd) self.drop = nn.Dropout(cfg.dropout) def forward(self, x): return self.drop(self.proj(F.gelu(self.fc(x)))) class Block(nn.Module): """Residual + LayerNorm wokół uwagi i MLP (L07: autostrada x + f(x)).""" def __init__(self, cfg: GPTConfig): super().__init__() self.ln1 = nn.LayerNorm(cfg.n_embd) self.attn = CausalSelfAttention(cfg) self.ln2 = nn.LayerNorm(cfg.n_embd) self.mlp = MLP(cfg) def forward(self, x): x = x + self.attn(self.ln1(x)) # komunikacja MIĘDZY tokenami x = x + self.mlp(self.ln2(x)) # przetwarzanie W tokenie return x class GPT(nn.Module): def __init__(self, cfg: GPTConfig): super().__init__() self.cfg = cfg self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.n_embd) # ID -> wektor self.pos_emb = nn.Embedding(cfg.block_size, cfg.n_embd) # pozycja self.drop = nn.Dropout(cfg.dropout) self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layer)]) self.ln_f = nn.LayerNorm(cfg.n_embd) self.head = nn.Linear(cfg.n_embd, cfg.vocab_size, bias=False) self.head.weight = self.tok_emb.weight # weight tying self.apply(self._init) def _init(self, m): if isinstance(m, (nn.Linear, nn.Embedding)): nn.init.normal_(m.weight, mean=0.0, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.zeros_(m.bias) def forward(self, idx, targets=None): B, T = idx.shape pos = torch.arange(T, device=idx.device) x = self.drop(self.tok_emb(idx) + self.pos_emb(pos)) for blk in self.blocks: x = blk(x) logits = self.head(self.ln_f(x)) loss = None if targets is not None: # cross-entropy (L03): -log p(prawdziwego następnego znaku) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) return logits, loss @torch.no_grad() def generate(self, idx, max_new_tokens, temperature=0.8, top_k=None): for _ in range(max_new_tokens): idx_cond = idx[:, -self.cfg.block_size:] # przytnij do okna logits, _ = self(idx_cond) logits = logits[:, -1, :] / temperature if top_k is not None: v, _ = torch.topk(logits, top_k) logits[logits < v[:, [-1]]] = float("-inf") probs = F.softmax(logits, dim=-1) nxt = torch.multinomial(probs, num_samples=1) idx = torch.cat((idx, nxt), dim=1) return idx def num_params(self): return sum(p.numel() for p in self.parameters())