tiny-38m / model.py
darthcrawl's picture
Add files using upload-large-folder tool
6e14144 verified
"""Decoder-only transformer with RMSNorm, RoPE, SwiGLU. Educational, modern, single-GPU."""
from __future__ import annotations
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from config import ModelConfig
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
norm = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return self.weight * norm.to(x.dtype)
def build_rope_cache(seq_len: int, head_dim: int, base: float, device, dtype):
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
t = torch.arange(seq_len, device=device).float()
freqs = torch.outer(t, inv_freq)
cos = freqs.cos().to(dtype)
sin = freqs.sin().to(dtype)
return cos, sin
def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
# x: (B, H, T, D). Pair adjacent dims and rotate.
x1, x2 = x[..., 0::2], x[..., 1::2]
cos = cos[None, None, :x.size(-2), :]
sin = sin[None, None, :x.size(-2), :]
rot1 = x1 * cos - x2 * sin
rot2 = x1 * sin + x2 * cos
out = torch.stack((rot1, rot2), dim=-1).flatten(-2)
return out
class CausalSelfAttention(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
self.n_head = cfg.n_head
self.head_dim = cfg.head_dim
self.qkv = nn.Linear(cfg.n_embd, 3 * cfg.n_embd, bias=False)
self.proj = nn.Linear(cfg.n_embd, cfg.n_embd, bias=False)
self.dropout = cfg.dropout
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
B, T, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
q = apply_rope(q, cos, sin)
k = apply_rope(k, cos, sin)
y = F.scaled_dot_product_attention(
q, k, v,
is_causal=True,
dropout_p=self.dropout if self.training else 0.0,
)
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.proj(y)
class SwiGLU(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
hidden = cfg.mlp_mult * cfg.n_embd
# Round to multiple of 64 for efficiency.
hidden = ((hidden + 63) // 64) * 64
self.w1 = nn.Linear(cfg.n_embd, hidden, bias=False)
self.w3 = nn.Linear(cfg.n_embd, hidden, bias=False)
self.w2 = nn.Linear(hidden, cfg.n_embd, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class Block(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
self.norm1 = RMSNorm(cfg.n_embd)
self.attn = CausalSelfAttention(cfg)
self.norm2 = RMSNorm(cfg.n_embd)
self.mlp = SwiGLU(cfg)
def forward(self, x, cos, sin):
x = x + self.attn(self.norm1(x), cos, sin)
x = x + self.mlp(self.norm2(x))
return x
class GPT(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
self.cfg = cfg
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.n_embd)
self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layer)])
self.norm = RMSNorm(cfg.n_embd)
self.lm_head = nn.Linear(cfg.n_embd, cfg.vocab_size, bias=False)
if cfg.tie_embeddings:
self.lm_head.weight = self.tok_emb.weight
self.apply(self._init_weights)
# Scale residual projections per GPT-2 init.
for name, p in self.named_parameters():
if name.endswith("proj.weight") or name.endswith("w2.weight"):
nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * cfg.n_layer))
self._rope_cache = None
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
def num_params(self, non_embedding: bool = True) -> int:
n = sum(p.numel() for p in self.parameters())
if non_embedding and self.cfg.tie_embeddings:
n -= self.tok_emb.weight.numel()
return n
def _rope(self, T: int, device, dtype):
if (self._rope_cache is None
or self._rope_cache[0].size(0) < T
or self._rope_cache[0].device != device
or self._rope_cache[0].dtype != dtype):
self._rope_cache = build_rope_cache(
self.cfg.block_size, self.cfg.head_dim, self.cfg.rope_base, device, dtype,
)
cos, sin = self._rope_cache
return cos[:T], sin[:T]
def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None):
B, T = idx.shape
assert T <= self.cfg.block_size, f"sequence length {T} > block_size {self.cfg.block_size}"
x = self.tok_emb(idx)
cos, sin = self._rope(T, x.device, x.dtype)
for block in self.blocks:
x = block(x, cos, sin)
x = self.norm(x)
if targets is None:
logits = self.lm_head(x[:, [-1], :])
return logits, None
logits = self.lm_head(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
return logits, loss
@torch.no_grad()
def generate(self, idx: torch.Tensor, max_new_tokens: int,
temperature: float = 1.0, top_k: int | None = None,
eos_id: int | None = None):
for _ in range(max_new_tokens):
idx_cond = idx if idx.size(1) <= self.cfg.block_size else idx[:, -self.cfg.block_size:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("inf")
probs = F.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, next_id), dim=1)
if eos_id is not None and (next_id == eos_id).all():
break
return idx