import torch import torch.nn as nn import torch.nn.functional as F import math # ========================================== # MODEL CONFIG (Matching your 1.2M Llama) # ========================================== n_embd = 128 n_head = 4 n_layer = 6 block_size = 256 dropout = 0.2 # Tiny Shakespeare Vocab chars = ['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] vocab_size = len(chars) stoi = { ch:i for i,ch in enumerate(chars) } itos = { i:ch for i,ch in enumerate(chars) } encode = lambda s: [stoi[c] for c in s if c in stoi] decode = lambda l: ''.join([itos[i] for i in l]) # ========================================== # HELPERS (RoPE & RMSNorm) # ========================================== def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) freqs = torch.outer(t, freqs).float() freqs_cis = torch.polar(torch.ones_like(freqs), freqs) return freqs_cis def apply_rotary_emb(xq, xk, freqs_cis): xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) freqs_cis = freqs_cis.view(1, xq_.shape[1], 1, xq_.shape[-1]) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): x_normed = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return self.weight * x_normed # ========================================== # CORE LAYERS # ========================================== class SwiGLU(nn.Module): def __init__(self, dim): super().__init__() hidden_dim = int(8/3 * dim) hidden_dim = 4 * ((hidden_dim + 3) // 4) self.w1 = nn.Linear(dim, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, dim, bias=False) self.w3 = nn.Linear(dim, hidden_dim, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x): return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) class CausalSelfAttention(nn.Module): def __init__(self): super().__init__() self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=False) self.c_proj = nn.Linear(n_embd, n_embd, bias=False) self.resid_dropout = nn.Dropout(dropout) def forward(self, x, freqs_cis): B, T, C = x.size() qkv = self.c_attn(x) q, k, v = qkv.split(n_embd, dim=2) k = k.view(B, T, n_head, C // n_head) q = q.view(B, T, n_head, C // n_head) v = v.view(B, T, n_head, C // n_head) q, k = apply_rotary_emb(q, k, freqs_cis) k, q, v = k.transpose(1, 2), q.transpose(1, 2), v.transpose(1, 2) y = F.scaled_dot_product_attention(q, k, v, is_causal=True, dropout_p=dropout if self.training else 0.0) y = y.transpose(1, 2).contiguous().view(B, T, C) return self.resid_dropout(self.c_proj(y)) class Block(nn.Module): def __init__(self): super().__init__() self.ln_1 = RMSNorm(n_embd) self.attn = CausalSelfAttention() self.ln_2 = RMSNorm(n_embd) self.ffwd = SwiGLU(n_embd) def forward(self, x, freqs_cis): x = x + self.attn(self.ln_1(x), freqs_cis) x = x + self.ffwd(self.ln_2(x)) return x # ========================================== # FINAL MODEL CLASS # ========================================== class LanguageModel(nn.Module): def __init__(self): super().__init__() self.token_embedding_table = nn.Embedding(vocab_size, n_embd) self.blocks = nn.ModuleList([Block() for _ in range(n_layer)]) self.ln_f = RMSNorm(n_embd) self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) self.token_embedding_table.weight = self.lm_head.weight # Weight tying freqs_cis = precompute_freqs_cis(n_embd // n_head, block_size) self.register_buffer("freqs_cis", freqs_cis) self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear) or isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, idx, targets=None): B, T = idx.shape x = self.token_embedding_table(idx) freqs_cis = self.freqs_cis[:T] for block in self.blocks: x = block(x, freqs_cis) x = self.ln_f(x) logits = self.lm_head(x) return logits, None def generate(self, idx, max_new_tokens): for _ in range(max_new_tokens): idx_cond = idx[:, -block_size:] logits, _ = self(idx_cond) logits = logits[:, -1, :] probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) idx = torch.cat((idx, idx_next), dim=1) return idx