import math import torch from torch import nn class CausalSelfAttention(nn.Module): def __init__(self, d_model, n_heads, block_size, dropout): super().__init__() self.n_heads = n_heads self.head_dim = d_model // n_heads self.qkv = nn.Linear(d_model, 3 * d_model) self.out_proj = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) mask = torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size) self.register_buffer("mask", mask) def forward(self, x): batch, seq_len, channels = x.shape qkv = self.qkv(x) q, k, v = qkv.chunk(3, dim=-1) q = q.view(batch, seq_len, self.n_heads, self.head_dim).transpose(1, 2) k = k.view(batch, seq_len, self.n_heads, self.head_dim).transpose(1, 2) v = v.view(batch, seq_len, self.n_heads, self.head_dim).transpose(1, 2) scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) scores = scores.masked_fill(self.mask[:, :, :seq_len, :seq_len] == 0, float("-inf")) attn = torch.softmax(scores, dim=-1) attn = self.dropout(attn) out = attn @ v out = out.transpose(1, 2).contiguous().view(batch, seq_len, channels) return self.out_proj(out) class FeedForward(nn.Module): def __init__(self, d_model, dropout): super().__init__() self.net = nn.Sequential( nn.Linear(d_model, 4 * d_model), nn.GELU(), nn.Linear(4 * d_model, d_model), nn.Dropout(dropout), ) def forward(self, x): return self.net(x) class TransformerBlock(nn.Module): def __init__(self, d_model, n_heads, block_size, dropout): super().__init__() self.ln1 = nn.LayerNorm(d_model) self.attn = CausalSelfAttention(d_model, n_heads, block_size, dropout) self.ln2 = nn.LayerNorm(d_model) self.ff = FeedForward(d_model, dropout) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.ff(self.ln2(x)) return x class TinyTransformerLM(nn.Module): def __init__(self, vocab_size, block_size, d_model, n_heads, n_layers, dropout): super().__init__() self.block_size = block_size self.token_emb = nn.Embedding(vocab_size, d_model) self.pos_emb = nn.Embedding(block_size, d_model) self.drop = nn.Dropout(dropout) self.blocks = nn.Sequential( *[TransformerBlock(d_model, n_heads, block_size, dropout) for _ in range(n_layers)] ) self.ln_f = nn.LayerNorm(d_model) self.head = nn.Linear(d_model, vocab_size) def forward(self, idx, targets=None): batch, seq_len = idx.shape positions = torch.arange(seq_len, device=idx.device) x = self.token_emb(idx) + self.pos_emb(positions)[None, :, :] x = self.drop(x) x = self.blocks(x) x = self.ln_f(x) logits = self.head(x) loss = None if targets is not None: loss = nn.functional.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), ) return logits, loss def generate(self, idx, max_new_tokens, temperature=1.0, top_k=16): for _ in range(max_new_tokens): idx_cond = idx[:, -self.block_size :] logits, _ = self(idx_cond) logits = logits[:, -1, :] / max(temperature, 1e-4) if top_k is not None and top_k > 0: values, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < values[:, [-1]]] = float("-inf") probs = torch.softmax(logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1) idx = torch.cat([idx, next_id], dim=1) return idx