| | """ |
| | Full definition of a GPT Language Model, all of it in this single file. |
| | References: |
| | 1) nanoGPT by Karpathy: |
| | https://github.com/karpathy/nanoGPT/tree/eba36e84649f3c6d840a93092cb779a260544d08 |
| | 2) the official GPT-2 TensorFlow implementation released by OpenAI: |
| | https://github.com/openai/gpt-2/blob/master/src/model.py |
| | 3) huggingface/transformers PyTorch implementation: |
| | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py |
| | """ |
| |
|
| | import math |
| | import inspect |
| |
|
| | import tiktoken |
| | import torch |
| | import torch.nn as nn |
| | from torch.nn import functional as F |
| | from huggingface_hub import PyTorchModelHubMixin |
| | from types import SimpleNamespace |
| |
|
| | from .moe import ( |
| | |
| | MaskedMoE, |
| | TimeDependantMoE, |
| | MoE, |
| | ) |
| |
|
| | from .aux_losses import ( |
| | entropy_reg, |
| | load_balancing_loss, |
| | router_z_loss, |
| | ) |
| |
|
| | class CausalSelfAttention(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | assert config.n_embd % config.n_head == 0 |
| | |
| | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) |
| | |
| | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) |
| | |
| | self.attn_dropout = nn.Dropout(config.dropout) |
| | self.resid_dropout = nn.Dropout(config.dropout) |
| | self.n_head = config.n_head |
| | self.n_embd = config.n_embd |
| | self.dropout = config.dropout |
| | |
| | self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") |
| | if not self.flash: |
| | print( |
| | "WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0" |
| | ) |
| | |
| | self.register_buffer( |
| | "bias", |
| | torch.tril( |
| | torch.ones(config.sequence_length, config.sequence_length) |
| | ).view(1, 1, config.sequence_length, config.sequence_length), |
| | ) |
| |
|
| | def forward(self, x): |
| | if x.ndim != 3: |
| | x = x.unsqueeze(0) |
| | |
| | ( |
| | B, |
| | T, |
| | C, |
| | ) = x.size() |
| |
|
| | |
| | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) |
| | |
| | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
| | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
| |
|
| | |
| | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
| |
|
| | |
| | if self.flash: |
| | |
| | y = torch.nn.functional.scaled_dot_product_attention( |
| | q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True |
| | ) |
| | else: |
| | |
| | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) |
| | att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) |
| | att = F.softmax(att, dim=-1) |
| | att = self.attn_dropout(att) |
| | y = att @ v |
| | y = ( |
| | y.transpose(1, 2).contiguous().view(B, T, C) |
| | ) |
| |
|
| | |
| | y = self.resid_dropout(self.c_proj(y)) |
| | return y, {} |
| |
|
| |
|
| | class LayerNorm(nn.Module): |
| | """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" |
| |
|
| | def __init__(self, ndim, bias): |
| | super().__init__() |
| | self.weight = nn.Parameter(torch.ones(ndim)) |
| | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None |
| |
|
| | def forward(self, input): |
| | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) |
| |
|
| |
|
| | class MLP(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.dim_exp_factor = int(config.mlp_dim_exp_factor * 4) |
| |
|
| | self.c_fc = nn.Linear( |
| | config.n_embd, self.dim_exp_factor * config.n_embd, bias=config.bias |
| | ) |
| | self.c_proj = nn.Linear( |
| | self.dim_exp_factor * config.n_embd, config.n_embd, bias=config.bias |
| | ) |
| | self.dropout = nn.Dropout(config.dropout) |
| | self.activation = nn.GELU() |
| |
|
| | def forward(self, x): |
| | x = self.c_fc(x) |
| | x = self.activation(x) |
| | x = self.c_proj(x) |
| | x = self.dropout(x) |
| | |
| | return x, {} |
| |
|
| | class Block(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.moe_config = config.moe_routing |
| | self.shared_attention = config.shared_attention |
| | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) |
| | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) |
| | if not config.moe and not config.shared_attention: |
| | raise ValueError( |
| | "If not using MoE, shared attention must be set to True" |
| | ) |
| |
|
| | if self.shared_attention: |
| | self.attn = CausalSelfAttention(config) |
| | |
| | if config.moe: |
| | if config.moe_routing == "standard_gating": |
| | self.mlp = MoE(config, MLP) |
| | if not self.shared_attention: |
| | self.attn = MoE(config, CausalSelfAttention) |
| | elif config.moe_routing == "masked": |
| | self.mlp = TimeDependantMoE(config, MLP) |
| | if not self.shared_attention: |
| | self.attn = TimeDependantMoE(config, CausalSelfAttention) |
| | else: |
| | raise ValueError(f"Unknown routing: {config.routing}") |
| | else: |
| | self.mlp = MLP(config) |
| |
|
| | def forward(self, x, date, *args, **kwargs): |
| | if self.moe_config == "masked": |
| | if self.shared_attention: |
| | attn_output, attn_logits_and_experts = self.attn(self.ln_1(x, *args, **kwargs)) |
| | else: |
| | attn_output, attn_logits_and_experts = self.attn(self.ln_1(x, *args, **kwargs), date) |
| | x = x + attn_output |
| | x_, mlp_logits_and_experts = self.mlp(self.ln_2(x, *args, **kwargs), date) |
| | else: |
| | attn_output, attn_logits_and_experts = self.attn(self.ln_1(x, *args, **kwargs)) |
| | x = x + attn_output |
| | x_, mlp_logits_and_experts = self.mlp(self.ln_2(x, *args, **kwargs)) |
| | x = x + x_ |
| | return x, mlp_logits_and_experts, attn_logits_and_experts |
| |
|
| | class GPTBase(nn.Module, PyTorchModelHubMixin): |
| | def __init__(self, config): |
| | super().__init__() |
| | if isinstance(config, dict): |
| | |
| | config = SimpleNamespace(**config) |
| |
|
| | assert config.vocab_size is not None |
| | assert config.sequence_length is not None |
| | self.config = config |
| | self.tokenizer = tiktoken.get_encoding("gpt2") |
| |
|
| | self.transformer = nn.ModuleDict( |
| | dict( |
| | wte=nn.Embedding(config.vocab_size, config.n_embd), |
| | wpe=nn.Embedding(config.sequence_length, config.n_embd), |
| | drop=nn.Dropout(config.dropout), |
| | h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), |
| | ln_f=LayerNorm(config.n_embd, bias=config.bias), |
| | ) |
| | ) |
| |
|
| | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
| | |
| | |
| | |
| | |
| | self.transformer.wte.weight = ( |
| | self.lm_head.weight |
| | ) |
| |
|
| | |
| | self.apply(self._init_weights) |
| | |
| | for pn, p in self.named_parameters(): |
| | if pn.endswith("c_proj.weight"): |
| | torch.nn.init.normal_( |
| | p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) |
| | ) |
| | if pn.endswith("router.weight"): |
| | |
| | with torch.no_grad(): |
| | dim = 1 if config.moe_routing == "standard_gating" else 0 |
| | std = p.std() |
| | p.div_(p.sum(dim=dim, keepdim=True)) |
| | p.mul_(std / p.std()) |
| |
|
| | def get_router_losses(self, logits, selected_experts, eval=False): |
| | |
| | |
| | if eval: |
| | return { |
| | "moe_entropy_loss": entropy_reg(logits), |
| | "moe_aux_loss": load_balancing_loss(logits, selected_experts), |
| | "moe_z_loss": router_z_loss(logits), |
| | } |
| | if self.config.moe_router_loss == "entropy": |
| | return { |
| | "moe_entropy_loss": entropy_reg(logits), |
| | } |
| | elif self.config.moe_router_loss == "load_balancing_only": |
| | return { |
| | "moe_aux_loss": load_balancing_loss(logits, selected_experts), |
| | } |
| | elif self.config.moe_router_loss == "load_balancing_z_loss": |
| | return { |
| | "moe_aux_loss": load_balancing_loss(logits, selected_experts), |
| | "moe_z_loss": router_z_loss(logits), |
| | } |
| | return {} |
| |
|
| | def get_num_params(self, non_embedding=True): |
| | """ |
| | Return the number of parameters in the model. |
| | For non-embedding count (default), the position embeddings get subtracted. |
| | The token embeddings would too, except due to the parameter sharing these |
| | params are actually used as weights in the final layer, so we include them. |
| | """ |
| | n_params = sum(p.numel() for p in self.parameters()) |
| | if non_embedding: |
| | n_params -= self.transformer.wpe.weight.numel() |
| | return n_params |
| |
|
| | def _init_weights(self, module): |
| | if isinstance(module, nn.Linear): |
| | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| | if module.bias is not None: |
| | torch.nn.init.zeros_(module.bias) |
| | elif isinstance(module, nn.Embedding): |
| | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| |
|
| | def forward(self, idx, date, targets=None, get_logits=False, moe=False): |
| | device = idx.device |
| | b, t = idx.size() |
| | assert ( |
| | t <= self.config.sequence_length |
| | ), f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}" |
| | |
| | pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) |
| |
|
| | |
| | tok_emb = self.transformer.wte(idx) |
| | pos_emb = self.transformer.wpe( |
| | pos |
| | ) |
| | x = self.transformer.drop(tok_emb + pos_emb) |
| |
|
| | |
| | mlp_router_logits = [] |
| | attn_router_logits = [] |
| | |
| | mlp_experts = [] |
| | attn_experts = [] |
| |
|
| |
|
| | |
| | for block in self.transformer.h: |
| | x, mlp_logits_and_experts, attn_logits_and_experts = block(x, date) |
| | if len(mlp_logits_and_experts) > 0: |
| | mlp_router_logits.append(mlp_logits_and_experts["router_logits"]) |
| | mlp_experts.append(mlp_logits_and_experts["selected_experts"]) |
| | if len(attn_logits_and_experts) > 0: |
| | attn_router_logits.append(attn_logits_and_experts["router_logits"]) |
| | attn_experts.append(attn_logits_and_experts["selected_experts"]) |
| | x = self.transformer.ln_f(x) |
| |
|
| | |
| | aux_losses_mlp = {} |
| | aux_losses_attn = {} |
| |
|
| | if targets is not None: |
| | |
| | logits = self.lm_head(x) |
| | loss = F.cross_entropy( |
| | logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1 |
| | ) |
| | loss_to_log = loss.item() |
| | if moe and (self.config.moe_routing == "standard_gating" or self.config.moe_routing == "masked"): |
| | |
| | for logit, expert_choice in zip(mlp_router_logits, mlp_experts): |
| | router_losses = self.get_router_losses( |
| | logit, expert_choice, eval=not self.training |
| | ) |
| | for k, v in router_losses.items(): |
| | aux_losses_mlp[k] = aux_losses_mlp.get(k, 0.0) + v |
| | if self.training: |
| | loss += ( |
| | v |
| | * getattr(self.config, k + "_factor") |
| | / self.config.n_layer |
| | ) |
| | for logit, expert_choice in zip(attn_router_logits, attn_experts): |
| | router_losses = self.get_router_losses( |
| | logit, expert_choice, eval=not self.training |
| | ) |
| | for k, v in router_losses.items(): |
| | aux_losses_attn[k] = aux_losses_attn.get(k, 0.0) + v |
| | if self.training: |
| | loss += ( |
| | v |
| | * getattr(self.config, k + "_factor") |
| | / self.config.n_layer |
| | ) |
| | else: |
| | |
| | logits = self.lm_head( |
| | |
| | x |
| | ) |
| | loss = None |
| | loss_to_log = None |
| | logits = logits if get_logits else None |
| | mlp_router_logits = ( |
| | torch.stack(mlp_router_logits, dim=0) if len(mlp_router_logits) > 0 else None |
| | ) |
| | attn_router_logits = ( |
| | torch.stack(attn_router_logits, dim=0) if len(attn_router_logits) > 0 else None |
| | ) |
| | return { |
| | "loss_to_log": loss_to_log, |
| | "logits": logits, |
| | "loss": loss, |
| | "aux_losses_mlp": aux_losses_mlp, |
| | "aux_losses_attn": aux_losses_attn, |
| | "mlp_router_logits": mlp_router_logits, |
| | "attn_router_logits": attn_router_logits |
| | } |
| |
|
| | def crop_sequence_length(self, sequence_length): |
| | |
| | |
| | |
| | assert sequence_length <= self.config.sequence_length |
| | self.config.sequence_length = sequence_length |
| | self.transformer.wpe.weight = nn.Parameter( |
| | self.transformer.wpe.weight[:sequence_length] |
| | ) |
| | for block in self.transformer.h: |
| | block.attn.bias = block.attn.bias[:, :, :sequence_length, :sequence_length] |
| |
|
| | @classmethod |
| | def from_pretrained(cls, model_type, override_args=None): |
| | |
| | pass |
| |
|
| | def get_parameter_group_specs(self): |
| | """ |
| | This long function is unfortunately doing something very simple and is being very defensive: |
| | We are separating out all parameters of the model into two buckets: those that will experience |
| | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). |
| | We are then returning the PyTorch optimizer object. |
| | """ |
| |
|
| | |
| | decay = set() |
| | no_decay = set() |
| | whitelist_weight_modules = (torch.nn.Linear,) |
| |
|
| | BLACKLIST_WEIGHT_MODULES = ( |
| | torch.nn.LayerNorm, |
| | LayerNorm, |
| | torch.nn.Embedding, |
| | ) |
| |
|
| | for mn, m in self.named_modules(): |
| | for pn, p in m.named_parameters(): |
| | fpn = "%s.%s" % (mn, pn) if mn else pn |
| | |
| | |
| | |
| | if pn.endswith("bias"): |
| | |
| | no_decay.add(fpn) |
| | elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): |
| | |
| | decay.add(fpn) |
| | elif pn.endswith("weight") and isinstance(m, BLACKLIST_WEIGHT_MODULES): |
| | |
| | no_decay.add(fpn) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | decay.remove("lm_head.weight") |
| |
|
| | |
| | param_dict = {pn: p for pn, p in self.named_parameters()} |
| | inter_params = decay & no_decay |
| | union_params = decay | no_decay |
| | assert ( |
| | len(inter_params) == 0 |
| | ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) |
| | assert ( |
| | len(param_dict.keys() - union_params) == 0 |
| | ), "parameters %s were not separated into either decay/no_decay set!" % ( |
| | str(param_dict.keys() - union_params), |
| | ) |
| |
|
| | |
| | return [ |
| | {"params": sorted(list(decay))}, |
| | {"params": sorted(list(no_decay)), "weight_decay": 0.0}, |
| | ] |
| |
|
| | @torch.no_grad() |
| | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): |
| | """ |
| | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete |
| | the sequence max_new_tokens times, feeding the predictions back into the model each time. |
| | Most likely you'll want to make sure to be in model.eval() mode of operation for this. |
| | """ |
| | for _ in range(max_new_tokens): |
| | |
| | idx_cond = ( |
| | idx |
| | if idx.size(1) <= self.config.sequence_length |
| | else idx[:, -self.config.sequence_length :] |
| | ) |
| | |
| | logits = self(idx_cond, get_logits=True)["logits"] |
| | |
| | logits = logits[:, -1, :] / temperature |
| | |
| | if top_k is not None: |
| | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| | logits[logits < v[:, [-1]]] = -float("Inf") |
| | |
| | probs = F.softmax(logits, dim=-1) |
| | |
| | idx_next = torch.multinomial(probs, num_samples=1) |
| | |
| | idx = torch.cat((idx, idx_next), dim=1) |
| |
|
| | return idx |
| |
|
| | @torch.no_grad() |
| | def generate_from_string(self, in_str, max_new_tokens, temperature=1.0, top_k=None): |
| | idx = ( |
| | torch.tensor( |
| | self.tokenizer.encode(in_str, allowed_special={"<|endoftext|>"}) |
| | ) |
| | .view(1, -1) |
| | .to(self.lm_head.weight.device) |
| | ) |
| | out_idx = ( |
| | self.generate(idx, max_new_tokens, temperature, top_k) |
| | .view(-1) |
| | .to("cpu") |
| | .numpy() |
| | ) |
| | return self.tokenizer.decode(out_idx) |
| |
|