File size: 8,479 Bytes
672259a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 | import math
import inspect
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
norm_x = torch.mean(x * x, dim=-1, keepdim=True)
x_normed = x * torch.rsqrt(norm_x + self.eps)
return self.weight * x_normed
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end)
freqs = torch.outer(t, freqs).float()
return torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
def apply_rotary_emb(xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)
cos = freqs_cis[:, :, 0].view(1, xq.shape[1], 1, xq.shape[-1] // 2)
sin = freqs_cis[:, :, 1].view(1, xq.shape[1], 1, xq.shape[-1] // 2)
xq_out = torch.stack([
xq_[..., 0] * cos - xq_[..., 1] * sin,
xq_[..., 0] * sin + xq_[..., 1] * cos
], dim=-1).flatten(3)
xk_out = torch.stack([
xk_[..., 0] * cos - xk_[..., 1] * sin,
xk_[..., 0] * sin + xk_[..., 1] * cos
], dim=-1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class SwiGLU(nn.Module):
def __init__(self, config):
super().__init__()
hidden_dim = int(2 * 4 * config.n_embd / 3)
hidden_dim = 256 * ((hidden_dim + 255) // 256)
self.w1 = nn.Linear(config.n_embd, hidden_dim, bias=False)
self.w2 = nn.Linear(config.n_embd, hidden_dim, bias=False)
self.w3 = nn.Linear(hidden_dim, config.n_embd, bias=False)
def forward(self, x):
return self.w3(F.silu(self.w1(x)) * self.w2(x))
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.wq = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.wk = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.wv = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.wo = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.n_head = config.n_head
self.n_embd = config.n_embd
self.head_dim = config.n_embd // config.n_head
def forward(self, x, freqs_cis):
B, T, C = x.size()
q = self.wq(x).view(B, T, self.n_head, self.head_dim)
k = self.wk(x).view(B, T, self.n_head, self.head_dim)
v = self.wv(x).view(B, T, self.n_head, self.head_dim)
q, k = apply_rotary_emb(q, k, freqs_cis)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.wo(y)
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.rmsnorm_1 = RMSNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.rmsnorm_2 = RMSNorm(config.n_embd)
self.mlp = SwiGLU(config)
def forward(self, x, freqs_cis):
x = x + self.attn(self.rmsnorm_1(x), freqs_cis)
x = x + self.mlp(self.rmsnorm_2(x))
return x
class ReflowSignalEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
self.n_signals = config.n_signals
self.n_embd = config.n_embd
self.vocab_to_signals = nn.Embedding(config.vocab_size, config.n_signals)
self.signal_basis = nn.Parameter(torch.empty(config.n_signals, config.n_embd))
def custom_init(self):
target_variance = 0.02
factor_std = math.sqrt(target_variance / math.sqrt(self.n_signals))
torch.nn.init.normal_(self.vocab_to_signals.weight, mean=0.0, std=factor_std)
torch.nn.init.normal_(self.signal_basis, mean=0.0, std=factor_std)
def get_dynamic_vocab_matrix(self):
return self.vocab_to_signals.weight @ self.signal_basis
def forward(self, idx):
recipes = self.vocab_to_signals(idx)
return recipes @ self.signal_basis
@dataclass
class GPTConfig:
block_size: int = 1024
vocab_size: int = 50304
n_layer: int = 32
n_head: int = 16
n_embd: int = 1024
n_signals: int = 1024
dropout: float = 0.0
bias: bool = False
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict(dict(
wte = ReflowSignalEmbedding(config),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f = RMSNorm(config.n_embd),
))
freqs_cis = precompute_freqs_cis(config.n_embd // config.n_head, config.block_size * 2)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
self.apply(self._init_weights)
self.transformer.wte.custom_init()
for pn, p in self.named_parameters():
if pn.endswith('wo.weight') or pn.endswith('w3.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
print(f"Number of parameters: {self.get_num_params()/1e6:.2f}M")
def get_num_params(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def estimate_mfu(self, fwdbwd_per_iter, dt):
N = self.get_num_params()
cfg = self.config
L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
flops_per_token = 6*N + 12*L*H*Q*T
flops_per_fwdbwd = flops_per_token * T
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
flops_achieved = flops_per_iter * (1.0/dt)
flops_promised = 65e12
mfu = flops_achieved / flops_promised
return mfu
def forward(self, idx, targets=None):
b, t = idx.size()
assert t <= self.config.block_size, f"Sequence length {t} exceeds block size {self.config.block_size}"
x = self.transformer.wte(idx)
freqs_cis = self.freqs_cis[:t]
for block in self.transformer.h:
x = block(x, freqs_cis)
x = self.transformer.ln_f(x)
if targets is not None:
dynamic_vocab_matrix = self.transformer.wte.get_dynamic_vocab_matrix()
logits = F.linear(x, dynamic_vocab_matrix)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
dynamic_vocab_matrix = self.transformer.wte.get_dynamic_vocab_matrix()
logits = F.linear(x[:, [-1], :], dynamic_vocab_matrix)
loss = None
return logits, loss
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
optim_groups = [
{'params': decay_params, 'weight_decay': weight_decay},
{'params': nodecay_params, 'weight_decay': 0.0}
]
use_fused = 'fused' in inspect.signature(torch.optim.AdamW).parameters and device_type == 'cuda'
return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, fused=use_fused)
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
for _ in range(max_new_tokens):
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
|