File size: 7,507 Bytes
3b97420 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 | from __future__ import annotations
import inspect
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.utils.checkpoint
from torch.nn import functional as F
@dataclass
class GPTConfig:
vocab_size: int
block_size: int = 512
n_layer: int = 12
n_head: int = 12
n_embd: int = 768
dropout: float = 0.0
bias: bool = False
gradient_checkpointing: bool = False
class LayerNorm(nn.Module):
def __init__(self, ndim: int, bias: bool):
super().__init__()
self.weight = nn.Parameter(torch.ones(ndim))
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
class CausalSelfAttention(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
assert config.n_embd % config.n_head == 0
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.attn_dropout = config.dropout
self.resid_dropout = nn.Dropout(config.dropout)
self.n_head = config.n_head
self.n_embd = config.n_embd
self.dropout = config.dropout
def forward(self, x: torch.Tensor) -> torch.Tensor:
batch, seq_len, channels = x.size()
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
head_dim = channels // self.n_head
q = q.view(batch, seq_len, self.n_head, head_dim).transpose(1, 2)
k = k.view(batch, seq_len, self.n_head, head_dim).transpose(1, 2)
v = v.view(batch, seq_len, self.n_head, head_dim).transpose(1, 2)
y = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=self.attn_dropout if self.training else 0.0,
is_causal=True,
)
y = y.transpose(1, 2).contiguous().view(batch, seq_len, channels)
return self.resid_dropout(self.c_proj(y))
class MLP(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
self.gelu = nn.GELU(approximate="tanh")
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.dropout(self.c_proj(self.gelu(self.c_fc(x))))
class Block(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
self.attn = CausalSelfAttention(config)
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
self.mlp = MLP(config)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class GPT(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict(
{
"wte": nn.Embedding(config.vocab_size, config.n_embd),
"wpe": nn.Embedding(config.block_size, config.n_embd),
"drop": nn.Dropout(config.dropout),
"h": nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
"ln_f": LayerNorm(config.n_embd, bias=config.bias),
}
)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer.wte.weight = self.lm_head.weight
self.apply(self._init_weights)
for name, param in self.named_parameters():
if name.endswith("c_proj.weight"):
torch.nn.init.normal_(param, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
def _init_weights(self, module: nn.Module) -> None:
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(
self, idx: torch.Tensor, targets: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor | None]:
batch, seq_len = idx.size()
if seq_len > self.config.block_size:
raise ValueError(f"Sequence length {seq_len} exceeds block size {self.config.block_size}")
pos = torch.arange(0, seq_len, dtype=torch.long, device=idx.device)
x = self.transformer.drop(self.transformer.wte(idx) + self.transformer.wpe(pos))
for block in self.transformer.h:
if self.config.gradient_checkpointing and self.training:
x = torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False)
else:
x = block(x)
x = self.transformer.ln_f(x)
if targets is None:
logits = self.lm_head(x[:, [-1], :])
loss = None
else:
logits = self.lm_head(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100)
return logits, loss
@torch.no_grad()
def generate(
self,
idx: torch.Tensor,
max_new_tokens: int,
temperature: float = 0.8,
top_k: int | None = 50,
eos_id: int | None = None,
) -> torch.Tensor:
for _ in range(max_new_tokens):
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size :]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("Inf")
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
if eos_id is not None and idx_next.item() == eos_id:
break
return idx
def crop_block_size(self, block_size: int) -> None:
assert block_size <= self.config.block_size
self.config.block_size = block_size
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
def configure_optimizers(
self, weight_decay: float, learning_rate: float, betas: tuple[float, float], device_type: str
) -> torch.optim.Optimizer:
param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
decay_params = [p for _, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for _, p in param_dict.items() if p.dim() < 2]
optim_groups = [
{"params": decay_params, "weight_decay": weight_decay},
{"params": nodecay_params, "weight_decay": 0.0},
]
fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and device_type == "cuda"
extra_args = {"fused": True} if use_fused else {}
return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
def num_parameters(self) -> int:
return sum(p.numel() for p in self.parameters())
|