CGPT-124m / modeling_cgpt.py
crumb's picture
Upload 3 files
4c250c1
from tqdm.auto import tqdm
import tiktoken
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
# rotary positional embedding w/ xpos
# https://arxiv.org/abs/2104.09864
# https://arxiv.org/abs/2212.10554v1
def exists(val):
return val is not None
class RotaryEmbedding(nn.Module):
def __init__(
self,
dim,
scale_base = 512,
use_xpos = True
):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.use_xpos = use_xpos
self.scale_base = scale_base
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
self.register_buffer('scale', scale)
@property
def device(self):
return next(self.buffers()).device
def forward(self, seq_len):
device = self.device
t = torch.arange(seq_len, device = device).type_as(self.inv_freq)
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
freqs = torch.cat((freqs, freqs), dim = -1)
if not self.use_xpos:
return freqs, torch.ones(1, device = device)
power = (t - (seq_len // 2)) / self.scale_base
scale = self.scale ** rearrange(power, 'n -> n 1')
scale = torch.cat((scale, scale), dim = -1)
return freqs, scale
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(pos, t, scale = 1.):
return (t * pos.cos() * scale) + (rotate_half(t) * pos.sin() * scale)
#@title minimal GPT implementation in PyTorch (karpathy)
""" super minimal decoder-only gpt """
torch.manual_seed(1337)
class RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
return F.normalize(x, dim = -1) * self.scale * self.gamma
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
# regularization
self.n_head = config.n_head
self.n_embd = config.n_embd
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
def forward(self, x, rotary_emb=None):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k ,v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
if exists(rotary_emb):
freqs, scale = rotary_emb
q = apply_rotary_pos_emb(freqs, q, scale)
k = apply_rotary_pos_emb(freqs, k, scale ** -1)
# manual implementation of attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
# apply causal mask
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
self.nonlin = nn.GELU()
def forward(self, x):
x = self.c_fc(x)
x = self.nonlin(x)
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln = RMSNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.mlp = MLP(config)
def forward(self, x, rotary_emb=None):
lnx = self.ln(x)
x = x + self.attn(lnx, rotary_emb) + self.mlp(lnx)
return x
@dataclass
class GPTConfig:
block_size: int = 1024
vocab_size: int = 50257
n_layer: int = 6
n_head: int = 8
n_embd: int = 512
bias: bool = False
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
assert config.vocab_size is not None
assert config.block_size is not None
self.config = config
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
wpe = RotaryEmbedding(config.n_embd//config.n_head),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f = RMSNorm(config.n_embd),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
# init all weights
self.apply(self._init_weights)
# apply special scaled init to the residual projections, per GPT-2 paper
for pn, p in self.named_parameters():
if pn.endswith('c_proj.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
# report number of parameters
print("number of parameters: %d" % (sum(p.nelement() for p in self.parameters()),))
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx):
device = idx.device
b, t = idx.size()
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
# pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
pos_emb = self.transformer.wpe(t)
# forward the GPT model itself
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
# pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
x = tok_emb
for block in self.transformer.h:
x = block(x, rotary_emb=pos_emb)
x = self.transformer.ln_f(x)
logits = self.lm_head(x)
return logits
# prtobably also from karpathy or maybe max woolf idk i've been copy/pasting it between my projects
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
assert logits.dim() == 1
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits.float(), dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
return logits
def next_token(logits, temperature=1., top_k=0, top_p=0.9):
logits = logits / temperature
filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
probabilities = F.softmax(filtered_logits.float(), dim=-1)
next_token = torch.multinomial(probabilities, 1)
return next_token
def sample(gpt, input_ids, temperature=0.7, top_k=0, top_p=0, max_new_tokens=16):
for i in range(max_new_tokens):
logits = gpt(input_ids.unsqueeze(0).cuda())[:,-1,:][0] / temperature
filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
probabilities = F.softmax(filtered_logits.float(), dim=-1)
next_token=torch.multinomial(probabilities, 1)
input_ids = torch.cat([input_ids, next_token], -1)
return input_ids