Maggio33's picture
Re-sync: warstwowy src/ (core/data/train/generate/compose/tools) + wagi + karta
ab7c6e3 verified
Raw
History Blame Contribute Delete
5.08 kB
"""Mały GPT od zera (char-level) — architektura z lekcji L07.
Embedding -> [Blok: uwaga (Q/K/V + maska) + MLP/GELU + residual + LayerNorm] x N -> logity.
Czysty PyTorch, bez gotowych warstw transformera.
"""
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
@dataclass
class GPTConfig:
vocab_size: int = 53
block_size: int = 256 # ile znaków kontekstu (uwaga sięga tak daleko)
n_layer: int = 4
n_head: int = 4
n_embd: int = 128
dropout: float = 0.1
class CausalSelfAttention(nn.Module):
"""Uwaga z maską przyczynową (L07): token patrzy tylko wstecz."""
def __init__(self, cfg: GPTConfig):
super().__init__()
assert cfg.n_embd % cfg.n_head == 0
self.qkv = nn.Linear(cfg.n_embd, 3 * cfg.n_embd) # Q, K, V naraz
self.proj = nn.Linear(cfg.n_embd, cfg.n_embd)
self.attn_drop = nn.Dropout(cfg.dropout)
self.resid_drop = nn.Dropout(cfg.dropout)
self.n_head = cfg.n_head
self.n_embd = cfg.n_embd
# maska dolnotrójkątna: pozycja t widzi tylko <= t
self.register_buffer("mask", torch.tril(torch.ones(cfg.block_size, cfg.block_size))
.view(1, 1, cfg.block_size, cfg.block_size))
def forward(self, x):
B, T, C = x.shape
q, k, v = self.qkv(x).split(self.n_embd, dim=2)
# rozbij na głowice: (B, nh, T, hs)
hs = C // self.n_head
q = q.view(B, T, self.n_head, hs).transpose(1, 2)
k = k.view(B, T, self.n_head, hs).transpose(1, 2)
v = v.view(B, T, self.n_head, hs).transpose(1, 2)
# scores = Q·Kᵀ / √d (L07: /√d normalizuje wariancję)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(hs))
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
att = self.attn_drop(F.softmax(att, dim=-1))
y = att @ v # ważona suma V
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.resid_drop(self.proj(y))
class MLP(nn.Module):
"""Przetwarzanie w tokenie (L07): 768->4x->768 z GELU."""
def __init__(self, cfg: GPTConfig):
super().__init__()
self.fc = nn.Linear(cfg.n_embd, 4 * cfg.n_embd)
self.proj = nn.Linear(4 * cfg.n_embd, cfg.n_embd)
self.drop = nn.Dropout(cfg.dropout)
def forward(self, x):
return self.drop(self.proj(F.gelu(self.fc(x))))
class Block(nn.Module):
"""Residual + LayerNorm wokół uwagi i MLP (L07: autostrada x + f(x))."""
def __init__(self, cfg: GPTConfig):
super().__init__()
self.ln1 = nn.LayerNorm(cfg.n_embd)
self.attn = CausalSelfAttention(cfg)
self.ln2 = nn.LayerNorm(cfg.n_embd)
self.mlp = MLP(cfg)
def forward(self, x):
x = x + self.attn(self.ln1(x)) # komunikacja MIĘDZY tokenami
x = x + self.mlp(self.ln2(x)) # przetwarzanie W tokenie
return x
class GPT(nn.Module):
def __init__(self, cfg: GPTConfig):
super().__init__()
self.cfg = cfg
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.n_embd) # ID -> wektor
self.pos_emb = nn.Embedding(cfg.block_size, cfg.n_embd) # pozycja
self.drop = nn.Dropout(cfg.dropout)
self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layer)])
self.ln_f = nn.LayerNorm(cfg.n_embd)
self.head = nn.Linear(cfg.n_embd, cfg.vocab_size, bias=False)
self.head.weight = self.tok_emb.weight # weight tying
self.apply(self._init)
def _init(self, m):
if isinstance(m, (nn.Linear, nn.Embedding)):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, idx, targets=None):
B, T = idx.shape
pos = torch.arange(T, device=idx.device)
x = self.drop(self.tok_emb(idx) + self.pos_emb(pos))
for blk in self.blocks:
x = blk(x)
logits = self.head(self.ln_f(x))
loss = None
if targets is not None:
# cross-entropy (L03): -log p(prawdziwego następnego znaku)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=0.8, top_k=None):
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.cfg.block_size:] # przytnij do okna
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = float("-inf")
probs = F.softmax(logits, dim=-1)
nxt = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, nxt), dim=1)
return idx
def num_params(self):
return sum(p.numel() for p in self.parameters())