tiny_shakespeare / model.py
cazyundee's picture
Upload 3 files
8004885 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# ==========================================
# MODEL CONFIG (Matching your 1.2M Llama)
# ==========================================
n_embd = 128
n_head = 4
n_layer = 6
block_size = 256
dropout = 0.2
# Tiny Shakespeare Vocab
chars = ['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
vocab_size = len(chars)
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s if c in stoi]
decode = lambda l: ''.join([itos[i] for i in l])
# ==========================================
# HELPERS (RoPE & RMSNorm)
# ==========================================
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def apply_rotary_emb(xq, xk, freqs_cis):
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = freqs_cis.view(1, xq_.shape[1], 1, xq_.shape[-1])
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
x_normed = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return self.weight * x_normed
# ==========================================
# CORE LAYERS
# ==========================================
class SwiGLU(nn.Module):
def __init__(self, dim):
super().__init__()
hidden_dim = int(8/3 * dim)
hidden_dim = 4 * ((hidden_dim + 3) // 4)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class CausalSelfAttention(nn.Module):
def __init__(self):
super().__init__()
self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=False)
self.c_proj = nn.Linear(n_embd, n_embd, bias=False)
self.resid_dropout = nn.Dropout(dropout)
def forward(self, x, freqs_cis):
B, T, C = x.size()
qkv = self.c_attn(x)
q, k, v = qkv.split(n_embd, dim=2)
k = k.view(B, T, n_head, C // n_head)
q = q.view(B, T, n_head, C // n_head)
v = v.view(B, T, n_head, C // n_head)
q, k = apply_rotary_emb(q, k, freqs_cis)
k, q, v = k.transpose(1, 2), q.transpose(1, 2), v.transpose(1, 2)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, dropout_p=dropout if self.training else 0.0)
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.resid_dropout(self.c_proj(y))
class Block(nn.Module):
def __init__(self):
super().__init__()
self.ln_1 = RMSNorm(n_embd)
self.attn = CausalSelfAttention()
self.ln_2 = RMSNorm(n_embd)
self.ffwd = SwiGLU(n_embd)
def forward(self, x, freqs_cis):
x = x + self.attn(self.ln_1(x), freqs_cis)
x = x + self.ffwd(self.ln_2(x))
return x
# ==========================================
# FINAL MODEL CLASS
# ==========================================
class LanguageModel(nn.Module):
def __init__(self):
super().__init__()
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.blocks = nn.ModuleList([Block() for _ in range(n_layer)])
self.ln_f = RMSNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
self.token_embedding_table.weight = self.lm_head.weight # Weight tying
freqs_cis = precompute_freqs_cis(n_embd // n_head, block_size)
self.register_buffer("freqs_cis", freqs_cis)
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear) or isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
B, T = idx.shape
x = self.token_embedding_table(idx)
freqs_cis = self.freqs_cis[:T]
for block in self.blocks:
x = block(x, freqs_cis)
x = self.ln_f(x)
logits = self.lm_head(x)
return logits, None
def generate(self, idx, max_new_tokens):
for _ in range(max_new_tokens):
idx_cond = idx[:, -block_size:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx